endpoint: /api/normalize
-Add new endpoint /api/normalize, params: file_uuid, overwrites exisisting file on server with normalized audio. -Remove print statements from src/freqsplit/postprocessing/audio_writer
This commit is contained in:
+15
-1
@@ -1,6 +1,8 @@
|
|||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from freqsplit.input.file_reader import read_audio
|
from freqsplit.input.file_reader import read_audio
|
||||||
from freqsplit.preprocessing.classify import classify_audio
|
from freqsplit.preprocessing.classify import classify_audio
|
||||||
|
from freqsplit.preprocessing.normalize import normalize_audio
|
||||||
|
from freqsplit.postprocessing.audio_writer import export_audio
|
||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
def save_and_classify(file_path, file_content):
|
def save_and_classify(file_path, file_content):
|
||||||
@@ -14,4 +16,16 @@ def save_and_classify(file_path, file_content):
|
|||||||
# Classify the audio
|
# Classify the audio
|
||||||
audio_class = classify_audio(waveform, sr)
|
audio_class = classify_audio(waveform, sr)
|
||||||
|
|
||||||
return audio_class
|
return audio_class
|
||||||
|
|
||||||
|
@shared_task
|
||||||
|
def normalize_audio_task(file_path):
|
||||||
|
"""Celery task to normalize audio synchronously"""
|
||||||
|
try:
|
||||||
|
audio, sr = read_audio(file_path) # Read audio
|
||||||
|
normalized_audio = normalize_audio(audio) # Normalize
|
||||||
|
export_audio(normalized_audio, file_path, sr) # Save file
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
return False
|
||||||
|
|
||||||
@@ -4,6 +4,7 @@ from rest_framework.decorators import api_view
|
|||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework import status
|
from rest_framework import status
|
||||||
from .tasks import save_and_classify
|
from .tasks import save_and_classify
|
||||||
|
from .tasks import normalize_audio_task
|
||||||
from freqsplit.input.format_checker import is_supported_format
|
from freqsplit.input.format_checker import is_supported_format
|
||||||
|
|
||||||
UPLOAD_DIR = "/tmp/freqsplit"
|
UPLOAD_DIR = "/tmp/freqsplit"
|
||||||
@@ -11,6 +12,7 @@ UPLOAD_DIR = "/tmp/freqsplit"
|
|||||||
# Ensure the temp directory exists
|
# Ensure the temp directory exists
|
||||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
# Endpoint to upload audio and classify it to audio_class
|
||||||
@api_view(['POST'])
|
@api_view(['POST'])
|
||||||
def upload_audio(request):
|
def upload_audio(request):
|
||||||
"""Handles audio file upload and saves it to /tmp/freq-split-enhance"""
|
"""Handles audio file upload and saves it to /tmp/freq-split-enhance"""
|
||||||
@@ -45,3 +47,33 @@ def upload_audio(request):
|
|||||||
},
|
},
|
||||||
status=status.HTTP_201_CREATED,
|
status=status.HTTP_201_CREATED,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Endpoint to normalize audio
|
||||||
|
@api_view(['POST'])
|
||||||
|
def normalize_audio(request):
|
||||||
|
"""Handles audio normalization request"""
|
||||||
|
file_uuid = request.data.get("file_uuid")
|
||||||
|
|
||||||
|
if not file_uuid:
|
||||||
|
return Response({"error": "Missing file_uuid"}, status=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
|
audio_dir = os.path.join(UPLOAD_DIR, file_uuid)
|
||||||
|
|
||||||
|
if not os.path.exists(audio_dir) or not os.path.isdir(audio_dir):
|
||||||
|
return Response({"error": "File directory not found"}, status=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
|
# Get the actual file name (since there's only one file)
|
||||||
|
files = os.listdir(audio_dir)
|
||||||
|
if not files:
|
||||||
|
return Response({"error": "No file found in directory"}, status=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
|
file_name = files[0]
|
||||||
|
file_path = os.path.join(audio_dir, file_name)
|
||||||
|
|
||||||
|
# Call Celery task synchronously
|
||||||
|
task = normalize_audio_task.apply(args=(file_path,))
|
||||||
|
|
||||||
|
if task.get():
|
||||||
|
return Response({"message": "Audio normalized successfully"}, status=status.HTTP_200_OK)
|
||||||
|
else:
|
||||||
|
return Response({"error": "Failed to normalize audio"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
+3
-1
@@ -17,8 +17,10 @@ Including another URLconf
|
|||||||
from django.contrib import admin
|
from django.contrib import admin
|
||||||
from django.urls import path
|
from django.urls import path
|
||||||
from api.views import upload_audio
|
from api.views import upload_audio
|
||||||
|
from api.views import normalize_audio
|
||||||
|
|
||||||
urlpatterns = [
|
urlpatterns = [
|
||||||
path('admin/', admin.site.urls),
|
path('admin/', admin.site.urls),
|
||||||
path('api/upload', upload_audio, name='upload_audio')
|
path('api/upload', upload_audio, name='upload_audio'),
|
||||||
|
path('api/normalize', normalize_audio, name="normalize_audio")
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ def export_audio(audio, output_path, sr):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
print(f"Initial audio shape: {audio.shape}, dtype: {audio.dtype}")
|
print(f"Initial audio shape: {audio.shape}, dtype: {audio.dtype}, max: {np.max(audio)}, min: {np.min(audio)}")
|
||||||
|
|
||||||
if audio.ndim == 2 and audio.shape[0] == 2:
|
if audio.ndim == 2 and audio.shape[0] == 2:
|
||||||
# Transpose stereo audio to match the expected shape
|
# Transpose stereo audio to match the expected shape
|
||||||
|
|||||||
Reference in New Issue
Block a user