diff --git a/api/api/tasks.py b/api/api/tasks.py index 71afd05..716bbec 100644 --- a/api/api/tasks.py +++ b/api/api/tasks.py @@ -1,11 +1,134 @@ -from __future__ import absolute_import, unicode_literals +import os +import shutil +from pathlib import Path from celery import shared_task -from src.input.file_reader import read_audio #export PYTHONPATH="/home/karthikeyan/code/MainProject/freq-split-enhance:$PYTHONPATH" for exporting the module +from freqsplit.input.file_reader import read_audio +from freqsplit.preprocessing.classify import classify_audio +from freqsplit.preprocessing.normalize import normalize_audio +from freqsplit.preprocessing.trim import trim_audio +from freqsplit.preprocessing.resample import resample +from freqsplit.postprocessing.audio_writer import export_audio +from freqsplit.separation.demucs_wrapper import separate_audio_with_demucs +from freqsplit.refinement.deepfilternet_wrapper import noisereduce -import time @shared_task -def process_uploaded_file(file_path): - # Simulate long-running task - read_audio(file_path=file_path) - return 'File processed' +def save_and_classify(file_path, file_content): + """Save uploaded file asynchronously and classify the audio file""" + with open(file_path, 'wb') as destination: + destination.write(file_content) + + # Read the saved audio file + _, org_sr = read_audio(file_path) # Get original sampling rate + waveform, sr = read_audio(file_path, 32000, mono=True) + + # Classify the audio + audio_class = classify_audio(waveform, sr) + + return audio_class, org_sr + +@shared_task +def normalize_audio_task(file_path): + """Celery task to normalize audio synchronously""" + try: + audio, sr = read_audio(file_path) # Read audio + normalized_audio = normalize_audio(audio) # Normalize + export_audio(normalized_audio, file_path, sr) # Save file + return True + except Exception as e: + return False + +@shared_task +def trim_audio_task(file_path): + """Celery task to trim audio synchronously""" + try: + audio, sr = read_audio(file_path) + trimmed_audio = trim_audio(audio, sr) + export_audio(trimmed_audio, file_path, sr) + return True + except Exception as e: + return False + +@shared_task +def resample_audio_task(file_path, sr): + """Celery task to resample the audio asynchronously""" + try: + audio, org_sr = read_audio(file_path) + resampled_audio, sr = resample(audio, org_sr, eval(sr)) + export_audio(resampled_audio, file_path, sr) + return True + except Exception as e: + return False + +@shared_task +def music_separation_task(file_path): + """Celery task to separate music audio into sources""" + file_path = Path(file_path) + print("File path is ", file_path) + + # Determine the base directory (output path) + output_path = file_path.parent + + # Run Demucs separation + separate_audio_with_demucs(str(file_path), str(output_path)) + + # Define expected output dir + demucs_dir = output_path / 'htdemucs' + file_folder = demucs_dir / file_path.stem + + if not file_folder.exists(): + raise RuntimeError(f"Demucs output folder not found: {file_folder}") + + # Expected output files + expected_files = ["bass.wav", "drums.wav", "other.wav", "vocals.wav"] + + # Create "sources" directory to store separated components + sources_dir = output_path / "sources" + sources_dir.mkdir(exist_ok=True) + + # Move separate files to output_path and replace original file with vocals.wav, and move other files into sources/ + try: + vocals_path = file_folder / "vocals.wav" + if not vocals_path.exists(): + raise RuntimeError("Vocals file not found in Demucs output") + + # Replace original file with vocals.wav while keeping original name + shutil.move(str(vocals_path), str(file_path)) + + # Move other separated files to the "sources" directory + for expected_file in expected_files: + src_file = file_folder / expected_file + if src_file.exists() and expected_file != "vocals.wav": + shutil.move(str(src_file), str(sources_dir / expected_file)) + + # Cleanup: Remove htedemucs directory + shutil.rmtree(str(demucs_dir)) + + return True + + except Exception as e: + return False + +@shared_task +def noisereduce_task(file_path): + """Celery task to remove noise from audio""" + file_path = Path(file_path) + + # Run noisereduction + try: + noisereduce(file_path, file_path) + return True + except Exception as e: + return False + +@shared_task +def cleanup_task(file_path): + """Celery task to cleanup files""" + file_path = Path(file_path) + + # Cleanup + try: + shutil.rmtree(os.path.dirname(file_path)) + return True + except Exception as e: + return False \ No newline at end of file diff --git a/api/api/utils.py b/api/api/utils.py new file mode 100644 index 0000000..b3b7c2f --- /dev/null +++ b/api/api/utils.py @@ -0,0 +1,16 @@ +# api/utils.py +import os +from rest_framework import status + +def get_audio_file_path(request, base_dir): + """Returns the full path to the audio file inside the given UUID folder.""" + + file_uuid = request.data.get("file_uuid") + if not file_uuid: + return False, "Missing file_uuid", status.HTTP_400_BAD_REQUEST + + dir_path = os.path.join(base_dir, file_uuid) + + if not os.path.exists(dir_path) or not os.listdir(dir_path): + return False, "No file found", status.HTTP_500_INTERNAL_SERVER_ERROR + return True, os.path.join(dir_path, os.listdir(dir_path)[0]), status.HTTP_200_OK # Assumes only one file exists \ No newline at end of file diff --git a/api/api/views.py b/api/api/views.py index ea7ea36..215eb6e 100644 --- a/api/api/views.py +++ b/api/api/views.py @@ -1,14 +1,28 @@ import os +import uuid +import zipfile +from django.http import FileResponse, HttpResponse from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework import status -from freqsplit.preprocessing.classify import classify_audio +from .utils import get_audio_file_path +from .tasks import save_and_classify +from .tasks import normalize_audio_task +from .tasks import trim_audio_task +from .tasks import resample_audio_task +from .tasks import music_separation_task +from .tasks import noisereduce_task +from .tasks import cleanup_task +from freqsplit.input.format_checker import is_supported_format -UPLOAD_DIR = "/tmp/freq-split-enhance" +UPLOAD_DIR = "/tmp/freqsplit" # Ensure the temp directory exists os.makedirs(UPLOAD_DIR, exist_ok=True) +# + +# Endpoint to upload audio and classify it to audio_class @api_view(['POST']) def upload_audio(request): """Handles audio file upload and saves it to /tmp/freq-split-enhance""" @@ -16,19 +30,183 @@ def upload_audio(request): return Response({"Error: No file provided"}, status=status.HTTP_400_BAD_REQUEST) audio_file = request.FILES['file'] - file_path = os.path.join(UPLOAD_DIR, audio_file.name) + + # Check file format before proceeding + if not is_supported_format(audio_file.name): + return Response({"error": "Unsupported file format"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + # Generate a unique ID for this upload + file_uuid = str(uuid.uuid4())[:8] + + #Create a subdirectory for this upload + upload_dir = os.path.join(UPLOAD_DIR, file_uuid) + os.makedirs(upload_dir, exist_ok=True) + + file_path = os.path.join(upload_dir, audio_file.name) # Save the uploaded file - with open(file_path, 'wb') as destination: - for chunk in audio_file.chunks(): - destination.write(chunk) + task = save_and_classify.apply(args=(file_path, audio_file.read())) + + if task.successful(): + audio_class = task.result[0] + return Response( + { + "Status": "File uploaded successfully", + "file_uuid": file_uuid, + "audio_class": audio_class, + "sr": task.result[1] + }, + status=status.HTTP_201_CREATED, + ) - audio_class = classify_audio(file_path) - - return Response( - { - "Status": "File uploaded successfully", - "file_path": file_path, - "audio_class": audio_class - }, status=status.HTTP_201_CREATED, - ) +# Endpoint to normalize audio +@api_view(['POST']) +def normalize_audio(request): + """Handles audio normalization request""" + stat, result, status_code = get_audio_file_path(request, UPLOAD_DIR) + if stat == False: + return Response({"error": result}, status=status_code) + + # Call Celery task synchronously + task = normalize_audio_task.apply(args=(result,)) + + if task.get(): + return Response({"message": "Audio normalized successfully"}, status=status.HTTP_200_OK) + else: + return Response({"error": "Failed to normalize audio"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + +# Endpoint to trim audio +@api_view(['POST']) +def trim_audio(request): + """Handles trimming of leading and trailing silence from an audio clip""" + stat, result, status_code = get_audio_file_path(request, UPLOAD_DIR) + if stat == False: + return Response({"error": result}, status=status_code) + + # Call Celery task synchronously + task = trim_audio_task.apply(args=(result,)) + + if task.get(): + return Response({"message": "Audio trimmed successfully"}, status=status.HTTP_200_OK) + else: + return Response({"error": "Failed to trim audio"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + +# Endpoint to resample audio +@api_view(['POST']) +def resample_audio(request): + """Handles the resampling of audio""" + stat, result, status_code = get_audio_file_path(request, UPLOAD_DIR) + if stat == False: + return Response({"error": result}, status=status_code) + + sr = request.data.get("sr") + if not sr: + return Response({"error": "Missing sr"}, status=status.HTTP_400_BAD_REQUEST) + + # Call Celery task synchronously + task = resample_audio_task.apply(args=(result, sr)) + + if task.get(): + return Response({"message": f"Audio resampled to {sr} successfully"}, status=status.HTTP_200_OK) + else: + return Response({"error": "Failed to resample audio"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + +# Endpoint to separate music into sources +@api_view(['POST']) +def separate_music(request): + """Handles the separatioo of music audio into source components""" + stat, result, status_code = get_audio_file_path(request, UPLOAD_DIR) + if stat == False: + return Response({"error": result}, status=status_code) + + # Call Celery task synchronously + task = music_separation_task.apply(args=(result,)) + + if task.get(): + return Response({"message": f"Audio separated into sources successfully"}, status=status.HTTP_200_OK) + else: + return Response({"error": "Failed to source separate audio"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + +# Endpoint to reduce noise from audio +@api_view(['POST']) +def noisereduce(request): + """Handles the reduction of noise from audio""" + stat, result, status_code = get_audio_file_path(request, UPLOAD_DIR) + if stat == False: + return Response({"error": result}, status=status_code) + + # Call Celery task synchronously + task = noisereduce_task.apply(args=(result,)) + + if task.get(): + return Response({"message": f"Removed noise from audio successfully"}, status=status.HTTP_200_OK) + else: + return Response({"error": "Failed to remove noise from audio"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + +# Endpoint to download audio file or zipped directory +@api_view(['GET']) +def download_audio(request): + """Handles downloading an audio file or a zipped directory.""" + file_uuid = request.query_params.get("file_uuid") + + if not file_uuid: + return Response({"error": "file_uuid is required"}, status=status.HTTP_400_BAD_REQUEST) + + dir_path = os.path.join(UPLOAD_DIR, file_uuid) + + if not os.path.exists(dir_path) or not os.listdir(dir_path): + return Response({"error": "No file found"}, status=status.HTTP_404_NOT_FOUND) + + audio_files = [f for f in os.listdir(dir_path) if f.endswith(('.wav', '.mp3', '.flac'))] + sources_folder = "sources" in os.listdir(dir_path) + + # If only one audio file exists (no sources/) + if len(audio_files) == 1 and not sources_folder: + file_path = os.path.join(dir_path, audio_files[0]) + return FileResponse(open(file_path, "rb"), as_attachment=True, filename=os.path.basename(file_path)) + + # If there are multiple audio files or a sources/ directory, create a ZIP + zip_file_path = os.path.join(UPLOAD_DIR, f"{file_uuid}.zip") + + # Ensure ZIP file is always fresh + if os.path.exists(zip_file_path): + os.remove(zip_file_path) + + # Create ZIP file + with zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_DEFLATED) as zipf: + for root, _, files in os.walk(dir_path): + for file in files: + file_path = os.path.join(root, file) + arcname = os.path.relpath(file_path, dir_path) # Preserve folder structure inside ZIP + zipf.write(file_path, arcname) + + # Stream the ZIP file + return FileResponse(open(zip_file_path, "rb"), as_attachment=True, filename=os.path.basename(zip_file_path)) + +@api_view(['POST']) +def cleanup(request): + """Handles file cleanup after pipeline processing""" + stat, result, status_code = get_audio_file_path(request, UPLOAD_DIR) + if stat == False: + return Response({"error": result}, status=status_code) + + # Call Celery task synchronously + task = cleanup_task.apply(args=(result,)) + + if task.get(): + return Response({"message": f"Successfully cleaned up files on the server"}, status=status.HTTP_200_OK) + else: + return Response({"error": "Failed to cleanup files on the server"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + +@api_view(['POST']) +def cleanup_zip(request): + """Handles cleanup of all zip files leftover by api/download""" + # Delete all ZIP files in UPLOAD_DIR + for file in os.listdir(UPLOAD_DIR): + if file.endswith(".zip"): + file_path = os.path.join(UPLOAD_DIR, file) + try: + os.remove(file_path) + return Response({"message": "Cleaned up zipfiles on the server"}, status=status.HTTP_200_OK) + except Exception as e: + return Response({"message": f"Error deleting {file_path}: {e}"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) \ No newline at end of file diff --git a/api/backend/__init__.py b/api/backend/__init__.py index 52ffd13..fb989c4 100644 --- a/api/backend/__init__.py +++ b/api/backend/__init__.py @@ -1 +1,3 @@ -from celery_app import app as celery +from .celery import app as celery_app + +__all__ = ('celery_app',) diff --git a/api/backend/celery.py b/api/backend/celery.py new file mode 100644 index 0000000..3b714c3 --- /dev/null +++ b/api/backend/celery.py @@ -0,0 +1,9 @@ +import os +from celery import Celery + +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'backend.settings') + +app = Celery('backend') +# Load config from Django settings, using a `CELERY_` prefix +app.config_from_object('django.conf:settings', namespace='CELERY') +app.autodiscover_tasks() \ No newline at end of file diff --git a/api/backend/settings.py b/api/backend/settings.py index 45aa008..b2e1751 100644 --- a/api/backend/settings.py +++ b/api/backend/settings.py @@ -123,6 +123,10 @@ STATIC_URL = 'static/' # https://docs.djangoproject.com/en/5.1/ref/settings/#default-auto-field DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField' + +# COnfigure Redis as message broker CELERY_BROKER_URL = 'redis://localhost:6379/0' CELERY_ACCEPT_CONTENT = ['json'] CELERY_TASK_SERIALIZER = 'json' +CELERY_BROKER_CONNECTION_RETRY_ON_STARTUP = True + diff --git a/api/backend/urls.py b/api/backend/urls.py index 6c468f2..727daa7 100644 --- a/api/backend/urls.py +++ b/api/backend/urls.py @@ -17,8 +17,24 @@ Including another URLconf from django.contrib import admin from django.urls import path from api.views import upload_audio +from api.views import normalize_audio +from api.views import trim_audio +from api.views import resample_audio +from api.views import separate_music +from api.views import noisereduce +from api.views import download_audio +from api.views import cleanup +from api.views import cleanup_zip urlpatterns = [ path('admin/', admin.site.urls), - path('api/upload', upload_audio, name='upload_audio') + path('api/upload', upload_audio, name='upload_audio'), + path('api/normalize', normalize_audio, name="normalize_audio"), + path('api/trim', trim_audio, name='trim_audio'), + path('api/resample', resample_audio, name='resample_audio'), + path('api/separate', separate_music, name="separate_music"), + path('api/noisereduce', noisereduce, name="noisreduce"), + path('api/download', download_audio, name="download_audio"), + path('api/cleanup', cleanup, name="cleanup"), + path('api/cleanup_zip', cleanup_zip, name="cleanup_zip") ] diff --git a/api/celery_app.py b/api/celery_app.py deleted file mode 100644 index a557171..0000000 --- a/api/celery_app.py +++ /dev/null @@ -1,20 +0,0 @@ -import os -import sys -from celery import Celery - -# Automatically set environment variables in celery_app.py - -# Set Django settings module -os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'backend.settings') - -# Add the project directory to sys.path (similar to the manual PYTHONPATH) -project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) -sys.path.append(project_root) - -app = Celery('backend') - -# Load configuration from Django settings, using the CELERY namespace. -app.config_from_object('django.conf:settings', namespace='CELERY') - -# Autodiscover tasks from installed apps. -app.autodiscover_tasks() diff --git a/pyproject.toml b/pyproject.toml index 717eaf4..ab36295 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,9 +6,8 @@ build-backend = "setuptools.build_meta" name = "freqsplit" version = "0.1.0" dependencies = [ - "absl-py==2.1.0", - "aiohappyeyeballs==2.4.4", - "aiohttp==3.11.11", + "aiohappyeyeballs==2.4.6", + "aiohttp==3.11.13", "aiosignal==1.3.2", "amqp==5.3.1", "antlr4-python3-runtime==4.9.3", @@ -16,13 +15,12 @@ dependencies = [ "asgiref==3.8.1", "asteroid==0.7.0", "asteroid-filterbanks==0.4.0", - "astunparse==1.6.3", "attrs==25.1.0", "audioread==3.0.1", "billiard==4.2.1", "cached-property==2.0.1", "celery==5.4.0", - "certifi==2024.12.14", + "certifi==2025.1.31", "cffi==1.17.1", "charset-normalizer==3.4.1", "click==8.1.8", @@ -32,50 +30,38 @@ dependencies = [ "cloudpickle==3.1.1", "contourpy==1.3.1", "cycler==0.12.1", - "decorator==5.1.1", + "decorator==5.2.1", "DeepFilterLib==0.5.6", "DeepFilterNet==0.5.6", "demucs==4.0.1", "Django==5.1.6", + "djangorestframework==3.15.2", "dora_search==0.1.12", - "einops==0.8.0", + "einops==0.8.1", "filelock==3.17.0", - "flatbuffers==25.2.10", - "fonttools==4.55.6", + "fonttools==4.56.0", "frozenlist==1.5.0", - "fsspec==2024.12.0", - "future==1.0.0", - "gast==0.6.0", - "google-pasta==0.2.0", - "grpcio==1.70.0", - "h5py==3.13.0", - "huggingface-hub==0.28.0", + "fsspec==2025.2.0", + "huggingface-hub==0.29.1", "idna==3.10", "iniconfig==2.0.0", "Jinja2==3.1.5", "joblib==1.4.2", "julius==0.2.7", - "keras==3.8.0", "kiwisolver==1.4.8", "kombu==5.4.2", "lameenc==1.8.1", "lazy_loader==0.4", - "libclang==18.1.1", "librosa==0.10.2.post1", - "lightning-utilities==0.11.9", + "lightning-utilities==0.12.0", "llvmlite==0.44.0", "loguru==0.7.3", - "Markdown==3.7", - "markdown-it-py==3.0.0", "MarkupSafe==3.0.2", "matplotlib==3.10.0", - "mdurl==0.1.2", - "mir_eval==0.7", - "ml-dtypes==0.4.1", + "mir_eval==0.8.2", "mpmath==1.3.0", "msgpack==1.1.0", "multidict==6.1.0", - "namex==0.0.8", "networkx==3.4.2", "numba==0.61.0", "numpy==1.26.4", @@ -88,15 +74,15 @@ dependencies = [ "nvidia-curand-cu12==10.3.5.147", "nvidia-cusolver-cu12==11.6.1.9", "nvidia-cusparse-cu12==12.3.1.170", + "nvidia-cusparselt-cu12==0.6.2", "nvidia-nccl-cu12==2.21.5", "nvidia-nvjitlink-cu12==12.4.127", "nvidia-nvtx-cu12==12.4.127", "omegaconf==2.3.0", "openunmix==1.3.0", - "opt_einsum==3.4.0", - "optree==0.14.0", "packaging==23.2", "pandas==2.2.3", + "panns-inference==0.1.1", "pb-bss-eval==0.0.2", "pesq==0.0.4", "pillow==11.1.0", @@ -104,54 +90,47 @@ dependencies = [ "pluggy==1.5.0", "pooch==1.8.2", "prompt_toolkit==3.0.50", - "propcache==0.2.1", - "protobuf==5.29.3", + "propcache==0.3.0", "pycparser==2.22", - "Pygments==2.19.1", "pyparsing==3.2.1", "pystoi==0.4.1", "pytest==8.3.4", "python-dateutil==2.9.0.post0", "pytorch-lightning==2.5.0.post0", "pytorch-ranger==0.1.1", - "pytz==2024.2", + "pytz==2025.1", "PyYAML==6.0.2", "redis==5.2.1", + "regex==2024.11.6", "requests==2.32.3", "retrying==1.3.4", - "rich==13.9.4", + "safetensors==0.5.3", "scikit-learn==1.6.1", - "scipy==1.15.1", - "setuptools==75.8.0", + "scipy==1.15.2", + "setuptools==75.8.1", "six==1.17.0", "soundfile==0.13.1", "soxr==0.5.0.post1", "sqlparse==0.5.3", "submitit==1.5.2", "sympy==1.13.1", - "tensorboard==2.18.0", - "tensorboard-data-server==0.7.2", - "tensorflow==2.18.0", - "tensorflow-hub==0.16.1", - "termcolor==2.5.0", - "tf_keras==2.18.0", "threadpoolctl==3.5.0", - "torch==2.5.1", + "tokenizers==0.21.0", + "torch==2.6.0", "torch-optimizer==0.1.0", "torch-stoi==0.2.3", - "torchaudio==2.5.1", + "torchaudio==2.6.0", + "torchlibrosa==0.1.0", "torchmetrics==0.11.4", "tqdm==4.67.1", + "transformers==4.49.0", "treetable==0.2.5", - "triton==3.1.0", + "triton==3.2.0", "typing_extensions==4.12.2", "tzdata==2025.1", "urllib3==2.3.0", "vine==5.1.0", "wcwidth==0.2.13", - "Werkzeug==3.1.3", - "wheel==0.45.1", - "wrapt==1.17.2", "yarl==1.18.3", ] diff --git a/requirements.txt b/requirements.txt index abd5442..3b74e0d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ -absl-py==2.1.0 -aiohappyeyeballs==2.4.4 -aiohttp==3.11.11 +aiohappyeyeballs==2.4.6 +aiohttp==3.11.13 aiosignal==1.3.2 amqp==5.3.1 antlr4-python3-runtime==4.9.3 @@ -8,13 +7,12 @@ appdirs==1.4.4 asgiref==3.8.1 asteroid==0.7.0 asteroid-filterbanks==0.4.0 -astunparse==1.6.3 attrs==25.1.0 audioread==3.0.1 billiard==4.2.1 cached-property==2.0.1 celery==5.4.0 -certifi==2024.12.14 +certifi==2025.1.31 cffi==1.17.1 charset-normalizer==3.4.1 click==8.1.8 @@ -24,51 +22,38 @@ click-repl==0.3.0 cloudpickle==3.1.1 contourpy==1.3.1 cycler==0.12.1 -decorator==5.1.1 +decorator==5.2.1 DeepFilterLib==0.5.6 DeepFilterNet==0.5.6 demucs==4.0.1 Django==5.1.6 djangorestframework==3.15.2 dora_search==0.1.12 -einops==0.8.0 +einops==0.8.1 filelock==3.17.0 -flatbuffers==25.2.10 -fonttools==4.55.6 +fonttools==4.56.0 frozenlist==1.5.0 -fsspec==2024.12.0 -future==1.0.0 -gast==0.6.0 -google-pasta==0.2.0 -grpcio==1.70.0 -h5py==3.13.0 -huggingface-hub==0.28.0 +fsspec==2025.2.0 +huggingface-hub==0.29.1 idna==3.10 iniconfig==2.0.0 Jinja2==3.1.5 joblib==1.4.2 julius==0.2.7 -keras==3.8.0 kiwisolver==1.4.8 kombu==5.4.2 lameenc==1.8.1 lazy_loader==0.4 -libclang==18.1.1 librosa==0.10.2.post1 -lightning-utilities==0.11.9 +lightning-utilities==0.12.0 llvmlite==0.44.0 loguru==0.7.3 -Markdown==3.7 -markdown-it-py==3.0.0 MarkupSafe==3.0.2 matplotlib==3.10.0 -mdurl==0.1.2 -mir_eval==0.7 -ml-dtypes==0.4.1 +mir_eval==0.8.2 mpmath==1.3.0 msgpack==1.1.0 multidict==6.1.0 -namex==0.0.8 networkx==3.4.2 numba==0.61.0 numpy==1.26.4 @@ -81,15 +66,15 @@ nvidia-cufft-cu12==11.2.1.3 nvidia-curand-cu12==10.3.5.147 nvidia-cusolver-cu12==11.6.1.9 nvidia-cusparse-cu12==12.3.1.170 +nvidia-cusparselt-cu12==0.6.2 nvidia-nccl-cu12==2.21.5 nvidia-nvjitlink-cu12==12.4.127 nvidia-nvtx-cu12==12.4.127 omegaconf==2.3.0 openunmix==1.3.0 -opt_einsum==3.4.0 -optree==0.14.0 packaging==23.2 pandas==2.2.3 +panns-inference==0.1.1 pb-bss-eval==0.0.2 pesq==0.0.4 pillow==11.1.0 @@ -97,52 +82,45 @@ platformdirs==4.3.6 pluggy==1.5.0 pooch==1.8.2 prompt_toolkit==3.0.50 -propcache==0.2.1 -protobuf==5.29.3 +propcache==0.3.0 pycparser==2.22 -Pygments==2.19.1 pyparsing==3.2.1 pystoi==0.4.1 pytest==8.3.4 python-dateutil==2.9.0.post0 pytorch-lightning==2.5.0.post0 pytorch-ranger==0.1.1 -pytz==2024.2 +pytz==2025.1 PyYAML==6.0.2 redis==5.2.1 +regex==2024.11.6 requests==2.32.3 retrying==1.3.4 -rich==13.9.4 +safetensors==0.5.3 scikit-learn==1.6.1 -scipy==1.15.1 -setuptools==75.8.0 +scipy==1.15.2 +setuptools==75.8.1 six==1.17.0 soundfile==0.13.1 soxr==0.5.0.post1 sqlparse==0.5.3 submitit==1.5.2 sympy==1.13.1 -tensorboard==2.18.0 -tensorboard-data-server==0.7.2 -tensorflow==2.18.0 -tensorflow-hub==0.16.1 -termcolor==2.5.0 -tf_keras==2.18.0 threadpoolctl==3.5.0 -torch==2.5.1 +tokenizers==0.21.0 +torch==2.6.0 torch-optimizer==0.1.0 torch-stoi==0.2.3 -torchaudio==2.5.1 +torchaudio==2.6.0 +torchlibrosa==0.1.0 torchmetrics==0.11.4 tqdm==4.67.1 +transformers==4.49.0 treetable==0.2.5 -triton==3.1.0 +triton==3.2.0 typing_extensions==4.12.2 tzdata==2025.1 urllib3==2.3.0 vine==5.1.0 wcwidth==0.2.13 -Werkzeug==3.1.3 -wheel==0.45.1 -wrapt==1.17.2 yarl==1.18.3 diff --git a/src/freqsplit/input/__init__.py b/src/freqsplit/input/__init__.py index e604b16..5991e8e 100644 --- a/src/freqsplit/input/__init__.py +++ b/src/freqsplit/input/__init__.py @@ -9,4 +9,4 @@ logging.basicConfig( level = logging.INFO ) -logging.info("freq-split-enhance/input package has been imported.") \ No newline at end of file +logging.info("freqsplit/input package has been imported.") \ No newline at end of file diff --git a/src/freqsplit/input/file_reader.py b/src/freqsplit/input/file_reader.py index f18b7e1..5b6723a 100644 --- a/src/freqsplit/input/file_reader.py +++ b/src/freqsplit/input/file_reader.py @@ -1,12 +1,14 @@ import os import librosa -def read_audio(file_path): +def read_audio(file_path, sr=None, mono=None): """ Reads an audio file and returns the audio time series and sampling rate. Args: file_path (str): Path to the audio file. + sr (int): Sample rate at which the audio is to be loaded + mono (bool): True to loaded audio with single channels, else False. Returns: tuple: audio_time_series (numpy.ndarray), sampling_rate (int) @@ -15,7 +17,11 @@ def read_audio(file_path): if not os.path.exists(file_path): raise FileNotFoundError(f"File not found: {file_path}") try: - audio, sr = librosa.load(file_path, sr=None) # Load with original sampling rate. + librosa_kwargs = {"sr": sr} + if mono is not None: # Only add 'mono' if explicitly provided + librosa_kwargs["mono"] = mono + + audio, sr = librosa.load(file_path, **librosa_kwargs) return audio, sr except Exception as e: raise RuntimeError(f"Error reading the audio file: {e}") \ No newline at end of file diff --git a/src/freqsplit/postprocessing/__init__.py b/src/freqsplit/postprocessing/__init__.py index 636a4ce..477b653 100644 --- a/src/freqsplit/postprocessing/__init__.py +++ b/src/freqsplit/postprocessing/__init__.py @@ -9,4 +9,4 @@ logging.basicConfig( level = logging.INFO ) -logging.info("freq-split-enhance/postprocessing package has been imported.") \ No newline at end of file +logging.info("freqsplit/postprocessing package has been imported.") \ No newline at end of file diff --git a/src/freqsplit/postprocessing/audio_writer.py b/src/freqsplit/postprocessing/audio_writer.py index 2fd7da0..4f944df 100644 --- a/src/freqsplit/postprocessing/audio_writer.py +++ b/src/freqsplit/postprocessing/audio_writer.py @@ -12,9 +12,6 @@ def export_audio(audio, output_path, sr): """ try: - - print(f"Initial audio shape: {audio.shape}, dtype: {audio.dtype}") - if audio.ndim == 2 and audio.shape[0] == 2: # Transpose stereo audio to match the expected shape audio = audio.T # From (2, num_samples) to (num_samples, 2) @@ -26,10 +23,6 @@ def export_audio(audio, output_path, sr): if np.max(np.abs(audio)) > 0: # Avoid divide by zero audio = audio / np.max(np.abs(audio)) - # Verify final format - print(f"Final audio shape: {audio.shape}, dtype: {audio.dtype}, max: {np.max(audio)}, min: {np.min(audio)}") - - sf.write(output_path, audio, sr, format='wav') print(f"Audio saved to {output_path}") except Exception as e: diff --git a/src/freqsplit/preprocessing/__init__.py b/src/freqsplit/preprocessing/__init__.py index b345187..56d90f9 100644 --- a/src/freqsplit/preprocessing/__init__.py +++ b/src/freqsplit/preprocessing/__init__.py @@ -9,4 +9,4 @@ logging.basicConfig( level = logging.INFO ) -logging.info("freq-split-enhance/preprocessing package has been imported.") \ No newline at end of file +logging.info("freqsplit/preprocessing package has been imported.") \ No newline at end of file diff --git a/src/freqsplit/preprocessing/classify.py b/src/freqsplit/preprocessing/classify.py index cfa1783..0a52d3a 100644 --- a/src/freqsplit/preprocessing/classify.py +++ b/src/freqsplit/preprocessing/classify.py @@ -1,53 +1,43 @@ -import tensorflow as tf -import tensorflow_hub as hub -import librosa import numpy as np -import csv -import os +from panns_inference import AudioTagging, labels +# Initialize PANNs model +at = AudioTagging(checkpoint_path=None, device='cuda') -# Force TensorFlow to use only CPU -tf.config.set_visible_devices([], 'GPU') - -model = hub.load('https://tfhub.dev/google/yamnet/1') - -#Find the name of the class with the top score when mean-aggregated across frames. -def class_names_from_csv(class_map_scv_text): - """Returns list of class names corresponding to score vector.""" - class_names = [] - with tf.io.gfile.GFile(class_map_scv_text) as csvfile: - reader = csv.DictReader(csvfile) - for row in reader: - class_names.append(row['display_name']) - return class_names - -# Main function to process audio and classify -def classify_audio(file_path): +def classify_audio(waveform, sr): """ Given an audio file, this function loads the audio, resamples it, - normalizes it, and runs it through the YAMNet model to classify the sound. + normalizes it, and runs it through the PANNs model to classify the sound. Args: - - file_path (str): Path to the audio file (WAV, MP3, etc.). + - waveform (numpy.ndarray): waveform of the audio file (WAV, MP3, etc.). + - sr (int): Sampling rate of the audio. Returns: - str: Predicted class label of the audio. """ - # Load audio using librosa (this handles both loading, resampling, and conversion to mono) - waveform, sample_rate = librosa.load(file_path, sr=16000, mono=True) # Ensuring 16k sample rate and mono - - # Normalize the waveform to [-1.0, 1.0] (librosa already returns normalized values) - waveform = waveform / np.max(np.abs(waveform)) - - # Execute the YAMNet model - scores, embeddings, spectrogram = model(waveform) - # Extract the class names from the model - class_map_path = model.class_map_path().numpy() - class_names = class_names_from_csv(class_map_path) - - # Find the class with the highest score - scores_np = scores.numpy() - inferred_class = class_names[scores_np.mean(axis=0).argmax()] - - return inferred_class \ No newline at end of file + # Check if the sampling rate is 32000Hz + try: + if sr != 32000: + raise RuntimeError + except Exception: + raise RuntimeError(f"The audio is not sampled at 32000Hz, failed to classify audio.") + + # Normalize the waveform to [-1.0, 1.0] + waveform = waveform / np.max(np.abs(waveform)) + + # Ensure waveform shape is correct for model input + waveform = waveform[None, :] + + # Execute the PANNs model + try: + clipwise_output, _ = at.inference(waveform) + except Exception as e: + raise RuntimeError(f"Error: Failed to classify audio: {e}") + + # Get the top predicted class + predicted_index = np.argmax(clipwise_output) + inferred_class = labels[predicted_index] + + return inferred_class diff --git a/src/freqsplit/refinement/__init__.py b/src/freqsplit/refinement/__init__.py index b02a5df..4937441 100644 --- a/src/freqsplit/refinement/__init__.py +++ b/src/freqsplit/refinement/__init__.py @@ -9,4 +9,4 @@ logging.basicConfig( level = logging.INFO ) -logging.info("freq-split-enhance/refinement package has been imported.") \ No newline at end of file +logging.info("freqsplit/refinement package has been imported.") \ No newline at end of file diff --git a/src/freqsplit/refinement/deepfilternet_wrapper.py b/src/freqsplit/refinement/deepfilternet_wrapper.py index 6140a28..9579e2f 100644 --- a/src/freqsplit/refinement/deepfilternet_wrapper.py +++ b/src/freqsplit/refinement/deepfilternet_wrapper.py @@ -1,35 +1,83 @@ import os +import librosa import torch +import shutil +import soundfile as sf +import numpy as np from df.enhance import enhance, init_df, load_audio, save_audio +def split_audio(audio, sr, chunk_size=5): + """Split audio into chunks of `chunk_size` seconds.""" + samples_per_chunk = sr * chunk_size + return [audio[i:i + samples_per_chunk] for i in range(0, len(audio), samples_per_chunk)] + def noisereduce(input_audio_path, output_audio_path, model_path=None): """ - Apply noise reduction using DeepFilterNet. + Apply noise reduction using DeepFilterNet with chunking. Args: input_audio_path (str): Path to the input noisy audio file. output_audio_path (str): Path to save the enhanced audio file. - model_path (str, optional): Path to a custom DeepFilterNet model. Defaults to None (uses the pre-trained model). + model_path (str, optional): Path to a custom DeepFilterNet model. Defaults to None (uses pre-trained model). Returns: str: Path to the enhanced audio file. """ if not os.path.exists(input_audio_path): raise FileNotFoundError(f"Input file {input_audio_path} not found") + + output_dir = os.path.dirname(output_audio_path) + os.makedirs(output_dir, exist_ok=True) # Ensure the directory exists # Initialize DeepFilterNet model model, df_state, _ = init_df(model_path) # Load audio - audio, _ = load_audio(input_audio_path, sr=df_state.sr()) + audio, sr = librosa.load(input_audio_path, sr=None) - # Ensure output path exists - os.makedirs(os.path.dirname(output_audio_path), exist_ok=True) + # Ensure output and chunk directories exist + parent_dir = os.path.dirname(input_audio_path) + chunk_dir = os.path.join(parent_dir, "chunks") + output_chunk_dir = os.path.join(chunk_dir, "output") + os.makedirs(chunk_dir, exist_ok=True) + os.makedirs(output_chunk_dir, exist_ok=True) - # Apply noise reduction - enhanced_audio = enhance(model, df_state, audio) + # Split audio into 5-second chunks + chunks = split_audio(audio, sr, chunk_size=5) + chunk_paths = [] + + for i, chunk in enumerate(chunks): + chunk_path = os.path.join(chunk_dir, f"chunk_{i}.wav") + sf.write(chunk_path, chunk, sr) + chunk_paths.append(chunk_path) + + enhanced_chunk_paths = [] + + # Process each chunk sequentially to avoid OOM errors + for chunk_path in chunk_paths: + output_chunk_path = os.path.join(output_chunk_dir, os.path.basename(chunk_path)) + + # Load and enhance + chunk_audio, _ = load_audio(chunk_path, sr=df_state.sr()) + enhanced_audio = enhance(model, df_state, chunk_audio) + + # Save enhanced chunk + save_audio(output_chunk_path, enhanced_audio, df_state.sr()) + enhanced_chunk_paths.append(output_chunk_path) + + # Combine enhanced chunks back into a single audio file + final_audio = [] + for chunk_path in enhanced_chunk_paths: + chunk_audio, _ = librosa.load(chunk_path, sr=sr) # Keep original sample rate + final_audio.append(chunk_audio) + + final_audio = np.concatenate(final_audio, axis=0) + + # Save final enhanced audio + sf.write(output_audio_path, final_audio, sr) + + # Clean up temporary chunk files and directories + shutil.rmtree(chunk_dir, ignore_errors=True) - # Save the enhanced audio - save_audio(output_audio_path, enhanced_audio, df_state.sr()) return output_audio_path diff --git a/src/freqsplit/separation/__init__.py b/src/freqsplit/separation/__init__.py index 637f57f..201a961 100644 --- a/src/freqsplit/separation/__init__.py +++ b/src/freqsplit/separation/__init__.py @@ -9,4 +9,4 @@ logging.basicConfig( level = logging.INFO ) -logging.info("freq-split-enhance/separation package has been imported.") \ No newline at end of file +logging.info("freqsplit/separation package has been imported.") \ No newline at end of file diff --git a/src/freqsplit/spectogram/__init__.py b/src/freqsplit/spectogram/__init__.py index 15341f8..310be47 100644 --- a/src/freqsplit/spectogram/__init__.py +++ b/src/freqsplit/spectogram/__init__.py @@ -9,4 +9,4 @@ logging.basicConfig( level = logging.INFO ) -logging.info("freq-split-enhance/spectogram package has been imported.") \ No newline at end of file +logging.info("freqsplit/spectogram package has been imported.") \ No newline at end of file diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 7de07e9..9af97cd 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -24,8 +24,9 @@ def test_trim_audio(): def test_classify(): file_path = "tests/test_audio/cafe_crowd_talk.wav" + waveform, sr = read_audio(file_path, 32000, mono=True) expected_class = "Speech" - predicted_class = classify_audio(file_path) + predicted_class = classify_audio(waveform, sr) assert predicted_class == expected_class , f"Expected {expected_class}, but got {predicted_class}"