From e4abb070db1ccab493bc4ee59fb8e99764c02e90 Mon Sep 17 00:00:00 2001 From: Joel Mathew Thomas <90510078+joelmathewthomas@users.noreply.github.com> Date: Thu, 26 Dec 2024 00:58:53 +0530 Subject: [PATCH] Implement audio classification function using YAMNet - Create a function to load audio, resample, and classify using YAMNet - Ensure compatibility with different audio formats and sample rates - Normalize audio and classify it into one of 600 categories --- requirements.txt | 1 + src/preprocessing/classify.py | 44 +++++++++++++++++++---------------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/requirements.txt b/requirements.txt index af1627b..ac44dc3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -56,6 +56,7 @@ python-dateutil==2.9.0.post0 requests==2.32.3 rich==13.9.4 scikit-learn==1.6.0 +scipy==1.14.1 setuptools==75.6.0 six==1.17.0 soundfile==0.12.1 diff --git a/src/preprocessing/classify.py b/src/preprocessing/classify.py index d1fc76c..59eadd6 100644 --- a/src/preprocessing/classify.py +++ b/src/preprocessing/classify.py @@ -14,31 +14,35 @@ def class_names_from_csv(class_map_scv_text): reader = csv.DictReader(csvfile) for row in reader: class_names.append(row['display_name']) - return class_names -class_map_path = model.class_map_path().numpy() -class_names = class_names_from_csv(class_map_path) +# Main function to process audio and classify +def classify_audio(file_path): + """ + Given an audio file, this function loads the audio, resamples it, + normalizes it, and runs it through the YAMNet model to classify the sound. -wav_file_name = 'cafe_crowd_talk.wav' -waveform, sample_rate = librosa.load(wav_file_name, sr=16000) + Args: + - file_path (str): Path to the audio file (WAV, MP3, etc.). -# Show some basic information about the audio. -duration = len(waveform)/sample_rate -print(f'Sample rate: {sample_rate} Hz') -print(f'Total duration: {duration:.2f}s') -print(f'Size of the input: {len(waveform)}') + Returns: + - str: Predicted class label of the audio. + """ + # Load audio using librosa (this handles both loading, resampling, and conversion to mono) + waveform, sample_rate = librosa.load(file_path, sr=16000, mono=True) # Ensuring 16k sample rate and mono -# The waveform needs to be normalized to values in [-1.0, 1.0] (librosa load already does this) -# No need to do this as librosa already normalizes# The wav_data needs to be normalized to values in [-1.0, 1.0] - -# Execute the Model -# Check the output. -scores, embeddings, spectogram = model(waveform) -scores_np = scores.numpy() -spectogram_np = spectogram.numpy() -infered_class = class_names[scores_np.mean(axis=0).argmax()] -print(f'The main sound is : {infered_class}') + # Normalize the waveform to [-1.0, 1.0] (librosa already returns normalized values) + waveform = waveform / np.max(np.abs(waveform)) + # Execute the YAMNet model + scores, embeddings, spectrogram = model(waveform) + + # Extract the class names from the model + class_map_path = model.class_map_path().numpy() + class_names = class_names_from_csv(class_map_path) + # Find the class with the highest score + scores_np = scores.numpy() + inferred_class = class_names[scores_np.mean(axis=0).argmax()] + return inferred_class \ No newline at end of file