From 3074084ac15e519d4786c68b88d62647cdc14f14 Mon Sep 17 00:00:00 2001 From: Joel Mathew Thomas <90510078+joelmathewthomas@users.noreply.github.com> Date: Tue, 25 Feb 2025 20:06:49 +0530 Subject: [PATCH] refactor code, use package libraries for freqsplit/preprocessing/classify - add additional params: sr=None and mono=False for freqsplit/input/file_reader/read_audio() - remove loading audio using librosa in freqsplit/preprocessing/classify/classify_audio() - add error handling, if the sr is not 16Khz - raise RuntimeError, if YAMNet model fails - update tests/test_preprocessing --- src/freqsplit/input/file_reader.py | 6 ++++-- src/freqsplit/preprocessing/classify.py | 24 +++++++++++++++--------- tests/test_preprocessing.py | 3 ++- 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/src/freqsplit/input/file_reader.py b/src/freqsplit/input/file_reader.py index f18b7e1..20c8c8a 100644 --- a/src/freqsplit/input/file_reader.py +++ b/src/freqsplit/input/file_reader.py @@ -1,12 +1,14 @@ import os import librosa -def read_audio(file_path): +def read_audio(file_path, sr=None, mono=False): """ Reads an audio file and returns the audio time series and sampling rate. Args: file_path (str): Path to the audio file. + sr (int): Sample rate at which the audio is to be loaded + mono (bool): True to loaded audio with single channels, else False. Returns: tuple: audio_time_series (numpy.ndarray), sampling_rate (int) @@ -15,7 +17,7 @@ def read_audio(file_path): if not os.path.exists(file_path): raise FileNotFoundError(f"File not found: {file_path}") try: - audio, sr = librosa.load(file_path, sr=None) # Load with original sampling rate. + audio, sr = librosa.load(file_path, sr=sr, mono=mono) # Load with original sampling rate. return audio, sr except Exception as e: raise RuntimeError(f"Error reading the audio file: {e}") \ No newline at end of file diff --git a/src/freqsplit/preprocessing/classify.py b/src/freqsplit/preprocessing/classify.py index cfa1783..dadf7eb 100644 --- a/src/freqsplit/preprocessing/classify.py +++ b/src/freqsplit/preprocessing/classify.py @@ -3,8 +3,6 @@ import tensorflow_hub as hub import librosa import numpy as np import csv -import os - # Force TensorFlow to use only CPU tf.config.set_visible_devices([], 'GPU') @@ -22,26 +20,34 @@ def class_names_from_csv(class_map_scv_text): return class_names # Main function to process audio and classify -def classify_audio(file_path): +def classify_audio(waveform, sr): """ Given an audio file, this function loads the audio, resamples it, normalizes it, and runs it through the YAMNet model to classify the sound. Args: - - file_path (str): Path to the audio file (WAV, MP3, etc.). + - waveform (numpy.ndarray): waveform of the audio file (WAV, MP3, etc.). Returns: - str: Predicted class label of the audio. """ - # Load audio using librosa (this handles both loading, resampling, and conversion to mono) - waveform, sample_rate = librosa.load(file_path, sr=16000, mono=True) # Ensuring 16k sample rate and mono - + + # Check if the sampling rate is 16000Hz + try: + if(sr!=16000): + raise RuntimeError + except Exception: + raise RuntimeError(f"The audio is not sampled at 16000Hz, failed to classify audio.") + # Normalize the waveform to [-1.0, 1.0] (librosa already returns normalized values) waveform = waveform / np.max(np.abs(waveform)) # Execute the YAMNet model - scores, embeddings, spectrogram = model(waveform) - + try: + scores, embeddings, spectrogram = model(waveform) + except Exception as e: + raise RuntimeError(f"Error: Failed to classify audio: {e}") + # Extract the class names from the model class_map_path = model.class_map_path().numpy() class_names = class_names_from_csv(class_map_path) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 7de07e9..787fd65 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -24,8 +24,9 @@ def test_trim_audio(): def test_classify(): file_path = "tests/test_audio/cafe_crowd_talk.wav" + waveform, sr = read_audio(file_path, 16000, mono=True) expected_class = "Speech" - predicted_class = classify_audio(file_path) + predicted_class = classify_audio(waveform, sr) assert predicted_class == expected_class , f"Expected {expected_class}, but got {predicted_class}"