refactor code, use package libraries for freqsplit/preprocessing/classify

- add additional params: sr=None and mono=False for freqsplit/input/file_reader/read_audio()
- remove loading audio using librosa in freqsplit/preprocessing/classify/classify_audio()
- add error handling, if the sr is not 16Khz
- raise RuntimeError, if YAMNet model fails
- update tests/test_preprocessing
This commit is contained in:
Joel Mathew Thomas
2025-02-25 20:06:49 +05:30
parent 825da48712
commit 3074084ac1
3 changed files with 21 additions and 12 deletions
+4 -2
View File
@@ -1,12 +1,14 @@
import os
import librosa
def read_audio(file_path):
def read_audio(file_path, sr=None, mono=False):
"""
Reads an audio file and returns the audio time series and sampling rate.
Args:
file_path (str): Path to the audio file.
sr (int): Sample rate at which the audio is to be loaded
mono (bool): True to loaded audio with single channels, else False.
Returns:
tuple: audio_time_series (numpy.ndarray), sampling_rate (int)
@@ -15,7 +17,7 @@ def read_audio(file_path):
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
try:
audio, sr = librosa.load(file_path, sr=None) # Load with original sampling rate.
audio, sr = librosa.load(file_path, sr=sr, mono=mono) # Load with original sampling rate.
return audio, sr
except Exception as e:
raise RuntimeError(f"Error reading the audio file: {e}")
+15 -9
View File
@@ -3,8 +3,6 @@ import tensorflow_hub as hub
import librosa
import numpy as np
import csv
import os
# Force TensorFlow to use only CPU
tf.config.set_visible_devices([], 'GPU')
@@ -22,26 +20,34 @@ def class_names_from_csv(class_map_scv_text):
return class_names
# Main function to process audio and classify
def classify_audio(file_path):
def classify_audio(waveform, sr):
"""
Given an audio file, this function loads the audio, resamples it,
normalizes it, and runs it through the YAMNet model to classify the sound.
Args:
- file_path (str): Path to the audio file (WAV, MP3, etc.).
- waveform (numpy.ndarray): waveform of the audio file (WAV, MP3, etc.).
Returns:
- str: Predicted class label of the audio.
"""
# Load audio using librosa (this handles both loading, resampling, and conversion to mono)
waveform, sample_rate = librosa.load(file_path, sr=16000, mono=True) # Ensuring 16k sample rate and mono
# Check if the sampling rate is 16000Hz
try:
if(sr!=16000):
raise RuntimeError
except Exception:
raise RuntimeError(f"The audio is not sampled at 16000Hz, failed to classify audio.")
# Normalize the waveform to [-1.0, 1.0] (librosa already returns normalized values)
waveform = waveform / np.max(np.abs(waveform))
# Execute the YAMNet model
scores, embeddings, spectrogram = model(waveform)
try:
scores, embeddings, spectrogram = model(waveform)
except Exception as e:
raise RuntimeError(f"Error: Failed to classify audio: {e}")
# Extract the class names from the model
class_map_path = model.class_map_path().numpy()
class_names = class_names_from_csv(class_map_path)
+2 -1
View File
@@ -24,8 +24,9 @@ def test_trim_audio():
def test_classify():
file_path = "tests/test_audio/cafe_crowd_talk.wav"
waveform, sr = read_audio(file_path, 16000, mono=True)
expected_class = "Speech"
predicted_class = classify_audio(file_path)
predicted_class = classify_audio(waveform, sr)
assert predicted_class == expected_class , f"Expected {expected_class}, but got {predicted_class}"