Merge pull request #29 from joelmathewthomas/feature/api
feature/api: Add endpoints
This commit is contained in:
+130
-7
@@ -1,11 +1,134 @@
|
|||||||
from __future__ import absolute_import, unicode_literals
|
import os
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from src.input.file_reader import read_audio #export PYTHONPATH="/home/karthikeyan/code/MainProject/freq-split-enhance:$PYTHONPATH" for exporting the module
|
from freqsplit.input.file_reader import read_audio
|
||||||
|
from freqsplit.preprocessing.classify import classify_audio
|
||||||
|
from freqsplit.preprocessing.normalize import normalize_audio
|
||||||
|
from freqsplit.preprocessing.trim import trim_audio
|
||||||
|
from freqsplit.preprocessing.resample import resample
|
||||||
|
from freqsplit.postprocessing.audio_writer import export_audio
|
||||||
|
from freqsplit.separation.demucs_wrapper import separate_audio_with_demucs
|
||||||
|
from freqsplit.refinement.deepfilternet_wrapper import noisereduce
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
def process_uploaded_file(file_path):
|
def save_and_classify(file_path, file_content):
|
||||||
# Simulate long-running task
|
"""Save uploaded file asynchronously and classify the audio file"""
|
||||||
read_audio(file_path=file_path)
|
with open(file_path, 'wb') as destination:
|
||||||
return 'File processed'
|
destination.write(file_content)
|
||||||
|
|
||||||
|
# Read the saved audio file
|
||||||
|
_, org_sr = read_audio(file_path) # Get original sampling rate
|
||||||
|
waveform, sr = read_audio(file_path, 32000, mono=True)
|
||||||
|
|
||||||
|
# Classify the audio
|
||||||
|
audio_class = classify_audio(waveform, sr)
|
||||||
|
|
||||||
|
return audio_class, org_sr
|
||||||
|
|
||||||
|
@shared_task
|
||||||
|
def normalize_audio_task(file_path):
|
||||||
|
"""Celery task to normalize audio synchronously"""
|
||||||
|
try:
|
||||||
|
audio, sr = read_audio(file_path) # Read audio
|
||||||
|
normalized_audio = normalize_audio(audio) # Normalize
|
||||||
|
export_audio(normalized_audio, file_path, sr) # Save file
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@shared_task
|
||||||
|
def trim_audio_task(file_path):
|
||||||
|
"""Celery task to trim audio synchronously"""
|
||||||
|
try:
|
||||||
|
audio, sr = read_audio(file_path)
|
||||||
|
trimmed_audio = trim_audio(audio, sr)
|
||||||
|
export_audio(trimmed_audio, file_path, sr)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@shared_task
|
||||||
|
def resample_audio_task(file_path, sr):
|
||||||
|
"""Celery task to resample the audio asynchronously"""
|
||||||
|
try:
|
||||||
|
audio, org_sr = read_audio(file_path)
|
||||||
|
resampled_audio, sr = resample(audio, org_sr, eval(sr))
|
||||||
|
export_audio(resampled_audio, file_path, sr)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@shared_task
|
||||||
|
def music_separation_task(file_path):
|
||||||
|
"""Celery task to separate music audio into sources"""
|
||||||
|
file_path = Path(file_path)
|
||||||
|
print("File path is ", file_path)
|
||||||
|
|
||||||
|
# Determine the base directory (output path)
|
||||||
|
output_path = file_path.parent
|
||||||
|
|
||||||
|
# Run Demucs separation
|
||||||
|
separate_audio_with_demucs(str(file_path), str(output_path))
|
||||||
|
|
||||||
|
# Define expected output dir
|
||||||
|
demucs_dir = output_path / 'htdemucs'
|
||||||
|
file_folder = demucs_dir / file_path.stem
|
||||||
|
|
||||||
|
if not file_folder.exists():
|
||||||
|
raise RuntimeError(f"Demucs output folder not found: {file_folder}")
|
||||||
|
|
||||||
|
# Expected output files
|
||||||
|
expected_files = ["bass.wav", "drums.wav", "other.wav", "vocals.wav"]
|
||||||
|
|
||||||
|
# Create "sources" directory to store separated components
|
||||||
|
sources_dir = output_path / "sources"
|
||||||
|
sources_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Move separate files to output_path and replace original file with vocals.wav, and move other files into sources/
|
||||||
|
try:
|
||||||
|
vocals_path = file_folder / "vocals.wav"
|
||||||
|
if not vocals_path.exists():
|
||||||
|
raise RuntimeError("Vocals file not found in Demucs output")
|
||||||
|
|
||||||
|
# Replace original file with vocals.wav while keeping original name
|
||||||
|
shutil.move(str(vocals_path), str(file_path))
|
||||||
|
|
||||||
|
# Move other separated files to the "sources" directory
|
||||||
|
for expected_file in expected_files:
|
||||||
|
src_file = file_folder / expected_file
|
||||||
|
if src_file.exists() and expected_file != "vocals.wav":
|
||||||
|
shutil.move(str(src_file), str(sources_dir / expected_file))
|
||||||
|
|
||||||
|
# Cleanup: Remove htedemucs directory
|
||||||
|
shutil.rmtree(str(demucs_dir))
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@shared_task
|
||||||
|
def noisereduce_task(file_path):
|
||||||
|
"""Celery task to remove noise from audio"""
|
||||||
|
file_path = Path(file_path)
|
||||||
|
|
||||||
|
# Run noisereduction
|
||||||
|
try:
|
||||||
|
noisereduce(file_path, file_path)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@shared_task
|
||||||
|
def cleanup_task(file_path):
|
||||||
|
"""Celery task to cleanup files"""
|
||||||
|
file_path = Path(file_path)
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
try:
|
||||||
|
shutil.rmtree(os.path.dirname(file_path))
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
return False
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
# api/utils.py
|
||||||
|
import os
|
||||||
|
from rest_framework import status
|
||||||
|
|
||||||
|
def get_audio_file_path(request, base_dir):
|
||||||
|
"""Returns the full path to the audio file inside the given UUID folder."""
|
||||||
|
|
||||||
|
file_uuid = request.data.get("file_uuid")
|
||||||
|
if not file_uuid:
|
||||||
|
return False, "Missing file_uuid", status.HTTP_400_BAD_REQUEST
|
||||||
|
|
||||||
|
dir_path = os.path.join(base_dir, file_uuid)
|
||||||
|
|
||||||
|
if not os.path.exists(dir_path) or not os.listdir(dir_path):
|
||||||
|
return False, "No file found", status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||||
|
return True, os.path.join(dir_path, os.listdir(dir_path)[0]), status.HTTP_200_OK # Assumes only one file exists
|
||||||
+193
-15
@@ -1,14 +1,28 @@
|
|||||||
import os
|
import os
|
||||||
|
import uuid
|
||||||
|
import zipfile
|
||||||
|
from django.http import FileResponse, HttpResponse
|
||||||
from rest_framework.decorators import api_view
|
from rest_framework.decorators import api_view
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework import status
|
from rest_framework import status
|
||||||
from freqsplit.preprocessing.classify import classify_audio
|
from .utils import get_audio_file_path
|
||||||
|
from .tasks import save_and_classify
|
||||||
|
from .tasks import normalize_audio_task
|
||||||
|
from .tasks import trim_audio_task
|
||||||
|
from .tasks import resample_audio_task
|
||||||
|
from .tasks import music_separation_task
|
||||||
|
from .tasks import noisereduce_task
|
||||||
|
from .tasks import cleanup_task
|
||||||
|
from freqsplit.input.format_checker import is_supported_format
|
||||||
|
|
||||||
UPLOAD_DIR = "/tmp/freq-split-enhance"
|
UPLOAD_DIR = "/tmp/freqsplit"
|
||||||
|
|
||||||
# Ensure the temp directory exists
|
# Ensure the temp directory exists
|
||||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
#
|
||||||
|
|
||||||
|
# Endpoint to upload audio and classify it to audio_class
|
||||||
@api_view(['POST'])
|
@api_view(['POST'])
|
||||||
def upload_audio(request):
|
def upload_audio(request):
|
||||||
"""Handles audio file upload and saves it to /tmp/freq-split-enhance"""
|
"""Handles audio file upload and saves it to /tmp/freq-split-enhance"""
|
||||||
@@ -16,19 +30,183 @@ def upload_audio(request):
|
|||||||
return Response({"Error: No file provided"}, status=status.HTTP_400_BAD_REQUEST)
|
return Response({"Error: No file provided"}, status=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
audio_file = request.FILES['file']
|
audio_file = request.FILES['file']
|
||||||
file_path = os.path.join(UPLOAD_DIR, audio_file.name)
|
|
||||||
|
# Check file format before proceeding
|
||||||
|
if not is_supported_format(audio_file.name):
|
||||||
|
return Response({"error": "Unsupported file format"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
|
|
||||||
|
# Generate a unique ID for this upload
|
||||||
|
file_uuid = str(uuid.uuid4())[:8]
|
||||||
|
|
||||||
|
#Create a subdirectory for this upload
|
||||||
|
upload_dir = os.path.join(UPLOAD_DIR, file_uuid)
|
||||||
|
os.makedirs(upload_dir, exist_ok=True)
|
||||||
|
|
||||||
|
file_path = os.path.join(upload_dir, audio_file.name)
|
||||||
|
|
||||||
# Save the uploaded file
|
# Save the uploaded file
|
||||||
with open(file_path, 'wb') as destination:
|
task = save_and_classify.apply(args=(file_path, audio_file.read()))
|
||||||
for chunk in audio_file.chunks():
|
|
||||||
destination.write(chunk)
|
if task.successful():
|
||||||
|
audio_class = task.result[0]
|
||||||
|
return Response(
|
||||||
|
{
|
||||||
|
"Status": "File uploaded successfully",
|
||||||
|
"file_uuid": file_uuid,
|
||||||
|
"audio_class": audio_class,
|
||||||
|
"sr": task.result[1]
|
||||||
|
},
|
||||||
|
status=status.HTTP_201_CREATED,
|
||||||
|
)
|
||||||
|
|
||||||
audio_class = classify_audio(file_path)
|
# Endpoint to normalize audio
|
||||||
|
@api_view(['POST'])
|
||||||
return Response(
|
def normalize_audio(request):
|
||||||
{
|
"""Handles audio normalization request"""
|
||||||
"Status": "File uploaded successfully",
|
stat, result, status_code = get_audio_file_path(request, UPLOAD_DIR)
|
||||||
"file_path": file_path,
|
if stat == False:
|
||||||
"audio_class": audio_class
|
return Response({"error": result}, status=status_code)
|
||||||
}, status=status.HTTP_201_CREATED,
|
|
||||||
)
|
# Call Celery task synchronously
|
||||||
|
task = normalize_audio_task.apply(args=(result,))
|
||||||
|
|
||||||
|
if task.get():
|
||||||
|
return Response({"message": "Audio normalized successfully"}, status=status.HTTP_200_OK)
|
||||||
|
else:
|
||||||
|
return Response({"error": "Failed to normalize audio"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
|
|
||||||
|
# Endpoint to trim audio
|
||||||
|
@api_view(['POST'])
|
||||||
|
def trim_audio(request):
|
||||||
|
"""Handles trimming of leading and trailing silence from an audio clip"""
|
||||||
|
stat, result, status_code = get_audio_file_path(request, UPLOAD_DIR)
|
||||||
|
if stat == False:
|
||||||
|
return Response({"error": result}, status=status_code)
|
||||||
|
|
||||||
|
# Call Celery task synchronously
|
||||||
|
task = trim_audio_task.apply(args=(result,))
|
||||||
|
|
||||||
|
if task.get():
|
||||||
|
return Response({"message": "Audio trimmed successfully"}, status=status.HTTP_200_OK)
|
||||||
|
else:
|
||||||
|
return Response({"error": "Failed to trim audio"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
|
|
||||||
|
# Endpoint to resample audio
|
||||||
|
@api_view(['POST'])
|
||||||
|
def resample_audio(request):
|
||||||
|
"""Handles the resampling of audio"""
|
||||||
|
stat, result, status_code = get_audio_file_path(request, UPLOAD_DIR)
|
||||||
|
if stat == False:
|
||||||
|
return Response({"error": result}, status=status_code)
|
||||||
|
|
||||||
|
sr = request.data.get("sr")
|
||||||
|
if not sr:
|
||||||
|
return Response({"error": "Missing sr"}, status=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
|
# Call Celery task synchronously
|
||||||
|
task = resample_audio_task.apply(args=(result, sr))
|
||||||
|
|
||||||
|
if task.get():
|
||||||
|
return Response({"message": f"Audio resampled to {sr} successfully"}, status=status.HTTP_200_OK)
|
||||||
|
else:
|
||||||
|
return Response({"error": "Failed to resample audio"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
|
|
||||||
|
# Endpoint to separate music into sources
|
||||||
|
@api_view(['POST'])
|
||||||
|
def separate_music(request):
|
||||||
|
"""Handles the separatioo of music audio into source components"""
|
||||||
|
stat, result, status_code = get_audio_file_path(request, UPLOAD_DIR)
|
||||||
|
if stat == False:
|
||||||
|
return Response({"error": result}, status=status_code)
|
||||||
|
|
||||||
|
# Call Celery task synchronously
|
||||||
|
task = music_separation_task.apply(args=(result,))
|
||||||
|
|
||||||
|
if task.get():
|
||||||
|
return Response({"message": f"Audio separated into sources successfully"}, status=status.HTTP_200_OK)
|
||||||
|
else:
|
||||||
|
return Response({"error": "Failed to source separate audio"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
|
|
||||||
|
# Endpoint to reduce noise from audio
|
||||||
|
@api_view(['POST'])
|
||||||
|
def noisereduce(request):
|
||||||
|
"""Handles the reduction of noise from audio"""
|
||||||
|
stat, result, status_code = get_audio_file_path(request, UPLOAD_DIR)
|
||||||
|
if stat == False:
|
||||||
|
return Response({"error": result}, status=status_code)
|
||||||
|
|
||||||
|
# Call Celery task synchronously
|
||||||
|
task = noisereduce_task.apply(args=(result,))
|
||||||
|
|
||||||
|
if task.get():
|
||||||
|
return Response({"message": f"Removed noise from audio successfully"}, status=status.HTTP_200_OK)
|
||||||
|
else:
|
||||||
|
return Response({"error": "Failed to remove noise from audio"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
|
|
||||||
|
# Endpoint to download audio file or zipped directory
|
||||||
|
@api_view(['GET'])
|
||||||
|
def download_audio(request):
|
||||||
|
"""Handles downloading an audio file or a zipped directory."""
|
||||||
|
file_uuid = request.query_params.get("file_uuid")
|
||||||
|
|
||||||
|
if not file_uuid:
|
||||||
|
return Response({"error": "file_uuid is required"}, status=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
|
dir_path = os.path.join(UPLOAD_DIR, file_uuid)
|
||||||
|
|
||||||
|
if not os.path.exists(dir_path) or not os.listdir(dir_path):
|
||||||
|
return Response({"error": "No file found"}, status=status.HTTP_404_NOT_FOUND)
|
||||||
|
|
||||||
|
audio_files = [f for f in os.listdir(dir_path) if f.endswith(('.wav', '.mp3', '.flac'))]
|
||||||
|
sources_folder = "sources" in os.listdir(dir_path)
|
||||||
|
|
||||||
|
# If only one audio file exists (no sources/)
|
||||||
|
if len(audio_files) == 1 and not sources_folder:
|
||||||
|
file_path = os.path.join(dir_path, audio_files[0])
|
||||||
|
return FileResponse(open(file_path, "rb"), as_attachment=True, filename=os.path.basename(file_path))
|
||||||
|
|
||||||
|
# If there are multiple audio files or a sources/ directory, create a ZIP
|
||||||
|
zip_file_path = os.path.join(UPLOAD_DIR, f"{file_uuid}.zip")
|
||||||
|
|
||||||
|
# Ensure ZIP file is always fresh
|
||||||
|
if os.path.exists(zip_file_path):
|
||||||
|
os.remove(zip_file_path)
|
||||||
|
|
||||||
|
# Create ZIP file
|
||||||
|
with zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_DEFLATED) as zipf:
|
||||||
|
for root, _, files in os.walk(dir_path):
|
||||||
|
for file in files:
|
||||||
|
file_path = os.path.join(root, file)
|
||||||
|
arcname = os.path.relpath(file_path, dir_path) # Preserve folder structure inside ZIP
|
||||||
|
zipf.write(file_path, arcname)
|
||||||
|
|
||||||
|
# Stream the ZIP file
|
||||||
|
return FileResponse(open(zip_file_path, "rb"), as_attachment=True, filename=os.path.basename(zip_file_path))
|
||||||
|
|
||||||
|
@api_view(['POST'])
|
||||||
|
def cleanup(request):
|
||||||
|
"""Handles file cleanup after pipeline processing"""
|
||||||
|
stat, result, status_code = get_audio_file_path(request, UPLOAD_DIR)
|
||||||
|
if stat == False:
|
||||||
|
return Response({"error": result}, status=status_code)
|
||||||
|
|
||||||
|
# Call Celery task synchronously
|
||||||
|
task = cleanup_task.apply(args=(result,))
|
||||||
|
|
||||||
|
if task.get():
|
||||||
|
return Response({"message": f"Successfully cleaned up files on the server"}, status=status.HTTP_200_OK)
|
||||||
|
else:
|
||||||
|
return Response({"error": "Failed to cleanup files on the server"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
|
|
||||||
|
@api_view(['POST'])
|
||||||
|
def cleanup_zip(request):
|
||||||
|
"""Handles cleanup of all zip files leftover by api/download"""
|
||||||
|
# Delete all ZIP files in UPLOAD_DIR
|
||||||
|
for file in os.listdir(UPLOAD_DIR):
|
||||||
|
if file.endswith(".zip"):
|
||||||
|
file_path = os.path.join(UPLOAD_DIR, file)
|
||||||
|
try:
|
||||||
|
os.remove(file_path)
|
||||||
|
return Response({"message": "Cleaned up zipfiles on the server"}, status=status.HTTP_200_OK)
|
||||||
|
except Exception as e:
|
||||||
|
return Response({"message": f"Error deleting {file_path}: {e}"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
@@ -1 +1,3 @@
|
|||||||
from celery_app import app as celery
|
from .celery import app as celery_app
|
||||||
|
|
||||||
|
__all__ = ('celery_app',)
|
||||||
|
|||||||
@@ -0,0 +1,9 @@
|
|||||||
|
import os
|
||||||
|
from celery import Celery
|
||||||
|
|
||||||
|
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'backend.settings')
|
||||||
|
|
||||||
|
app = Celery('backend')
|
||||||
|
# Load config from Django settings, using a `CELERY_` prefix
|
||||||
|
app.config_from_object('django.conf:settings', namespace='CELERY')
|
||||||
|
app.autodiscover_tasks()
|
||||||
@@ -123,6 +123,10 @@ STATIC_URL = 'static/'
|
|||||||
# https://docs.djangoproject.com/en/5.1/ref/settings/#default-auto-field
|
# https://docs.djangoproject.com/en/5.1/ref/settings/#default-auto-field
|
||||||
|
|
||||||
DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField'
|
DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField'
|
||||||
|
|
||||||
|
# COnfigure Redis as message broker
|
||||||
CELERY_BROKER_URL = 'redis://localhost:6379/0'
|
CELERY_BROKER_URL = 'redis://localhost:6379/0'
|
||||||
CELERY_ACCEPT_CONTENT = ['json']
|
CELERY_ACCEPT_CONTENT = ['json']
|
||||||
CELERY_TASK_SERIALIZER = 'json'
|
CELERY_TASK_SERIALIZER = 'json'
|
||||||
|
CELERY_BROKER_CONNECTION_RETRY_ON_STARTUP = True
|
||||||
|
|
||||||
|
|||||||
+17
-1
@@ -17,8 +17,24 @@ Including another URLconf
|
|||||||
from django.contrib import admin
|
from django.contrib import admin
|
||||||
from django.urls import path
|
from django.urls import path
|
||||||
from api.views import upload_audio
|
from api.views import upload_audio
|
||||||
|
from api.views import normalize_audio
|
||||||
|
from api.views import trim_audio
|
||||||
|
from api.views import resample_audio
|
||||||
|
from api.views import separate_music
|
||||||
|
from api.views import noisereduce
|
||||||
|
from api.views import download_audio
|
||||||
|
from api.views import cleanup
|
||||||
|
from api.views import cleanup_zip
|
||||||
|
|
||||||
urlpatterns = [
|
urlpatterns = [
|
||||||
path('admin/', admin.site.urls),
|
path('admin/', admin.site.urls),
|
||||||
path('api/upload', upload_audio, name='upload_audio')
|
path('api/upload', upload_audio, name='upload_audio'),
|
||||||
|
path('api/normalize', normalize_audio, name="normalize_audio"),
|
||||||
|
path('api/trim', trim_audio, name='trim_audio'),
|
||||||
|
path('api/resample', resample_audio, name='resample_audio'),
|
||||||
|
path('api/separate', separate_music, name="separate_music"),
|
||||||
|
path('api/noisereduce', noisereduce, name="noisreduce"),
|
||||||
|
path('api/download', download_audio, name="download_audio"),
|
||||||
|
path('api/cleanup', cleanup, name="cleanup"),
|
||||||
|
path('api/cleanup_zip', cleanup_zip, name="cleanup_zip")
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,20 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
from celery import Celery
|
|
||||||
|
|
||||||
# Automatically set environment variables in celery_app.py
|
|
||||||
|
|
||||||
# Set Django settings module
|
|
||||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'backend.settings')
|
|
||||||
|
|
||||||
# Add the project directory to sys.path (similar to the manual PYTHONPATH)
|
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
|
||||||
sys.path.append(project_root)
|
|
||||||
|
|
||||||
app = Celery('backend')
|
|
||||||
|
|
||||||
# Load configuration from Django settings, using the CELERY namespace.
|
|
||||||
app.config_from_object('django.conf:settings', namespace='CELERY')
|
|
||||||
|
|
||||||
# Autodiscover tasks from installed apps.
|
|
||||||
app.autodiscover_tasks()
|
|
||||||
+25
-46
@@ -6,9 +6,8 @@ build-backend = "setuptools.build_meta"
|
|||||||
name = "freqsplit"
|
name = "freqsplit"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"absl-py==2.1.0",
|
"aiohappyeyeballs==2.4.6",
|
||||||
"aiohappyeyeballs==2.4.4",
|
"aiohttp==3.11.13",
|
||||||
"aiohttp==3.11.11",
|
|
||||||
"aiosignal==1.3.2",
|
"aiosignal==1.3.2",
|
||||||
"amqp==5.3.1",
|
"amqp==5.3.1",
|
||||||
"antlr4-python3-runtime==4.9.3",
|
"antlr4-python3-runtime==4.9.3",
|
||||||
@@ -16,13 +15,12 @@ dependencies = [
|
|||||||
"asgiref==3.8.1",
|
"asgiref==3.8.1",
|
||||||
"asteroid==0.7.0",
|
"asteroid==0.7.0",
|
||||||
"asteroid-filterbanks==0.4.0",
|
"asteroid-filterbanks==0.4.0",
|
||||||
"astunparse==1.6.3",
|
|
||||||
"attrs==25.1.0",
|
"attrs==25.1.0",
|
||||||
"audioread==3.0.1",
|
"audioread==3.0.1",
|
||||||
"billiard==4.2.1",
|
"billiard==4.2.1",
|
||||||
"cached-property==2.0.1",
|
"cached-property==2.0.1",
|
||||||
"celery==5.4.0",
|
"celery==5.4.0",
|
||||||
"certifi==2024.12.14",
|
"certifi==2025.1.31",
|
||||||
"cffi==1.17.1",
|
"cffi==1.17.1",
|
||||||
"charset-normalizer==3.4.1",
|
"charset-normalizer==3.4.1",
|
||||||
"click==8.1.8",
|
"click==8.1.8",
|
||||||
@@ -32,50 +30,38 @@ dependencies = [
|
|||||||
"cloudpickle==3.1.1",
|
"cloudpickle==3.1.1",
|
||||||
"contourpy==1.3.1",
|
"contourpy==1.3.1",
|
||||||
"cycler==0.12.1",
|
"cycler==0.12.1",
|
||||||
"decorator==5.1.1",
|
"decorator==5.2.1",
|
||||||
"DeepFilterLib==0.5.6",
|
"DeepFilterLib==0.5.6",
|
||||||
"DeepFilterNet==0.5.6",
|
"DeepFilterNet==0.5.6",
|
||||||
"demucs==4.0.1",
|
"demucs==4.0.1",
|
||||||
"Django==5.1.6",
|
"Django==5.1.6",
|
||||||
|
"djangorestframework==3.15.2",
|
||||||
"dora_search==0.1.12",
|
"dora_search==0.1.12",
|
||||||
"einops==0.8.0",
|
"einops==0.8.1",
|
||||||
"filelock==3.17.0",
|
"filelock==3.17.0",
|
||||||
"flatbuffers==25.2.10",
|
"fonttools==4.56.0",
|
||||||
"fonttools==4.55.6",
|
|
||||||
"frozenlist==1.5.0",
|
"frozenlist==1.5.0",
|
||||||
"fsspec==2024.12.0",
|
"fsspec==2025.2.0",
|
||||||
"future==1.0.0",
|
"huggingface-hub==0.29.1",
|
||||||
"gast==0.6.0",
|
|
||||||
"google-pasta==0.2.0",
|
|
||||||
"grpcio==1.70.0",
|
|
||||||
"h5py==3.13.0",
|
|
||||||
"huggingface-hub==0.28.0",
|
|
||||||
"idna==3.10",
|
"idna==3.10",
|
||||||
"iniconfig==2.0.0",
|
"iniconfig==2.0.0",
|
||||||
"Jinja2==3.1.5",
|
"Jinja2==3.1.5",
|
||||||
"joblib==1.4.2",
|
"joblib==1.4.2",
|
||||||
"julius==0.2.7",
|
"julius==0.2.7",
|
||||||
"keras==3.8.0",
|
|
||||||
"kiwisolver==1.4.8",
|
"kiwisolver==1.4.8",
|
||||||
"kombu==5.4.2",
|
"kombu==5.4.2",
|
||||||
"lameenc==1.8.1",
|
"lameenc==1.8.1",
|
||||||
"lazy_loader==0.4",
|
"lazy_loader==0.4",
|
||||||
"libclang==18.1.1",
|
|
||||||
"librosa==0.10.2.post1",
|
"librosa==0.10.2.post1",
|
||||||
"lightning-utilities==0.11.9",
|
"lightning-utilities==0.12.0",
|
||||||
"llvmlite==0.44.0",
|
"llvmlite==0.44.0",
|
||||||
"loguru==0.7.3",
|
"loguru==0.7.3",
|
||||||
"Markdown==3.7",
|
|
||||||
"markdown-it-py==3.0.0",
|
|
||||||
"MarkupSafe==3.0.2",
|
"MarkupSafe==3.0.2",
|
||||||
"matplotlib==3.10.0",
|
"matplotlib==3.10.0",
|
||||||
"mdurl==0.1.2",
|
"mir_eval==0.8.2",
|
||||||
"mir_eval==0.7",
|
|
||||||
"ml-dtypes==0.4.1",
|
|
||||||
"mpmath==1.3.0",
|
"mpmath==1.3.0",
|
||||||
"msgpack==1.1.0",
|
"msgpack==1.1.0",
|
||||||
"multidict==6.1.0",
|
"multidict==6.1.0",
|
||||||
"namex==0.0.8",
|
|
||||||
"networkx==3.4.2",
|
"networkx==3.4.2",
|
||||||
"numba==0.61.0",
|
"numba==0.61.0",
|
||||||
"numpy==1.26.4",
|
"numpy==1.26.4",
|
||||||
@@ -88,15 +74,15 @@ dependencies = [
|
|||||||
"nvidia-curand-cu12==10.3.5.147",
|
"nvidia-curand-cu12==10.3.5.147",
|
||||||
"nvidia-cusolver-cu12==11.6.1.9",
|
"nvidia-cusolver-cu12==11.6.1.9",
|
||||||
"nvidia-cusparse-cu12==12.3.1.170",
|
"nvidia-cusparse-cu12==12.3.1.170",
|
||||||
|
"nvidia-cusparselt-cu12==0.6.2",
|
||||||
"nvidia-nccl-cu12==2.21.5",
|
"nvidia-nccl-cu12==2.21.5",
|
||||||
"nvidia-nvjitlink-cu12==12.4.127",
|
"nvidia-nvjitlink-cu12==12.4.127",
|
||||||
"nvidia-nvtx-cu12==12.4.127",
|
"nvidia-nvtx-cu12==12.4.127",
|
||||||
"omegaconf==2.3.0",
|
"omegaconf==2.3.0",
|
||||||
"openunmix==1.3.0",
|
"openunmix==1.3.0",
|
||||||
"opt_einsum==3.4.0",
|
|
||||||
"optree==0.14.0",
|
|
||||||
"packaging==23.2",
|
"packaging==23.2",
|
||||||
"pandas==2.2.3",
|
"pandas==2.2.3",
|
||||||
|
"panns-inference==0.1.1",
|
||||||
"pb-bss-eval==0.0.2",
|
"pb-bss-eval==0.0.2",
|
||||||
"pesq==0.0.4",
|
"pesq==0.0.4",
|
||||||
"pillow==11.1.0",
|
"pillow==11.1.0",
|
||||||
@@ -104,54 +90,47 @@ dependencies = [
|
|||||||
"pluggy==1.5.0",
|
"pluggy==1.5.0",
|
||||||
"pooch==1.8.2",
|
"pooch==1.8.2",
|
||||||
"prompt_toolkit==3.0.50",
|
"prompt_toolkit==3.0.50",
|
||||||
"propcache==0.2.1",
|
"propcache==0.3.0",
|
||||||
"protobuf==5.29.3",
|
|
||||||
"pycparser==2.22",
|
"pycparser==2.22",
|
||||||
"Pygments==2.19.1",
|
|
||||||
"pyparsing==3.2.1",
|
"pyparsing==3.2.1",
|
||||||
"pystoi==0.4.1",
|
"pystoi==0.4.1",
|
||||||
"pytest==8.3.4",
|
"pytest==8.3.4",
|
||||||
"python-dateutil==2.9.0.post0",
|
"python-dateutil==2.9.0.post0",
|
||||||
"pytorch-lightning==2.5.0.post0",
|
"pytorch-lightning==2.5.0.post0",
|
||||||
"pytorch-ranger==0.1.1",
|
"pytorch-ranger==0.1.1",
|
||||||
"pytz==2024.2",
|
"pytz==2025.1",
|
||||||
"PyYAML==6.0.2",
|
"PyYAML==6.0.2",
|
||||||
"redis==5.2.1",
|
"redis==5.2.1",
|
||||||
|
"regex==2024.11.6",
|
||||||
"requests==2.32.3",
|
"requests==2.32.3",
|
||||||
"retrying==1.3.4",
|
"retrying==1.3.4",
|
||||||
"rich==13.9.4",
|
"safetensors==0.5.3",
|
||||||
"scikit-learn==1.6.1",
|
"scikit-learn==1.6.1",
|
||||||
"scipy==1.15.1",
|
"scipy==1.15.2",
|
||||||
"setuptools==75.8.0",
|
"setuptools==75.8.1",
|
||||||
"six==1.17.0",
|
"six==1.17.0",
|
||||||
"soundfile==0.13.1",
|
"soundfile==0.13.1",
|
||||||
"soxr==0.5.0.post1",
|
"soxr==0.5.0.post1",
|
||||||
"sqlparse==0.5.3",
|
"sqlparse==0.5.3",
|
||||||
"submitit==1.5.2",
|
"submitit==1.5.2",
|
||||||
"sympy==1.13.1",
|
"sympy==1.13.1",
|
||||||
"tensorboard==2.18.0",
|
|
||||||
"tensorboard-data-server==0.7.2",
|
|
||||||
"tensorflow==2.18.0",
|
|
||||||
"tensorflow-hub==0.16.1",
|
|
||||||
"termcolor==2.5.0",
|
|
||||||
"tf_keras==2.18.0",
|
|
||||||
"threadpoolctl==3.5.0",
|
"threadpoolctl==3.5.0",
|
||||||
"torch==2.5.1",
|
"tokenizers==0.21.0",
|
||||||
|
"torch==2.6.0",
|
||||||
"torch-optimizer==0.1.0",
|
"torch-optimizer==0.1.0",
|
||||||
"torch-stoi==0.2.3",
|
"torch-stoi==0.2.3",
|
||||||
"torchaudio==2.5.1",
|
"torchaudio==2.6.0",
|
||||||
|
"torchlibrosa==0.1.0",
|
||||||
"torchmetrics==0.11.4",
|
"torchmetrics==0.11.4",
|
||||||
"tqdm==4.67.1",
|
"tqdm==4.67.1",
|
||||||
|
"transformers==4.49.0",
|
||||||
"treetable==0.2.5",
|
"treetable==0.2.5",
|
||||||
"triton==3.1.0",
|
"triton==3.2.0",
|
||||||
"typing_extensions==4.12.2",
|
"typing_extensions==4.12.2",
|
||||||
"tzdata==2025.1",
|
"tzdata==2025.1",
|
||||||
"urllib3==2.3.0",
|
"urllib3==2.3.0",
|
||||||
"vine==5.1.0",
|
"vine==5.1.0",
|
||||||
"wcwidth==0.2.13",
|
"wcwidth==0.2.13",
|
||||||
"Werkzeug==3.1.3",
|
|
||||||
"wheel==0.45.1",
|
|
||||||
"wrapt==1.17.2",
|
|
||||||
"yarl==1.18.3",
|
"yarl==1.18.3",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
+24
-46
@@ -1,6 +1,5 @@
|
|||||||
absl-py==2.1.0
|
aiohappyeyeballs==2.4.6
|
||||||
aiohappyeyeballs==2.4.4
|
aiohttp==3.11.13
|
||||||
aiohttp==3.11.11
|
|
||||||
aiosignal==1.3.2
|
aiosignal==1.3.2
|
||||||
amqp==5.3.1
|
amqp==5.3.1
|
||||||
antlr4-python3-runtime==4.9.3
|
antlr4-python3-runtime==4.9.3
|
||||||
@@ -8,13 +7,12 @@ appdirs==1.4.4
|
|||||||
asgiref==3.8.1
|
asgiref==3.8.1
|
||||||
asteroid==0.7.0
|
asteroid==0.7.0
|
||||||
asteroid-filterbanks==0.4.0
|
asteroid-filterbanks==0.4.0
|
||||||
astunparse==1.6.3
|
|
||||||
attrs==25.1.0
|
attrs==25.1.0
|
||||||
audioread==3.0.1
|
audioread==3.0.1
|
||||||
billiard==4.2.1
|
billiard==4.2.1
|
||||||
cached-property==2.0.1
|
cached-property==2.0.1
|
||||||
celery==5.4.0
|
celery==5.4.0
|
||||||
certifi==2024.12.14
|
certifi==2025.1.31
|
||||||
cffi==1.17.1
|
cffi==1.17.1
|
||||||
charset-normalizer==3.4.1
|
charset-normalizer==3.4.1
|
||||||
click==8.1.8
|
click==8.1.8
|
||||||
@@ -24,51 +22,38 @@ click-repl==0.3.0
|
|||||||
cloudpickle==3.1.1
|
cloudpickle==3.1.1
|
||||||
contourpy==1.3.1
|
contourpy==1.3.1
|
||||||
cycler==0.12.1
|
cycler==0.12.1
|
||||||
decorator==5.1.1
|
decorator==5.2.1
|
||||||
DeepFilterLib==0.5.6
|
DeepFilterLib==0.5.6
|
||||||
DeepFilterNet==0.5.6
|
DeepFilterNet==0.5.6
|
||||||
demucs==4.0.1
|
demucs==4.0.1
|
||||||
Django==5.1.6
|
Django==5.1.6
|
||||||
djangorestframework==3.15.2
|
djangorestframework==3.15.2
|
||||||
dora_search==0.1.12
|
dora_search==0.1.12
|
||||||
einops==0.8.0
|
einops==0.8.1
|
||||||
filelock==3.17.0
|
filelock==3.17.0
|
||||||
flatbuffers==25.2.10
|
fonttools==4.56.0
|
||||||
fonttools==4.55.6
|
|
||||||
frozenlist==1.5.0
|
frozenlist==1.5.0
|
||||||
fsspec==2024.12.0
|
fsspec==2025.2.0
|
||||||
future==1.0.0
|
huggingface-hub==0.29.1
|
||||||
gast==0.6.0
|
|
||||||
google-pasta==0.2.0
|
|
||||||
grpcio==1.70.0
|
|
||||||
h5py==3.13.0
|
|
||||||
huggingface-hub==0.28.0
|
|
||||||
idna==3.10
|
idna==3.10
|
||||||
iniconfig==2.0.0
|
iniconfig==2.0.0
|
||||||
Jinja2==3.1.5
|
Jinja2==3.1.5
|
||||||
joblib==1.4.2
|
joblib==1.4.2
|
||||||
julius==0.2.7
|
julius==0.2.7
|
||||||
keras==3.8.0
|
|
||||||
kiwisolver==1.4.8
|
kiwisolver==1.4.8
|
||||||
kombu==5.4.2
|
kombu==5.4.2
|
||||||
lameenc==1.8.1
|
lameenc==1.8.1
|
||||||
lazy_loader==0.4
|
lazy_loader==0.4
|
||||||
libclang==18.1.1
|
|
||||||
librosa==0.10.2.post1
|
librosa==0.10.2.post1
|
||||||
lightning-utilities==0.11.9
|
lightning-utilities==0.12.0
|
||||||
llvmlite==0.44.0
|
llvmlite==0.44.0
|
||||||
loguru==0.7.3
|
loguru==0.7.3
|
||||||
Markdown==3.7
|
|
||||||
markdown-it-py==3.0.0
|
|
||||||
MarkupSafe==3.0.2
|
MarkupSafe==3.0.2
|
||||||
matplotlib==3.10.0
|
matplotlib==3.10.0
|
||||||
mdurl==0.1.2
|
mir_eval==0.8.2
|
||||||
mir_eval==0.7
|
|
||||||
ml-dtypes==0.4.1
|
|
||||||
mpmath==1.3.0
|
mpmath==1.3.0
|
||||||
msgpack==1.1.0
|
msgpack==1.1.0
|
||||||
multidict==6.1.0
|
multidict==6.1.0
|
||||||
namex==0.0.8
|
|
||||||
networkx==3.4.2
|
networkx==3.4.2
|
||||||
numba==0.61.0
|
numba==0.61.0
|
||||||
numpy==1.26.4
|
numpy==1.26.4
|
||||||
@@ -81,15 +66,15 @@ nvidia-cufft-cu12==11.2.1.3
|
|||||||
nvidia-curand-cu12==10.3.5.147
|
nvidia-curand-cu12==10.3.5.147
|
||||||
nvidia-cusolver-cu12==11.6.1.9
|
nvidia-cusolver-cu12==11.6.1.9
|
||||||
nvidia-cusparse-cu12==12.3.1.170
|
nvidia-cusparse-cu12==12.3.1.170
|
||||||
|
nvidia-cusparselt-cu12==0.6.2
|
||||||
nvidia-nccl-cu12==2.21.5
|
nvidia-nccl-cu12==2.21.5
|
||||||
nvidia-nvjitlink-cu12==12.4.127
|
nvidia-nvjitlink-cu12==12.4.127
|
||||||
nvidia-nvtx-cu12==12.4.127
|
nvidia-nvtx-cu12==12.4.127
|
||||||
omegaconf==2.3.0
|
omegaconf==2.3.0
|
||||||
openunmix==1.3.0
|
openunmix==1.3.0
|
||||||
opt_einsum==3.4.0
|
|
||||||
optree==0.14.0
|
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
pandas==2.2.3
|
pandas==2.2.3
|
||||||
|
panns-inference==0.1.1
|
||||||
pb-bss-eval==0.0.2
|
pb-bss-eval==0.0.2
|
||||||
pesq==0.0.4
|
pesq==0.0.4
|
||||||
pillow==11.1.0
|
pillow==11.1.0
|
||||||
@@ -97,52 +82,45 @@ platformdirs==4.3.6
|
|||||||
pluggy==1.5.0
|
pluggy==1.5.0
|
||||||
pooch==1.8.2
|
pooch==1.8.2
|
||||||
prompt_toolkit==3.0.50
|
prompt_toolkit==3.0.50
|
||||||
propcache==0.2.1
|
propcache==0.3.0
|
||||||
protobuf==5.29.3
|
|
||||||
pycparser==2.22
|
pycparser==2.22
|
||||||
Pygments==2.19.1
|
|
||||||
pyparsing==3.2.1
|
pyparsing==3.2.1
|
||||||
pystoi==0.4.1
|
pystoi==0.4.1
|
||||||
pytest==8.3.4
|
pytest==8.3.4
|
||||||
python-dateutil==2.9.0.post0
|
python-dateutil==2.9.0.post0
|
||||||
pytorch-lightning==2.5.0.post0
|
pytorch-lightning==2.5.0.post0
|
||||||
pytorch-ranger==0.1.1
|
pytorch-ranger==0.1.1
|
||||||
pytz==2024.2
|
pytz==2025.1
|
||||||
PyYAML==6.0.2
|
PyYAML==6.0.2
|
||||||
redis==5.2.1
|
redis==5.2.1
|
||||||
|
regex==2024.11.6
|
||||||
requests==2.32.3
|
requests==2.32.3
|
||||||
retrying==1.3.4
|
retrying==1.3.4
|
||||||
rich==13.9.4
|
safetensors==0.5.3
|
||||||
scikit-learn==1.6.1
|
scikit-learn==1.6.1
|
||||||
scipy==1.15.1
|
scipy==1.15.2
|
||||||
setuptools==75.8.0
|
setuptools==75.8.1
|
||||||
six==1.17.0
|
six==1.17.0
|
||||||
soundfile==0.13.1
|
soundfile==0.13.1
|
||||||
soxr==0.5.0.post1
|
soxr==0.5.0.post1
|
||||||
sqlparse==0.5.3
|
sqlparse==0.5.3
|
||||||
submitit==1.5.2
|
submitit==1.5.2
|
||||||
sympy==1.13.1
|
sympy==1.13.1
|
||||||
tensorboard==2.18.0
|
|
||||||
tensorboard-data-server==0.7.2
|
|
||||||
tensorflow==2.18.0
|
|
||||||
tensorflow-hub==0.16.1
|
|
||||||
termcolor==2.5.0
|
|
||||||
tf_keras==2.18.0
|
|
||||||
threadpoolctl==3.5.0
|
threadpoolctl==3.5.0
|
||||||
torch==2.5.1
|
tokenizers==0.21.0
|
||||||
|
torch==2.6.0
|
||||||
torch-optimizer==0.1.0
|
torch-optimizer==0.1.0
|
||||||
torch-stoi==0.2.3
|
torch-stoi==0.2.3
|
||||||
torchaudio==2.5.1
|
torchaudio==2.6.0
|
||||||
|
torchlibrosa==0.1.0
|
||||||
torchmetrics==0.11.4
|
torchmetrics==0.11.4
|
||||||
tqdm==4.67.1
|
tqdm==4.67.1
|
||||||
|
transformers==4.49.0
|
||||||
treetable==0.2.5
|
treetable==0.2.5
|
||||||
triton==3.1.0
|
triton==3.2.0
|
||||||
typing_extensions==4.12.2
|
typing_extensions==4.12.2
|
||||||
tzdata==2025.1
|
tzdata==2025.1
|
||||||
urllib3==2.3.0
|
urllib3==2.3.0
|
||||||
vine==5.1.0
|
vine==5.1.0
|
||||||
wcwidth==0.2.13
|
wcwidth==0.2.13
|
||||||
Werkzeug==3.1.3
|
|
||||||
wheel==0.45.1
|
|
||||||
wrapt==1.17.2
|
|
||||||
yarl==1.18.3
|
yarl==1.18.3
|
||||||
|
|||||||
@@ -9,4 +9,4 @@ logging.basicConfig(
|
|||||||
level = logging.INFO
|
level = logging.INFO
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("freq-split-enhance/input package has been imported.")
|
logging.info("freqsplit/input package has been imported.")
|
||||||
@@ -1,12 +1,14 @@
|
|||||||
import os
|
import os
|
||||||
import librosa
|
import librosa
|
||||||
|
|
||||||
def read_audio(file_path):
|
def read_audio(file_path, sr=None, mono=None):
|
||||||
"""
|
"""
|
||||||
Reads an audio file and returns the audio time series and sampling rate.
|
Reads an audio file and returns the audio time series and sampling rate.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_path (str): Path to the audio file.
|
file_path (str): Path to the audio file.
|
||||||
|
sr (int): Sample rate at which the audio is to be loaded
|
||||||
|
mono (bool): True to loaded audio with single channels, else False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: audio_time_series (numpy.ndarray), sampling_rate (int)
|
tuple: audio_time_series (numpy.ndarray), sampling_rate (int)
|
||||||
@@ -15,7 +17,11 @@ def read_audio(file_path):
|
|||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
raise FileNotFoundError(f"File not found: {file_path}")
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
try:
|
try:
|
||||||
audio, sr = librosa.load(file_path, sr=None) # Load with original sampling rate.
|
librosa_kwargs = {"sr": sr}
|
||||||
|
if mono is not None: # Only add 'mono' if explicitly provided
|
||||||
|
librosa_kwargs["mono"] = mono
|
||||||
|
|
||||||
|
audio, sr = librosa.load(file_path, **librosa_kwargs)
|
||||||
return audio, sr
|
return audio, sr
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Error reading the audio file: {e}")
|
raise RuntimeError(f"Error reading the audio file: {e}")
|
||||||
@@ -9,4 +9,4 @@ logging.basicConfig(
|
|||||||
level = logging.INFO
|
level = logging.INFO
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("freq-split-enhance/postprocessing package has been imported.")
|
logging.info("freqsplit/postprocessing package has been imported.")
|
||||||
@@ -12,9 +12,6 @@ def export_audio(audio, output_path, sr):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
print(f"Initial audio shape: {audio.shape}, dtype: {audio.dtype}")
|
|
||||||
|
|
||||||
if audio.ndim == 2 and audio.shape[0] == 2:
|
if audio.ndim == 2 and audio.shape[0] == 2:
|
||||||
# Transpose stereo audio to match the expected shape
|
# Transpose stereo audio to match the expected shape
|
||||||
audio = audio.T # From (2, num_samples) to (num_samples, 2)
|
audio = audio.T # From (2, num_samples) to (num_samples, 2)
|
||||||
@@ -26,10 +23,6 @@ def export_audio(audio, output_path, sr):
|
|||||||
if np.max(np.abs(audio)) > 0: # Avoid divide by zero
|
if np.max(np.abs(audio)) > 0: # Avoid divide by zero
|
||||||
audio = audio / np.max(np.abs(audio))
|
audio = audio / np.max(np.abs(audio))
|
||||||
|
|
||||||
# Verify final format
|
|
||||||
print(f"Final audio shape: {audio.shape}, dtype: {audio.dtype}, max: {np.max(audio)}, min: {np.min(audio)}")
|
|
||||||
|
|
||||||
|
|
||||||
sf.write(output_path, audio, sr, format='wav')
|
sf.write(output_path, audio, sr, format='wav')
|
||||||
print(f"Audio saved to {output_path}")
|
print(f"Audio saved to {output_path}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -9,4 +9,4 @@ logging.basicConfig(
|
|||||||
level = logging.INFO
|
level = logging.INFO
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("freq-split-enhance/preprocessing package has been imported.")
|
logging.info("freqsplit/preprocessing package has been imported.")
|
||||||
@@ -1,53 +1,43 @@
|
|||||||
import tensorflow as tf
|
|
||||||
import tensorflow_hub as hub
|
|
||||||
import librosa
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import csv
|
from panns_inference import AudioTagging, labels
|
||||||
import os
|
|
||||||
|
|
||||||
|
# Initialize PANNs model
|
||||||
|
at = AudioTagging(checkpoint_path=None, device='cuda')
|
||||||
|
|
||||||
# Force TensorFlow to use only CPU
|
def classify_audio(waveform, sr):
|
||||||
tf.config.set_visible_devices([], 'GPU')
|
|
||||||
|
|
||||||
model = hub.load('https://tfhub.dev/google/yamnet/1')
|
|
||||||
|
|
||||||
#Find the name of the class with the top score when mean-aggregated across frames.
|
|
||||||
def class_names_from_csv(class_map_scv_text):
|
|
||||||
"""Returns list of class names corresponding to score vector."""
|
|
||||||
class_names = []
|
|
||||||
with tf.io.gfile.GFile(class_map_scv_text) as csvfile:
|
|
||||||
reader = csv.DictReader(csvfile)
|
|
||||||
for row in reader:
|
|
||||||
class_names.append(row['display_name'])
|
|
||||||
return class_names
|
|
||||||
|
|
||||||
# Main function to process audio and classify
|
|
||||||
def classify_audio(file_path):
|
|
||||||
"""
|
"""
|
||||||
Given an audio file, this function loads the audio, resamples it,
|
Given an audio file, this function loads the audio, resamples it,
|
||||||
normalizes it, and runs it through the YAMNet model to classify the sound.
|
normalizes it, and runs it through the PANNs model to classify the sound.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- file_path (str): Path to the audio file (WAV, MP3, etc.).
|
- waveform (numpy.ndarray): waveform of the audio file (WAV, MP3, etc.).
|
||||||
|
- sr (int): Sampling rate of the audio.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- str: Predicted class label of the audio.
|
- str: Predicted class label of the audio.
|
||||||
"""
|
"""
|
||||||
# Load audio using librosa (this handles both loading, resampling, and conversion to mono)
|
|
||||||
waveform, sample_rate = librosa.load(file_path, sr=16000, mono=True) # Ensuring 16k sample rate and mono
|
|
||||||
|
|
||||||
# Normalize the waveform to [-1.0, 1.0] (librosa already returns normalized values)
|
|
||||||
waveform = waveform / np.max(np.abs(waveform))
|
|
||||||
|
|
||||||
# Execute the YAMNet model
|
|
||||||
scores, embeddings, spectrogram = model(waveform)
|
|
||||||
|
|
||||||
# Extract the class names from the model
|
# Check if the sampling rate is 32000Hz
|
||||||
class_map_path = model.class_map_path().numpy()
|
try:
|
||||||
class_names = class_names_from_csv(class_map_path)
|
if sr != 32000:
|
||||||
|
raise RuntimeError
|
||||||
# Find the class with the highest score
|
except Exception:
|
||||||
scores_np = scores.numpy()
|
raise RuntimeError(f"The audio is not sampled at 32000Hz, failed to classify audio.")
|
||||||
inferred_class = class_names[scores_np.mean(axis=0).argmax()]
|
|
||||||
|
# Normalize the waveform to [-1.0, 1.0]
|
||||||
return inferred_class
|
waveform = waveform / np.max(np.abs(waveform))
|
||||||
|
|
||||||
|
# Ensure waveform shape is correct for model input
|
||||||
|
waveform = waveform[None, :]
|
||||||
|
|
||||||
|
# Execute the PANNs model
|
||||||
|
try:
|
||||||
|
clipwise_output, _ = at.inference(waveform)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Error: Failed to classify audio: {e}")
|
||||||
|
|
||||||
|
# Get the top predicted class
|
||||||
|
predicted_index = np.argmax(clipwise_output)
|
||||||
|
inferred_class = labels[predicted_index]
|
||||||
|
|
||||||
|
return inferred_class
|
||||||
|
|||||||
@@ -9,4 +9,4 @@ logging.basicConfig(
|
|||||||
level = logging.INFO
|
level = logging.INFO
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("freq-split-enhance/refinement package has been imported.")
|
logging.info("freqsplit/refinement package has been imported.")
|
||||||
@@ -1,35 +1,83 @@
|
|||||||
import os
|
import os
|
||||||
|
import librosa
|
||||||
import torch
|
import torch
|
||||||
|
import shutil
|
||||||
|
import soundfile as sf
|
||||||
|
import numpy as np
|
||||||
from df.enhance import enhance, init_df, load_audio, save_audio
|
from df.enhance import enhance, init_df, load_audio, save_audio
|
||||||
|
|
||||||
|
def split_audio(audio, sr, chunk_size=5):
|
||||||
|
"""Split audio into chunks of `chunk_size` seconds."""
|
||||||
|
samples_per_chunk = sr * chunk_size
|
||||||
|
return [audio[i:i + samples_per_chunk] for i in range(0, len(audio), samples_per_chunk)]
|
||||||
|
|
||||||
def noisereduce(input_audio_path, output_audio_path, model_path=None):
|
def noisereduce(input_audio_path, output_audio_path, model_path=None):
|
||||||
"""
|
"""
|
||||||
Apply noise reduction using DeepFilterNet.
|
Apply noise reduction using DeepFilterNet with chunking.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_audio_path (str): Path to the input noisy audio file.
|
input_audio_path (str): Path to the input noisy audio file.
|
||||||
output_audio_path (str): Path to save the enhanced audio file.
|
output_audio_path (str): Path to save the enhanced audio file.
|
||||||
model_path (str, optional): Path to a custom DeepFilterNet model. Defaults to None (uses the pre-trained model).
|
model_path (str, optional): Path to a custom DeepFilterNet model. Defaults to None (uses pre-trained model).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: Path to the enhanced audio file.
|
str: Path to the enhanced audio file.
|
||||||
"""
|
"""
|
||||||
if not os.path.exists(input_audio_path):
|
if not os.path.exists(input_audio_path):
|
||||||
raise FileNotFoundError(f"Input file {input_audio_path} not found")
|
raise FileNotFoundError(f"Input file {input_audio_path} not found")
|
||||||
|
|
||||||
|
output_dir = os.path.dirname(output_audio_path)
|
||||||
|
os.makedirs(output_dir, exist_ok=True) # Ensure the directory exists
|
||||||
|
|
||||||
# Initialize DeepFilterNet model
|
# Initialize DeepFilterNet model
|
||||||
model, df_state, _ = init_df(model_path)
|
model, df_state, _ = init_df(model_path)
|
||||||
|
|
||||||
# Load audio
|
# Load audio
|
||||||
audio, _ = load_audio(input_audio_path, sr=df_state.sr())
|
audio, sr = librosa.load(input_audio_path, sr=None)
|
||||||
|
|
||||||
# Ensure output path exists
|
# Ensure output and chunk directories exist
|
||||||
os.makedirs(os.path.dirname(output_audio_path), exist_ok=True)
|
parent_dir = os.path.dirname(input_audio_path)
|
||||||
|
chunk_dir = os.path.join(parent_dir, "chunks")
|
||||||
|
output_chunk_dir = os.path.join(chunk_dir, "output")
|
||||||
|
os.makedirs(chunk_dir, exist_ok=True)
|
||||||
|
os.makedirs(output_chunk_dir, exist_ok=True)
|
||||||
|
|
||||||
# Apply noise reduction
|
# Split audio into 5-second chunks
|
||||||
enhanced_audio = enhance(model, df_state, audio)
|
chunks = split_audio(audio, sr, chunk_size=5)
|
||||||
|
chunk_paths = []
|
||||||
|
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
chunk_path = os.path.join(chunk_dir, f"chunk_{i}.wav")
|
||||||
|
sf.write(chunk_path, chunk, sr)
|
||||||
|
chunk_paths.append(chunk_path)
|
||||||
|
|
||||||
|
enhanced_chunk_paths = []
|
||||||
|
|
||||||
|
# Process each chunk sequentially to avoid OOM errors
|
||||||
|
for chunk_path in chunk_paths:
|
||||||
|
output_chunk_path = os.path.join(output_chunk_dir, os.path.basename(chunk_path))
|
||||||
|
|
||||||
|
# Load and enhance
|
||||||
|
chunk_audio, _ = load_audio(chunk_path, sr=df_state.sr())
|
||||||
|
enhanced_audio = enhance(model, df_state, chunk_audio)
|
||||||
|
|
||||||
|
# Save enhanced chunk
|
||||||
|
save_audio(output_chunk_path, enhanced_audio, df_state.sr())
|
||||||
|
enhanced_chunk_paths.append(output_chunk_path)
|
||||||
|
|
||||||
|
# Combine enhanced chunks back into a single audio file
|
||||||
|
final_audio = []
|
||||||
|
for chunk_path in enhanced_chunk_paths:
|
||||||
|
chunk_audio, _ = librosa.load(chunk_path, sr=sr) # Keep original sample rate
|
||||||
|
final_audio.append(chunk_audio)
|
||||||
|
|
||||||
|
final_audio = np.concatenate(final_audio, axis=0)
|
||||||
|
|
||||||
|
# Save final enhanced audio
|
||||||
|
sf.write(output_audio_path, final_audio, sr)
|
||||||
|
|
||||||
|
# Clean up temporary chunk files and directories
|
||||||
|
shutil.rmtree(chunk_dir, ignore_errors=True)
|
||||||
|
|
||||||
# Save the enhanced audio
|
|
||||||
save_audio(output_audio_path, enhanced_audio, df_state.sr())
|
|
||||||
|
|
||||||
return output_audio_path
|
return output_audio_path
|
||||||
|
|||||||
@@ -9,4 +9,4 @@ logging.basicConfig(
|
|||||||
level = logging.INFO
|
level = logging.INFO
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("freq-split-enhance/separation package has been imported.")
|
logging.info("freqsplit/separation package has been imported.")
|
||||||
@@ -9,4 +9,4 @@ logging.basicConfig(
|
|||||||
level = logging.INFO
|
level = logging.INFO
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("freq-split-enhance/spectogram package has been imported.")
|
logging.info("freqsplit/spectogram package has been imported.")
|
||||||
@@ -24,8 +24,9 @@ def test_trim_audio():
|
|||||||
|
|
||||||
def test_classify():
|
def test_classify():
|
||||||
file_path = "tests/test_audio/cafe_crowd_talk.wav"
|
file_path = "tests/test_audio/cafe_crowd_talk.wav"
|
||||||
|
waveform, sr = read_audio(file_path, 32000, mono=True)
|
||||||
expected_class = "Speech"
|
expected_class = "Speech"
|
||||||
predicted_class = classify_audio(file_path)
|
predicted_class = classify_audio(waveform, sr)
|
||||||
|
|
||||||
assert predicted_class == expected_class , f"Expected {expected_class}, but got {predicted_class}"
|
assert predicted_class == expected_class , f"Expected {expected_class}, but got {predicted_class}"
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user