diff --git a/api/api/tasks.py b/api/api/tasks.py index 6cd0880..3446893 100644 --- a/api/api/tasks.py +++ b/api/api/tasks.py @@ -1,7 +1,17 @@ from celery import shared_task +from freqsplit.input.file_reader import read_audio +from freqsplit.preprocessing.classify import classify_audio @shared_task -def save_uploaded_file(file_path, file_content): - """Save uploaded file asynchronously""" +def save_and_classify(file_path, file_content): + """Save uploaded file asynchronously and classify the audio file""" with open(file_path, 'wb') as destination: - destination.write(file_content) \ No newline at end of file + destination.write(file_content) + + # Read the saved audio file + waveform, sr = read_audio(file_path, 16000, mono=True) + + # Classify the audio + audio_class = classify_audio(waveform, sr) + + return audio_class \ No newline at end of file diff --git a/api/api/views.py b/api/api/views.py index 78f45c7..8a993a4 100644 --- a/api/api/views.py +++ b/api/api/views.py @@ -3,7 +3,7 @@ import uuid from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework import status -from .tasks import save_uploaded_file +from .tasks import save_and_classify UPLOAD_DIR = "/tmp/freqsplit" @@ -28,12 +28,15 @@ def upload_audio(request): file_path = os.path.join(upload_dir, audio_file.name) # Save the uploaded file - save_uploaded_file.delay(file_path, audio_file.read()) - - return Response( - { - "Status": "File uploaded successfully", - "file_uuid": file_uuid, - }, - status=status.HTTP_201_CREATED, - ) + task = save_and_classify.apply(args=(file_path, audio_file.read())) + + if task.successful(): + audio_class = task.result + return Response( + { + "Status": "File uploaded successfully", + "file_uuid": file_uuid, + "audio_class": audio_class + }, + status=status.HTTP_201_CREATED, + )