Merge pull request #29 from joelmathewthomas/feature/api

feature/api: Add endpoints
This commit is contained in:
Joel Mathew Thomas
2025-02-26 20:23:06 +05:30
committed by GitHub
21 changed files with 525 additions and 202 deletions
+130 -7
View File
@@ -1,11 +1,134 @@
from __future__ import absolute_import, unicode_literals
import os
import shutil
from pathlib import Path
from celery import shared_task
from src.input.file_reader import read_audio #export PYTHONPATH="/home/karthikeyan/code/MainProject/freq-split-enhance:$PYTHONPATH" for exporting the module
from freqsplit.input.file_reader import read_audio
from freqsplit.preprocessing.classify import classify_audio
from freqsplit.preprocessing.normalize import normalize_audio
from freqsplit.preprocessing.trim import trim_audio
from freqsplit.preprocessing.resample import resample
from freqsplit.postprocessing.audio_writer import export_audio
from freqsplit.separation.demucs_wrapper import separate_audio_with_demucs
from freqsplit.refinement.deepfilternet_wrapper import noisereduce
import time
@shared_task
def process_uploaded_file(file_path):
# Simulate long-running task
read_audio(file_path=file_path)
return 'File processed'
def save_and_classify(file_path, file_content):
"""Save uploaded file asynchronously and classify the audio file"""
with open(file_path, 'wb') as destination:
destination.write(file_content)
# Read the saved audio file
_, org_sr = read_audio(file_path) # Get original sampling rate
waveform, sr = read_audio(file_path, 32000, mono=True)
# Classify the audio
audio_class = classify_audio(waveform, sr)
return audio_class, org_sr
@shared_task
def normalize_audio_task(file_path):
"""Celery task to normalize audio synchronously"""
try:
audio, sr = read_audio(file_path) # Read audio
normalized_audio = normalize_audio(audio) # Normalize
export_audio(normalized_audio, file_path, sr) # Save file
return True
except Exception as e:
return False
@shared_task
def trim_audio_task(file_path):
"""Celery task to trim audio synchronously"""
try:
audio, sr = read_audio(file_path)
trimmed_audio = trim_audio(audio, sr)
export_audio(trimmed_audio, file_path, sr)
return True
except Exception as e:
return False
@shared_task
def resample_audio_task(file_path, sr):
"""Celery task to resample the audio asynchronously"""
try:
audio, org_sr = read_audio(file_path)
resampled_audio, sr = resample(audio, org_sr, eval(sr))
export_audio(resampled_audio, file_path, sr)
return True
except Exception as e:
return False
@shared_task
def music_separation_task(file_path):
"""Celery task to separate music audio into sources"""
file_path = Path(file_path)
print("File path is ", file_path)
# Determine the base directory (output path)
output_path = file_path.parent
# Run Demucs separation
separate_audio_with_demucs(str(file_path), str(output_path))
# Define expected output dir
demucs_dir = output_path / 'htdemucs'
file_folder = demucs_dir / file_path.stem
if not file_folder.exists():
raise RuntimeError(f"Demucs output folder not found: {file_folder}")
# Expected output files
expected_files = ["bass.wav", "drums.wav", "other.wav", "vocals.wav"]
# Create "sources" directory to store separated components
sources_dir = output_path / "sources"
sources_dir.mkdir(exist_ok=True)
# Move separate files to output_path and replace original file with vocals.wav, and move other files into sources/
try:
vocals_path = file_folder / "vocals.wav"
if not vocals_path.exists():
raise RuntimeError("Vocals file not found in Demucs output")
# Replace original file with vocals.wav while keeping original name
shutil.move(str(vocals_path), str(file_path))
# Move other separated files to the "sources" directory
for expected_file in expected_files:
src_file = file_folder / expected_file
if src_file.exists() and expected_file != "vocals.wav":
shutil.move(str(src_file), str(sources_dir / expected_file))
# Cleanup: Remove htedemucs directory
shutil.rmtree(str(demucs_dir))
return True
except Exception as e:
return False
@shared_task
def noisereduce_task(file_path):
"""Celery task to remove noise from audio"""
file_path = Path(file_path)
# Run noisereduction
try:
noisereduce(file_path, file_path)
return True
except Exception as e:
return False
@shared_task
def cleanup_task(file_path):
"""Celery task to cleanup files"""
file_path = Path(file_path)
# Cleanup
try:
shutil.rmtree(os.path.dirname(file_path))
return True
except Exception as e:
return False
+16
View File
@@ -0,0 +1,16 @@
# api/utils.py
import os
from rest_framework import status
def get_audio_file_path(request, base_dir):
"""Returns the full path to the audio file inside the given UUID folder."""
file_uuid = request.data.get("file_uuid")
if not file_uuid:
return False, "Missing file_uuid", status.HTTP_400_BAD_REQUEST
dir_path = os.path.join(base_dir, file_uuid)
if not os.path.exists(dir_path) or not os.listdir(dir_path):
return False, "No file found", status.HTTP_500_INTERNAL_SERVER_ERROR
return True, os.path.join(dir_path, os.listdir(dir_path)[0]), status.HTTP_200_OK # Assumes only one file exists
+193 -15
View File
@@ -1,14 +1,28 @@
import os
import uuid
import zipfile
from django.http import FileResponse, HttpResponse
from rest_framework.decorators import api_view
from rest_framework.response import Response
from rest_framework import status
from freqsplit.preprocessing.classify import classify_audio
from .utils import get_audio_file_path
from .tasks import save_and_classify
from .tasks import normalize_audio_task
from .tasks import trim_audio_task
from .tasks import resample_audio_task
from .tasks import music_separation_task
from .tasks import noisereduce_task
from .tasks import cleanup_task
from freqsplit.input.format_checker import is_supported_format
UPLOAD_DIR = "/tmp/freq-split-enhance"
UPLOAD_DIR = "/tmp/freqsplit"
# Ensure the temp directory exists
os.makedirs(UPLOAD_DIR, exist_ok=True)
#
# Endpoint to upload audio and classify it to audio_class
@api_view(['POST'])
def upload_audio(request):
"""Handles audio file upload and saves it to /tmp/freq-split-enhance"""
@@ -16,19 +30,183 @@ def upload_audio(request):
return Response({"Error: No file provided"}, status=status.HTTP_400_BAD_REQUEST)
audio_file = request.FILES['file']
file_path = os.path.join(UPLOAD_DIR, audio_file.name)
# Check file format before proceeding
if not is_supported_format(audio_file.name):
return Response({"error": "Unsupported file format"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
# Generate a unique ID for this upload
file_uuid = str(uuid.uuid4())[:8]
#Create a subdirectory for this upload
upload_dir = os.path.join(UPLOAD_DIR, file_uuid)
os.makedirs(upload_dir, exist_ok=True)
file_path = os.path.join(upload_dir, audio_file.name)
# Save the uploaded file
with open(file_path, 'wb') as destination:
for chunk in audio_file.chunks():
destination.write(chunk)
task = save_and_classify.apply(args=(file_path, audio_file.read()))
if task.successful():
audio_class = task.result[0]
return Response(
{
"Status": "File uploaded successfully",
"file_uuid": file_uuid,
"audio_class": audio_class,
"sr": task.result[1]
},
status=status.HTTP_201_CREATED,
)
audio_class = classify_audio(file_path)
return Response(
{
"Status": "File uploaded successfully",
"file_path": file_path,
"audio_class": audio_class
}, status=status.HTTP_201_CREATED,
)
# Endpoint to normalize audio
@api_view(['POST'])
def normalize_audio(request):
"""Handles audio normalization request"""
stat, result, status_code = get_audio_file_path(request, UPLOAD_DIR)
if stat == False:
return Response({"error": result}, status=status_code)
# Call Celery task synchronously
task = normalize_audio_task.apply(args=(result,))
if task.get():
return Response({"message": "Audio normalized successfully"}, status=status.HTTP_200_OK)
else:
return Response({"error": "Failed to normalize audio"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
# Endpoint to trim audio
@api_view(['POST'])
def trim_audio(request):
"""Handles trimming of leading and trailing silence from an audio clip"""
stat, result, status_code = get_audio_file_path(request, UPLOAD_DIR)
if stat == False:
return Response({"error": result}, status=status_code)
# Call Celery task synchronously
task = trim_audio_task.apply(args=(result,))
if task.get():
return Response({"message": "Audio trimmed successfully"}, status=status.HTTP_200_OK)
else:
return Response({"error": "Failed to trim audio"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
# Endpoint to resample audio
@api_view(['POST'])
def resample_audio(request):
"""Handles the resampling of audio"""
stat, result, status_code = get_audio_file_path(request, UPLOAD_DIR)
if stat == False:
return Response({"error": result}, status=status_code)
sr = request.data.get("sr")
if not sr:
return Response({"error": "Missing sr"}, status=status.HTTP_400_BAD_REQUEST)
# Call Celery task synchronously
task = resample_audio_task.apply(args=(result, sr))
if task.get():
return Response({"message": f"Audio resampled to {sr} successfully"}, status=status.HTTP_200_OK)
else:
return Response({"error": "Failed to resample audio"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
# Endpoint to separate music into sources
@api_view(['POST'])
def separate_music(request):
"""Handles the separatioo of music audio into source components"""
stat, result, status_code = get_audio_file_path(request, UPLOAD_DIR)
if stat == False:
return Response({"error": result}, status=status_code)
# Call Celery task synchronously
task = music_separation_task.apply(args=(result,))
if task.get():
return Response({"message": f"Audio separated into sources successfully"}, status=status.HTTP_200_OK)
else:
return Response({"error": "Failed to source separate audio"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
# Endpoint to reduce noise from audio
@api_view(['POST'])
def noisereduce(request):
"""Handles the reduction of noise from audio"""
stat, result, status_code = get_audio_file_path(request, UPLOAD_DIR)
if stat == False:
return Response({"error": result}, status=status_code)
# Call Celery task synchronously
task = noisereduce_task.apply(args=(result,))
if task.get():
return Response({"message": f"Removed noise from audio successfully"}, status=status.HTTP_200_OK)
else:
return Response({"error": "Failed to remove noise from audio"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
# Endpoint to download audio file or zipped directory
@api_view(['GET'])
def download_audio(request):
"""Handles downloading an audio file or a zipped directory."""
file_uuid = request.query_params.get("file_uuid")
if not file_uuid:
return Response({"error": "file_uuid is required"}, status=status.HTTP_400_BAD_REQUEST)
dir_path = os.path.join(UPLOAD_DIR, file_uuid)
if not os.path.exists(dir_path) or not os.listdir(dir_path):
return Response({"error": "No file found"}, status=status.HTTP_404_NOT_FOUND)
audio_files = [f for f in os.listdir(dir_path) if f.endswith(('.wav', '.mp3', '.flac'))]
sources_folder = "sources" in os.listdir(dir_path)
# If only one audio file exists (no sources/)
if len(audio_files) == 1 and not sources_folder:
file_path = os.path.join(dir_path, audio_files[0])
return FileResponse(open(file_path, "rb"), as_attachment=True, filename=os.path.basename(file_path))
# If there are multiple audio files or a sources/ directory, create a ZIP
zip_file_path = os.path.join(UPLOAD_DIR, f"{file_uuid}.zip")
# Ensure ZIP file is always fresh
if os.path.exists(zip_file_path):
os.remove(zip_file_path)
# Create ZIP file
with zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_DEFLATED) as zipf:
for root, _, files in os.walk(dir_path):
for file in files:
file_path = os.path.join(root, file)
arcname = os.path.relpath(file_path, dir_path) # Preserve folder structure inside ZIP
zipf.write(file_path, arcname)
# Stream the ZIP file
return FileResponse(open(zip_file_path, "rb"), as_attachment=True, filename=os.path.basename(zip_file_path))
@api_view(['POST'])
def cleanup(request):
"""Handles file cleanup after pipeline processing"""
stat, result, status_code = get_audio_file_path(request, UPLOAD_DIR)
if stat == False:
return Response({"error": result}, status=status_code)
# Call Celery task synchronously
task = cleanup_task.apply(args=(result,))
if task.get():
return Response({"message": f"Successfully cleaned up files on the server"}, status=status.HTTP_200_OK)
else:
return Response({"error": "Failed to cleanup files on the server"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@api_view(['POST'])
def cleanup_zip(request):
"""Handles cleanup of all zip files leftover by api/download"""
# Delete all ZIP files in UPLOAD_DIR
for file in os.listdir(UPLOAD_DIR):
if file.endswith(".zip"):
file_path = os.path.join(UPLOAD_DIR, file)
try:
os.remove(file_path)
return Response({"message": "Cleaned up zipfiles on the server"}, status=status.HTTP_200_OK)
except Exception as e:
return Response({"message": f"Error deleting {file_path}: {e}"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
+3 -1
View File
@@ -1 +1,3 @@
from celery_app import app as celery
from .celery import app as celery_app
__all__ = ('celery_app',)
+9
View File
@@ -0,0 +1,9 @@
import os
from celery import Celery
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'backend.settings')
app = Celery('backend')
# Load config from Django settings, using a `CELERY_` prefix
app.config_from_object('django.conf:settings', namespace='CELERY')
app.autodiscover_tasks()
+4
View File
@@ -123,6 +123,10 @@ STATIC_URL = 'static/'
# https://docs.djangoproject.com/en/5.1/ref/settings/#default-auto-field
DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField'
# COnfigure Redis as message broker
CELERY_BROKER_URL = 'redis://localhost:6379/0'
CELERY_ACCEPT_CONTENT = ['json']
CELERY_TASK_SERIALIZER = 'json'
CELERY_BROKER_CONNECTION_RETRY_ON_STARTUP = True
+17 -1
View File
@@ -17,8 +17,24 @@ Including another URLconf
from django.contrib import admin
from django.urls import path
from api.views import upload_audio
from api.views import normalize_audio
from api.views import trim_audio
from api.views import resample_audio
from api.views import separate_music
from api.views import noisereduce
from api.views import download_audio
from api.views import cleanup
from api.views import cleanup_zip
urlpatterns = [
path('admin/', admin.site.urls),
path('api/upload', upload_audio, name='upload_audio')
path('api/upload', upload_audio, name='upload_audio'),
path('api/normalize', normalize_audio, name="normalize_audio"),
path('api/trim', trim_audio, name='trim_audio'),
path('api/resample', resample_audio, name='resample_audio'),
path('api/separate', separate_music, name="separate_music"),
path('api/noisereduce', noisereduce, name="noisreduce"),
path('api/download', download_audio, name="download_audio"),
path('api/cleanup', cleanup, name="cleanup"),
path('api/cleanup_zip', cleanup_zip, name="cleanup_zip")
]
-20
View File
@@ -1,20 +0,0 @@
import os
import sys
from celery import Celery
# Automatically set environment variables in celery_app.py
# Set Django settings module
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'backend.settings')
# Add the project directory to sys.path (similar to the manual PYTHONPATH)
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(project_root)
app = Celery('backend')
# Load configuration from Django settings, using the CELERY namespace.
app.config_from_object('django.conf:settings', namespace='CELERY')
# Autodiscover tasks from installed apps.
app.autodiscover_tasks()
+25 -46
View File
@@ -6,9 +6,8 @@ build-backend = "setuptools.build_meta"
name = "freqsplit"
version = "0.1.0"
dependencies = [
"absl-py==2.1.0",
"aiohappyeyeballs==2.4.4",
"aiohttp==3.11.11",
"aiohappyeyeballs==2.4.6",
"aiohttp==3.11.13",
"aiosignal==1.3.2",
"amqp==5.3.1",
"antlr4-python3-runtime==4.9.3",
@@ -16,13 +15,12 @@ dependencies = [
"asgiref==3.8.1",
"asteroid==0.7.0",
"asteroid-filterbanks==0.4.0",
"astunparse==1.6.3",
"attrs==25.1.0",
"audioread==3.0.1",
"billiard==4.2.1",
"cached-property==2.0.1",
"celery==5.4.0",
"certifi==2024.12.14",
"certifi==2025.1.31",
"cffi==1.17.1",
"charset-normalizer==3.4.1",
"click==8.1.8",
@@ -32,50 +30,38 @@ dependencies = [
"cloudpickle==3.1.1",
"contourpy==1.3.1",
"cycler==0.12.1",
"decorator==5.1.1",
"decorator==5.2.1",
"DeepFilterLib==0.5.6",
"DeepFilterNet==0.5.6",
"demucs==4.0.1",
"Django==5.1.6",
"djangorestframework==3.15.2",
"dora_search==0.1.12",
"einops==0.8.0",
"einops==0.8.1",
"filelock==3.17.0",
"flatbuffers==25.2.10",
"fonttools==4.55.6",
"fonttools==4.56.0",
"frozenlist==1.5.0",
"fsspec==2024.12.0",
"future==1.0.0",
"gast==0.6.0",
"google-pasta==0.2.0",
"grpcio==1.70.0",
"h5py==3.13.0",
"huggingface-hub==0.28.0",
"fsspec==2025.2.0",
"huggingface-hub==0.29.1",
"idna==3.10",
"iniconfig==2.0.0",
"Jinja2==3.1.5",
"joblib==1.4.2",
"julius==0.2.7",
"keras==3.8.0",
"kiwisolver==1.4.8",
"kombu==5.4.2",
"lameenc==1.8.1",
"lazy_loader==0.4",
"libclang==18.1.1",
"librosa==0.10.2.post1",
"lightning-utilities==0.11.9",
"lightning-utilities==0.12.0",
"llvmlite==0.44.0",
"loguru==0.7.3",
"Markdown==3.7",
"markdown-it-py==3.0.0",
"MarkupSafe==3.0.2",
"matplotlib==3.10.0",
"mdurl==0.1.2",
"mir_eval==0.7",
"ml-dtypes==0.4.1",
"mir_eval==0.8.2",
"mpmath==1.3.0",
"msgpack==1.1.0",
"multidict==6.1.0",
"namex==0.0.8",
"networkx==3.4.2",
"numba==0.61.0",
"numpy==1.26.4",
@@ -88,15 +74,15 @@ dependencies = [
"nvidia-curand-cu12==10.3.5.147",
"nvidia-cusolver-cu12==11.6.1.9",
"nvidia-cusparse-cu12==12.3.1.170",
"nvidia-cusparselt-cu12==0.6.2",
"nvidia-nccl-cu12==2.21.5",
"nvidia-nvjitlink-cu12==12.4.127",
"nvidia-nvtx-cu12==12.4.127",
"omegaconf==2.3.0",
"openunmix==1.3.0",
"opt_einsum==3.4.0",
"optree==0.14.0",
"packaging==23.2",
"pandas==2.2.3",
"panns-inference==0.1.1",
"pb-bss-eval==0.0.2",
"pesq==0.0.4",
"pillow==11.1.0",
@@ -104,54 +90,47 @@ dependencies = [
"pluggy==1.5.0",
"pooch==1.8.2",
"prompt_toolkit==3.0.50",
"propcache==0.2.1",
"protobuf==5.29.3",
"propcache==0.3.0",
"pycparser==2.22",
"Pygments==2.19.1",
"pyparsing==3.2.1",
"pystoi==0.4.1",
"pytest==8.3.4",
"python-dateutil==2.9.0.post0",
"pytorch-lightning==2.5.0.post0",
"pytorch-ranger==0.1.1",
"pytz==2024.2",
"pytz==2025.1",
"PyYAML==6.0.2",
"redis==5.2.1",
"regex==2024.11.6",
"requests==2.32.3",
"retrying==1.3.4",
"rich==13.9.4",
"safetensors==0.5.3",
"scikit-learn==1.6.1",
"scipy==1.15.1",
"setuptools==75.8.0",
"scipy==1.15.2",
"setuptools==75.8.1",
"six==1.17.0",
"soundfile==0.13.1",
"soxr==0.5.0.post1",
"sqlparse==0.5.3",
"submitit==1.5.2",
"sympy==1.13.1",
"tensorboard==2.18.0",
"tensorboard-data-server==0.7.2",
"tensorflow==2.18.0",
"tensorflow-hub==0.16.1",
"termcolor==2.5.0",
"tf_keras==2.18.0",
"threadpoolctl==3.5.0",
"torch==2.5.1",
"tokenizers==0.21.0",
"torch==2.6.0",
"torch-optimizer==0.1.0",
"torch-stoi==0.2.3",
"torchaudio==2.5.1",
"torchaudio==2.6.0",
"torchlibrosa==0.1.0",
"torchmetrics==0.11.4",
"tqdm==4.67.1",
"transformers==4.49.0",
"treetable==0.2.5",
"triton==3.1.0",
"triton==3.2.0",
"typing_extensions==4.12.2",
"tzdata==2025.1",
"urllib3==2.3.0",
"vine==5.1.0",
"wcwidth==0.2.13",
"Werkzeug==3.1.3",
"wheel==0.45.1",
"wrapt==1.17.2",
"yarl==1.18.3",
]
+24 -46
View File
@@ -1,6 +1,5 @@
absl-py==2.1.0
aiohappyeyeballs==2.4.4
aiohttp==3.11.11
aiohappyeyeballs==2.4.6
aiohttp==3.11.13
aiosignal==1.3.2
amqp==5.3.1
antlr4-python3-runtime==4.9.3
@@ -8,13 +7,12 @@ appdirs==1.4.4
asgiref==3.8.1
asteroid==0.7.0
asteroid-filterbanks==0.4.0
astunparse==1.6.3
attrs==25.1.0
audioread==3.0.1
billiard==4.2.1
cached-property==2.0.1
celery==5.4.0
certifi==2024.12.14
certifi==2025.1.31
cffi==1.17.1
charset-normalizer==3.4.1
click==8.1.8
@@ -24,51 +22,38 @@ click-repl==0.3.0
cloudpickle==3.1.1
contourpy==1.3.1
cycler==0.12.1
decorator==5.1.1
decorator==5.2.1
DeepFilterLib==0.5.6
DeepFilterNet==0.5.6
demucs==4.0.1
Django==5.1.6
djangorestframework==3.15.2
dora_search==0.1.12
einops==0.8.0
einops==0.8.1
filelock==3.17.0
flatbuffers==25.2.10
fonttools==4.55.6
fonttools==4.56.0
frozenlist==1.5.0
fsspec==2024.12.0
future==1.0.0
gast==0.6.0
google-pasta==0.2.0
grpcio==1.70.0
h5py==3.13.0
huggingface-hub==0.28.0
fsspec==2025.2.0
huggingface-hub==0.29.1
idna==3.10
iniconfig==2.0.0
Jinja2==3.1.5
joblib==1.4.2
julius==0.2.7
keras==3.8.0
kiwisolver==1.4.8
kombu==5.4.2
lameenc==1.8.1
lazy_loader==0.4
libclang==18.1.1
librosa==0.10.2.post1
lightning-utilities==0.11.9
lightning-utilities==0.12.0
llvmlite==0.44.0
loguru==0.7.3
Markdown==3.7
markdown-it-py==3.0.0
MarkupSafe==3.0.2
matplotlib==3.10.0
mdurl==0.1.2
mir_eval==0.7
ml-dtypes==0.4.1
mir_eval==0.8.2
mpmath==1.3.0
msgpack==1.1.0
multidict==6.1.0
namex==0.0.8
networkx==3.4.2
numba==0.61.0
numpy==1.26.4
@@ -81,15 +66,15 @@ nvidia-cufft-cu12==11.2.1.3
nvidia-curand-cu12==10.3.5.147
nvidia-cusolver-cu12==11.6.1.9
nvidia-cusparse-cu12==12.3.1.170
nvidia-cusparselt-cu12==0.6.2
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.4.127
omegaconf==2.3.0
openunmix==1.3.0
opt_einsum==3.4.0
optree==0.14.0
packaging==23.2
pandas==2.2.3
panns-inference==0.1.1
pb-bss-eval==0.0.2
pesq==0.0.4
pillow==11.1.0
@@ -97,52 +82,45 @@ platformdirs==4.3.6
pluggy==1.5.0
pooch==1.8.2
prompt_toolkit==3.0.50
propcache==0.2.1
protobuf==5.29.3
propcache==0.3.0
pycparser==2.22
Pygments==2.19.1
pyparsing==3.2.1
pystoi==0.4.1
pytest==8.3.4
python-dateutil==2.9.0.post0
pytorch-lightning==2.5.0.post0
pytorch-ranger==0.1.1
pytz==2024.2
pytz==2025.1
PyYAML==6.0.2
redis==5.2.1
regex==2024.11.6
requests==2.32.3
retrying==1.3.4
rich==13.9.4
safetensors==0.5.3
scikit-learn==1.6.1
scipy==1.15.1
setuptools==75.8.0
scipy==1.15.2
setuptools==75.8.1
six==1.17.0
soundfile==0.13.1
soxr==0.5.0.post1
sqlparse==0.5.3
submitit==1.5.2
sympy==1.13.1
tensorboard==2.18.0
tensorboard-data-server==0.7.2
tensorflow==2.18.0
tensorflow-hub==0.16.1
termcolor==2.5.0
tf_keras==2.18.0
threadpoolctl==3.5.0
torch==2.5.1
tokenizers==0.21.0
torch==2.6.0
torch-optimizer==0.1.0
torch-stoi==0.2.3
torchaudio==2.5.1
torchaudio==2.6.0
torchlibrosa==0.1.0
torchmetrics==0.11.4
tqdm==4.67.1
transformers==4.49.0
treetable==0.2.5
triton==3.1.0
triton==3.2.0
typing_extensions==4.12.2
tzdata==2025.1
urllib3==2.3.0
vine==5.1.0
wcwidth==0.2.13
Werkzeug==3.1.3
wheel==0.45.1
wrapt==1.17.2
yarl==1.18.3
+1 -1
View File
@@ -9,4 +9,4 @@ logging.basicConfig(
level = logging.INFO
)
logging.info("freq-split-enhance/input package has been imported.")
logging.info("freqsplit/input package has been imported.")
+8 -2
View File
@@ -1,12 +1,14 @@
import os
import librosa
def read_audio(file_path):
def read_audio(file_path, sr=None, mono=None):
"""
Reads an audio file and returns the audio time series and sampling rate.
Args:
file_path (str): Path to the audio file.
sr (int): Sample rate at which the audio is to be loaded
mono (bool): True to loaded audio with single channels, else False.
Returns:
tuple: audio_time_series (numpy.ndarray), sampling_rate (int)
@@ -15,7 +17,11 @@ def read_audio(file_path):
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
try:
audio, sr = librosa.load(file_path, sr=None) # Load with original sampling rate.
librosa_kwargs = {"sr": sr}
if mono is not None: # Only add 'mono' if explicitly provided
librosa_kwargs["mono"] = mono
audio, sr = librosa.load(file_path, **librosa_kwargs)
return audio, sr
except Exception as e:
raise RuntimeError(f"Error reading the audio file: {e}")
+1 -1
View File
@@ -9,4 +9,4 @@ logging.basicConfig(
level = logging.INFO
)
logging.info("freq-split-enhance/postprocessing package has been imported.")
logging.info("freqsplit/postprocessing package has been imported.")
@@ -12,9 +12,6 @@ def export_audio(audio, output_path, sr):
"""
try:
print(f"Initial audio shape: {audio.shape}, dtype: {audio.dtype}")
if audio.ndim == 2 and audio.shape[0] == 2:
# Transpose stereo audio to match the expected shape
audio = audio.T # From (2, num_samples) to (num_samples, 2)
@@ -26,10 +23,6 @@ def export_audio(audio, output_path, sr):
if np.max(np.abs(audio)) > 0: # Avoid divide by zero
audio = audio / np.max(np.abs(audio))
# Verify final format
print(f"Final audio shape: {audio.shape}, dtype: {audio.dtype}, max: {np.max(audio)}, min: {np.min(audio)}")
sf.write(output_path, audio, sr, format='wav')
print(f"Audio saved to {output_path}")
except Exception as e:
+1 -1
View File
@@ -9,4 +9,4 @@ logging.basicConfig(
level = logging.INFO
)
logging.info("freq-split-enhance/preprocessing package has been imported.")
logging.info("freqsplit/preprocessing package has been imported.")
+31 -41
View File
@@ -1,53 +1,43 @@
import tensorflow as tf
import tensorflow_hub as hub
import librosa
import numpy as np
import csv
import os
from panns_inference import AudioTagging, labels
# Initialize PANNs model
at = AudioTagging(checkpoint_path=None, device='cuda')
# Force TensorFlow to use only CPU
tf.config.set_visible_devices([], 'GPU')
model = hub.load('https://tfhub.dev/google/yamnet/1')
#Find the name of the class with the top score when mean-aggregated across frames.
def class_names_from_csv(class_map_scv_text):
"""Returns list of class names corresponding to score vector."""
class_names = []
with tf.io.gfile.GFile(class_map_scv_text) as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
class_names.append(row['display_name'])
return class_names
# Main function to process audio and classify
def classify_audio(file_path):
def classify_audio(waveform, sr):
"""
Given an audio file, this function loads the audio, resamples it,
normalizes it, and runs it through the YAMNet model to classify the sound.
normalizes it, and runs it through the PANNs model to classify the sound.
Args:
- file_path (str): Path to the audio file (WAV, MP3, etc.).
- waveform (numpy.ndarray): waveform of the audio file (WAV, MP3, etc.).
- sr (int): Sampling rate of the audio.
Returns:
- str: Predicted class label of the audio.
"""
# Load audio using librosa (this handles both loading, resampling, and conversion to mono)
waveform, sample_rate = librosa.load(file_path, sr=16000, mono=True) # Ensuring 16k sample rate and mono
# Normalize the waveform to [-1.0, 1.0] (librosa already returns normalized values)
waveform = waveform / np.max(np.abs(waveform))
# Execute the YAMNet model
scores, embeddings, spectrogram = model(waveform)
# Extract the class names from the model
class_map_path = model.class_map_path().numpy()
class_names = class_names_from_csv(class_map_path)
# Find the class with the highest score
scores_np = scores.numpy()
inferred_class = class_names[scores_np.mean(axis=0).argmax()]
return inferred_class
# Check if the sampling rate is 32000Hz
try:
if sr != 32000:
raise RuntimeError
except Exception:
raise RuntimeError(f"The audio is not sampled at 32000Hz, failed to classify audio.")
# Normalize the waveform to [-1.0, 1.0]
waveform = waveform / np.max(np.abs(waveform))
# Ensure waveform shape is correct for model input
waveform = waveform[None, :]
# Execute the PANNs model
try:
clipwise_output, _ = at.inference(waveform)
except Exception as e:
raise RuntimeError(f"Error: Failed to classify audio: {e}")
# Get the top predicted class
predicted_index = np.argmax(clipwise_output)
inferred_class = labels[predicted_index]
return inferred_class
+1 -1
View File
@@ -9,4 +9,4 @@ logging.basicConfig(
level = logging.INFO
)
logging.info("freq-split-enhance/refinement package has been imported.")
logging.info("freqsplit/refinement package has been imported.")
@@ -1,35 +1,83 @@
import os
import librosa
import torch
import shutil
import soundfile as sf
import numpy as np
from df.enhance import enhance, init_df, load_audio, save_audio
def split_audio(audio, sr, chunk_size=5):
"""Split audio into chunks of `chunk_size` seconds."""
samples_per_chunk = sr * chunk_size
return [audio[i:i + samples_per_chunk] for i in range(0, len(audio), samples_per_chunk)]
def noisereduce(input_audio_path, output_audio_path, model_path=None):
"""
Apply noise reduction using DeepFilterNet.
Apply noise reduction using DeepFilterNet with chunking.
Args:
input_audio_path (str): Path to the input noisy audio file.
output_audio_path (str): Path to save the enhanced audio file.
model_path (str, optional): Path to a custom DeepFilterNet model. Defaults to None (uses the pre-trained model).
model_path (str, optional): Path to a custom DeepFilterNet model. Defaults to None (uses pre-trained model).
Returns:
str: Path to the enhanced audio file.
"""
if not os.path.exists(input_audio_path):
raise FileNotFoundError(f"Input file {input_audio_path} not found")
output_dir = os.path.dirname(output_audio_path)
os.makedirs(output_dir, exist_ok=True) # Ensure the directory exists
# Initialize DeepFilterNet model
model, df_state, _ = init_df(model_path)
# Load audio
audio, _ = load_audio(input_audio_path, sr=df_state.sr())
audio, sr = librosa.load(input_audio_path, sr=None)
# Ensure output path exists
os.makedirs(os.path.dirname(output_audio_path), exist_ok=True)
# Ensure output and chunk directories exist
parent_dir = os.path.dirname(input_audio_path)
chunk_dir = os.path.join(parent_dir, "chunks")
output_chunk_dir = os.path.join(chunk_dir, "output")
os.makedirs(chunk_dir, exist_ok=True)
os.makedirs(output_chunk_dir, exist_ok=True)
# Apply noise reduction
enhanced_audio = enhance(model, df_state, audio)
# Split audio into 5-second chunks
chunks = split_audio(audio, sr, chunk_size=5)
chunk_paths = []
for i, chunk in enumerate(chunks):
chunk_path = os.path.join(chunk_dir, f"chunk_{i}.wav")
sf.write(chunk_path, chunk, sr)
chunk_paths.append(chunk_path)
enhanced_chunk_paths = []
# Process each chunk sequentially to avoid OOM errors
for chunk_path in chunk_paths:
output_chunk_path = os.path.join(output_chunk_dir, os.path.basename(chunk_path))
# Load and enhance
chunk_audio, _ = load_audio(chunk_path, sr=df_state.sr())
enhanced_audio = enhance(model, df_state, chunk_audio)
# Save enhanced chunk
save_audio(output_chunk_path, enhanced_audio, df_state.sr())
enhanced_chunk_paths.append(output_chunk_path)
# Combine enhanced chunks back into a single audio file
final_audio = []
for chunk_path in enhanced_chunk_paths:
chunk_audio, _ = librosa.load(chunk_path, sr=sr) # Keep original sample rate
final_audio.append(chunk_audio)
final_audio = np.concatenate(final_audio, axis=0)
# Save final enhanced audio
sf.write(output_audio_path, final_audio, sr)
# Clean up temporary chunk files and directories
shutil.rmtree(chunk_dir, ignore_errors=True)
# Save the enhanced audio
save_audio(output_audio_path, enhanced_audio, df_state.sr())
return output_audio_path
+1 -1
View File
@@ -9,4 +9,4 @@ logging.basicConfig(
level = logging.INFO
)
logging.info("freq-split-enhance/separation package has been imported.")
logging.info("freqsplit/separation package has been imported.")
+1 -1
View File
@@ -9,4 +9,4 @@ logging.basicConfig(
level = logging.INFO
)
logging.info("freq-split-enhance/spectogram package has been imported.")
logging.info("freqsplit/spectogram package has been imported.")
+2 -1
View File
@@ -24,8 +24,9 @@ def test_trim_audio():
def test_classify():
file_path = "tests/test_audio/cafe_crowd_talk.wav"
waveform, sr = read_audio(file_path, 32000, mono=True)
expected_class = "Speech"
predicted_class = classify_audio(file_path)
predicted_class = classify_audio(waveform, sr)
assert predicted_class == expected_class , f"Expected {expected_class}, but got {predicted_class}"