refactor code, use package libraries for freqsplit/preprocessing/classify

- add additional params: sr=None and mono=False for freqsplit/input/file_reader/read_audio()
- remove loading audio using librosa in freqsplit/preprocessing/classify/classify_audio()
- add error handling, if the sr is not 16Khz
- raise RuntimeError, if YAMNet model fails
- update tests/test_preprocessing
This commit is contained in:
Joel Mathew Thomas
2025-02-25 20:06:49 +05:30
parent 825da48712
commit 3074084ac1
3 changed files with 21 additions and 12 deletions
+4 -2
View File
@@ -1,12 +1,14 @@
import os import os
import librosa import librosa
def read_audio(file_path): def read_audio(file_path, sr=None, mono=False):
""" """
Reads an audio file and returns the audio time series and sampling rate. Reads an audio file and returns the audio time series and sampling rate.
Args: Args:
file_path (str): Path to the audio file. file_path (str): Path to the audio file.
sr (int): Sample rate at which the audio is to be loaded
mono (bool): True to loaded audio with single channels, else False.
Returns: Returns:
tuple: audio_time_series (numpy.ndarray), sampling_rate (int) tuple: audio_time_series (numpy.ndarray), sampling_rate (int)
@@ -15,7 +17,7 @@ def read_audio(file_path):
if not os.path.exists(file_path): if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}") raise FileNotFoundError(f"File not found: {file_path}")
try: try:
audio, sr = librosa.load(file_path, sr=None) # Load with original sampling rate. audio, sr = librosa.load(file_path, sr=sr, mono=mono) # Load with original sampling rate.
return audio, sr return audio, sr
except Exception as e: except Exception as e:
raise RuntimeError(f"Error reading the audio file: {e}") raise RuntimeError(f"Error reading the audio file: {e}")
+15 -9
View File
@@ -3,8 +3,6 @@ import tensorflow_hub as hub
import librosa import librosa
import numpy as np import numpy as np
import csv import csv
import os
# Force TensorFlow to use only CPU # Force TensorFlow to use only CPU
tf.config.set_visible_devices([], 'GPU') tf.config.set_visible_devices([], 'GPU')
@@ -22,26 +20,34 @@ def class_names_from_csv(class_map_scv_text):
return class_names return class_names
# Main function to process audio and classify # Main function to process audio and classify
def classify_audio(file_path): def classify_audio(waveform, sr):
""" """
Given an audio file, this function loads the audio, resamples it, Given an audio file, this function loads the audio, resamples it,
normalizes it, and runs it through the YAMNet model to classify the sound. normalizes it, and runs it through the YAMNet model to classify the sound.
Args: Args:
- file_path (str): Path to the audio file (WAV, MP3, etc.). - waveform (numpy.ndarray): waveform of the audio file (WAV, MP3, etc.).
Returns: Returns:
- str: Predicted class label of the audio. - str: Predicted class label of the audio.
""" """
# Load audio using librosa (this handles both loading, resampling, and conversion to mono)
waveform, sample_rate = librosa.load(file_path, sr=16000, mono=True) # Ensuring 16k sample rate and mono # Check if the sampling rate is 16000Hz
try:
if(sr!=16000):
raise RuntimeError
except Exception:
raise RuntimeError(f"The audio is not sampled at 16000Hz, failed to classify audio.")
# Normalize the waveform to [-1.0, 1.0] (librosa already returns normalized values) # Normalize the waveform to [-1.0, 1.0] (librosa already returns normalized values)
waveform = waveform / np.max(np.abs(waveform)) waveform = waveform / np.max(np.abs(waveform))
# Execute the YAMNet model # Execute the YAMNet model
scores, embeddings, spectrogram = model(waveform) try:
scores, embeddings, spectrogram = model(waveform)
except Exception as e:
raise RuntimeError(f"Error: Failed to classify audio: {e}")
# Extract the class names from the model # Extract the class names from the model
class_map_path = model.class_map_path().numpy() class_map_path = model.class_map_path().numpy()
class_names = class_names_from_csv(class_map_path) class_names = class_names_from_csv(class_map_path)
+2 -1
View File
@@ -24,8 +24,9 @@ def test_trim_audio():
def test_classify(): def test_classify():
file_path = "tests/test_audio/cafe_crowd_talk.wav" file_path = "tests/test_audio/cafe_crowd_talk.wav"
waveform, sr = read_audio(file_path, 16000, mono=True)
expected_class = "Speech" expected_class = "Speech"
predicted_class = classify_audio(file_path) predicted_class = classify_audio(waveform, sr)
assert predicted_class == expected_class , f"Expected {expected_class}, but got {predicted_class}" assert predicted_class == expected_class , f"Expected {expected_class}, but got {predicted_class}"