Merge pull request #4 from joelmathewthomas/feature/preprocessing-classify
Implement audio classification using YAMNet in preprocessing pipeline
This commit is contained in:
@@ -0,0 +1 @@
|
|||||||
|
3.12.7
|
||||||
@@ -1,28 +1,78 @@
|
|||||||
|
absl-py==2.1.0
|
||||||
|
asttokens==3.0.0
|
||||||
|
astunparse==1.6.3
|
||||||
audioread==3.0.1
|
audioread==3.0.1
|
||||||
certifi==2024.12.14
|
certifi==2024.12.14
|
||||||
cffi==1.17.1
|
cffi==1.17.1
|
||||||
charset-normalizer==3.4.0
|
charset-normalizer==3.4.0
|
||||||
|
contourpy==1.3.1
|
||||||
|
cycler==0.12.1
|
||||||
decorator==5.1.1
|
decorator==5.1.1
|
||||||
|
executing==2.1.0
|
||||||
|
flatbuffers==24.12.23
|
||||||
|
fonttools==4.55.3
|
||||||
|
gast==0.6.0
|
||||||
|
google-pasta==0.2.0
|
||||||
|
grpcio==1.68.1
|
||||||
|
h5py==3.12.1
|
||||||
idna==3.10
|
idna==3.10
|
||||||
iniconfig==2.0.0
|
iniconfig==2.0.0
|
||||||
|
jedi==0.19.2
|
||||||
joblib==1.4.2
|
joblib==1.4.2
|
||||||
|
keras==3.7.0
|
||||||
|
kiwisolver==1.4.8
|
||||||
lazy_loader==0.4
|
lazy_loader==0.4
|
||||||
|
libclang==18.1.1
|
||||||
librosa==0.10.2.post1
|
librosa==0.10.2.post1
|
||||||
llvmlite==0.43.0
|
llvmlite==0.43.0
|
||||||
|
Markdown==3.7
|
||||||
|
markdown-it-py==3.0.0
|
||||||
|
MarkupSafe==3.0.2
|
||||||
|
matplotlib-inline==0.1.7
|
||||||
|
mdurl==0.1.2
|
||||||
|
ml-dtypes==0.4.1
|
||||||
msgpack==1.1.0
|
msgpack==1.1.0
|
||||||
|
namex==0.0.8
|
||||||
numba==0.60.0
|
numba==0.60.0
|
||||||
numpy==2.0.2
|
numpy==2.0.2
|
||||||
|
opt_einsum==3.4.0
|
||||||
|
optree==0.13.1
|
||||||
packaging==24.2
|
packaging==24.2
|
||||||
|
parso==0.8.4
|
||||||
|
pexpect==4.9.0
|
||||||
|
pillow==11.0.0
|
||||||
platformdirs==4.3.6
|
platformdirs==4.3.6
|
||||||
pluggy==1.5.0
|
pluggy==1.5.0
|
||||||
pooch==1.8.2
|
pooch==1.8.2
|
||||||
|
prompt_toolkit==3.0.48
|
||||||
|
protobuf==5.29.2
|
||||||
|
ptyprocess==0.7.0
|
||||||
|
pure_eval==0.2.3
|
||||||
pycparser==2.22
|
pycparser==2.22
|
||||||
|
Pygments==2.18.0
|
||||||
|
pyparsing==3.2.0
|
||||||
pytest==8.3.4
|
pytest==8.3.4
|
||||||
|
python-dateutil==2.9.0.post0
|
||||||
requests==2.32.3
|
requests==2.32.3
|
||||||
|
rich==13.9.4
|
||||||
scikit-learn==1.6.0
|
scikit-learn==1.6.0
|
||||||
scipy==1.14.1
|
scipy==1.14.1
|
||||||
|
setuptools==75.6.0
|
||||||
|
six==1.17.0
|
||||||
soundfile==0.12.1
|
soundfile==0.12.1
|
||||||
soxr==0.5.0.post1
|
soxr==0.5.0.post1
|
||||||
|
stack-data==0.6.3
|
||||||
|
tensorboard==2.18.0
|
||||||
|
tensorboard-data-server==0.7.2
|
||||||
|
tensorflow==2.18.0
|
||||||
|
tensorflow-hub==0.16.1
|
||||||
|
termcolor==2.5.0
|
||||||
|
tf_keras==2.18.0
|
||||||
threadpoolctl==3.5.0
|
threadpoolctl==3.5.0
|
||||||
|
traitlets==5.14.3
|
||||||
typing_extensions==4.12.2
|
typing_extensions==4.12.2
|
||||||
urllib3==2.3.0
|
urllib3==2.3.0
|
||||||
|
wcwidth==0.2.13
|
||||||
|
Werkzeug==3.1.3
|
||||||
|
wheel==0.45.1
|
||||||
|
wrapt==1.17.0
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,48 @@
|
|||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow_hub as hub
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import csv
|
||||||
|
|
||||||
|
model = hub.load('https://tfhub.dev/google/yamnet/1')
|
||||||
|
|
||||||
|
#Find the name of the class with the top score when mean-aggregated across frames.
|
||||||
|
def class_names_from_csv(class_map_scv_text):
|
||||||
|
"""Returns list of class names corresponding to score vector."""
|
||||||
|
class_names = []
|
||||||
|
with tf.io.gfile.GFile(class_map_scv_text) as csvfile:
|
||||||
|
reader = csv.DictReader(csvfile)
|
||||||
|
for row in reader:
|
||||||
|
class_names.append(row['display_name'])
|
||||||
|
return class_names
|
||||||
|
|
||||||
|
# Main function to process audio and classify
|
||||||
|
def classify_audio(file_path):
|
||||||
|
"""
|
||||||
|
Given an audio file, this function loads the audio, resamples it,
|
||||||
|
normalizes it, and runs it through the YAMNet model to classify the sound.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- file_path (str): Path to the audio file (WAV, MP3, etc.).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- str: Predicted class label of the audio.
|
||||||
|
"""
|
||||||
|
# Load audio using librosa (this handles both loading, resampling, and conversion to mono)
|
||||||
|
waveform, sample_rate = librosa.load(file_path, sr=16000, mono=True) # Ensuring 16k sample rate and mono
|
||||||
|
|
||||||
|
# Normalize the waveform to [-1.0, 1.0] (librosa already returns normalized values)
|
||||||
|
waveform = waveform / np.max(np.abs(waveform))
|
||||||
|
|
||||||
|
# Execute the YAMNet model
|
||||||
|
scores, embeddings, spectrogram = model(waveform)
|
||||||
|
|
||||||
|
# Extract the class names from the model
|
||||||
|
class_map_path = model.class_map_path().numpy()
|
||||||
|
class_names = class_names_from_csv(class_map_path)
|
||||||
|
|
||||||
|
# Find the class with the highest score
|
||||||
|
scores_np = scores.numpy()
|
||||||
|
inferred_class = class_names[scores_np.mean(axis=0).argmax()]
|
||||||
|
|
||||||
|
return inferred_class
|
||||||
@@ -2,6 +2,7 @@ import pytest
|
|||||||
import librosa
|
import librosa
|
||||||
from src.preprocessing.normalize import normalize_audio
|
from src.preprocessing.normalize import normalize_audio
|
||||||
from src.preprocessing.trim import trim_audio
|
from src.preprocessing.trim import trim_audio
|
||||||
|
from src.preprocessing.classify import classify_audio
|
||||||
from src.input.file_reader import read_audio
|
from src.input.file_reader import read_audio
|
||||||
|
|
||||||
def test_normalize_audio():
|
def test_normalize_audio():
|
||||||
@@ -18,3 +19,10 @@ def test_trim_audio():
|
|||||||
trimmed_audio = trim_audio(audio, sr)
|
trimmed_audio = trim_audio(audio, sr)
|
||||||
|
|
||||||
assert len(trimmed_audio) <= len(audio)
|
assert len(trimmed_audio) <= len(audio)
|
||||||
|
|
||||||
|
def test_classify():
|
||||||
|
file_path = "samples/cafe_crowd_talk.wav"
|
||||||
|
expected_class = "Speech"
|
||||||
|
predicted_class = classify_audio(file_path)
|
||||||
|
|
||||||
|
assert predicted_class == expected_class , f"Expected {expected_class}, but got {predicted_class}"
|
||||||
Reference in New Issue
Block a user