refactor code, use package libraries for freqsplit/preprocessing/classify
- add additional params: sr=None and mono=False for freqsplit/input/file_reader/read_audio() - remove loading audio using librosa in freqsplit/preprocessing/classify/classify_audio() - add error handling, if the sr is not 16Khz - raise RuntimeError, if YAMNet model fails - update tests/test_preprocessing
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
import os
|
||||
import librosa
|
||||
|
||||
def read_audio(file_path):
|
||||
def read_audio(file_path, sr=None, mono=False):
|
||||
"""
|
||||
Reads an audio file and returns the audio time series and sampling rate.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the audio file.
|
||||
sr (int): Sample rate at which the audio is to be loaded
|
||||
mono (bool): True to loaded audio with single channels, else False.
|
||||
|
||||
Returns:
|
||||
tuple: audio_time_series (numpy.ndarray), sampling_rate (int)
|
||||
@@ -15,7 +17,7 @@ def read_audio(file_path):
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
try:
|
||||
audio, sr = librosa.load(file_path, sr=None) # Load with original sampling rate.
|
||||
audio, sr = librosa.load(file_path, sr=sr, mono=mono) # Load with original sampling rate.
|
||||
return audio, sr
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error reading the audio file: {e}")
|
||||
@@ -3,8 +3,6 @@ import tensorflow_hub as hub
|
||||
import librosa
|
||||
import numpy as np
|
||||
import csv
|
||||
import os
|
||||
|
||||
|
||||
# Force TensorFlow to use only CPU
|
||||
tf.config.set_visible_devices([], 'GPU')
|
||||
@@ -22,25 +20,33 @@ def class_names_from_csv(class_map_scv_text):
|
||||
return class_names
|
||||
|
||||
# Main function to process audio and classify
|
||||
def classify_audio(file_path):
|
||||
def classify_audio(waveform, sr):
|
||||
"""
|
||||
Given an audio file, this function loads the audio, resamples it,
|
||||
normalizes it, and runs it through the YAMNet model to classify the sound.
|
||||
|
||||
Args:
|
||||
- file_path (str): Path to the audio file (WAV, MP3, etc.).
|
||||
- waveform (numpy.ndarray): waveform of the audio file (WAV, MP3, etc.).
|
||||
|
||||
Returns:
|
||||
- str: Predicted class label of the audio.
|
||||
"""
|
||||
# Load audio using librosa (this handles both loading, resampling, and conversion to mono)
|
||||
waveform, sample_rate = librosa.load(file_path, sr=16000, mono=True) # Ensuring 16k sample rate and mono
|
||||
|
||||
# Check if the sampling rate is 16000Hz
|
||||
try:
|
||||
if(sr!=16000):
|
||||
raise RuntimeError
|
||||
except Exception:
|
||||
raise RuntimeError(f"The audio is not sampled at 16000Hz, failed to classify audio.")
|
||||
|
||||
# Normalize the waveform to [-1.0, 1.0] (librosa already returns normalized values)
|
||||
waveform = waveform / np.max(np.abs(waveform))
|
||||
|
||||
# Execute the YAMNet model
|
||||
try:
|
||||
scores, embeddings, spectrogram = model(waveform)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error: Failed to classify audio: {e}")
|
||||
|
||||
# Extract the class names from the model
|
||||
class_map_path = model.class_map_path().numpy()
|
||||
|
||||
@@ -24,8 +24,9 @@ def test_trim_audio():
|
||||
|
||||
def test_classify():
|
||||
file_path = "tests/test_audio/cafe_crowd_talk.wav"
|
||||
waveform, sr = read_audio(file_path, 16000, mono=True)
|
||||
expected_class = "Speech"
|
||||
predicted_class = classify_audio(file_path)
|
||||
predicted_class = classify_audio(waveform, sr)
|
||||
|
||||
assert predicted_class == expected_class , f"Expected {expected_class}, but got {predicted_class}"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user