diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..56bb660 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.12.7 diff --git a/requirements.txt b/requirements.txt index 5b77076..ac44dc3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,28 +1,78 @@ +absl-py==2.1.0 +asttokens==3.0.0 +astunparse==1.6.3 audioread==3.0.1 certifi==2024.12.14 cffi==1.17.1 charset-normalizer==3.4.0 +contourpy==1.3.1 +cycler==0.12.1 decorator==5.1.1 +executing==2.1.0 +flatbuffers==24.12.23 +fonttools==4.55.3 +gast==0.6.0 +google-pasta==0.2.0 +grpcio==1.68.1 +h5py==3.12.1 idna==3.10 iniconfig==2.0.0 +jedi==0.19.2 joblib==1.4.2 +keras==3.7.0 +kiwisolver==1.4.8 lazy_loader==0.4 +libclang==18.1.1 librosa==0.10.2.post1 llvmlite==0.43.0 +Markdown==3.7 +markdown-it-py==3.0.0 +MarkupSafe==3.0.2 +matplotlib-inline==0.1.7 +mdurl==0.1.2 +ml-dtypes==0.4.1 msgpack==1.1.0 +namex==0.0.8 numba==0.60.0 numpy==2.0.2 +opt_einsum==3.4.0 +optree==0.13.1 packaging==24.2 +parso==0.8.4 +pexpect==4.9.0 +pillow==11.0.0 platformdirs==4.3.6 pluggy==1.5.0 pooch==1.8.2 +prompt_toolkit==3.0.48 +protobuf==5.29.2 +ptyprocess==0.7.0 +pure_eval==0.2.3 pycparser==2.22 +Pygments==2.18.0 +pyparsing==3.2.0 pytest==8.3.4 +python-dateutil==2.9.0.post0 requests==2.32.3 +rich==13.9.4 scikit-learn==1.6.0 scipy==1.14.1 +setuptools==75.6.0 +six==1.17.0 soundfile==0.12.1 soxr==0.5.0.post1 +stack-data==0.6.3 +tensorboard==2.18.0 +tensorboard-data-server==0.7.2 +tensorflow==2.18.0 +tensorflow-hub==0.16.1 +termcolor==2.5.0 +tf_keras==2.18.0 threadpoolctl==3.5.0 +traitlets==5.14.3 typing_extensions==4.12.2 urllib3==2.3.0 +wcwidth==0.2.13 +Werkzeug==3.1.3 +wheel==0.45.1 +wrapt==1.17.0 diff --git a/samples/cafe_crowd_talk.wav b/samples/cafe_crowd_talk.wav new file mode 100644 index 0000000..09f2872 Binary files /dev/null and b/samples/cafe_crowd_talk.wav differ diff --git a/samples/miaow_16k.wav b/samples/miaow_16k.wav new file mode 100644 index 0000000..98478c1 Binary files /dev/null and b/samples/miaow_16k.wav differ diff --git a/samples/speech_whistling2.wav b/samples/speech_whistling2.wav new file mode 100644 index 0000000..76b1ebb Binary files /dev/null and b/samples/speech_whistling2.wav differ diff --git a/src/preprocessing/classify.py b/src/preprocessing/classify.py new file mode 100644 index 0000000..59eadd6 --- /dev/null +++ b/src/preprocessing/classify.py @@ -0,0 +1,48 @@ +import tensorflow as tf +import tensorflow_hub as hub +import librosa +import numpy as np +import csv + +model = hub.load('https://tfhub.dev/google/yamnet/1') + +#Find the name of the class with the top score when mean-aggregated across frames. +def class_names_from_csv(class_map_scv_text): + """Returns list of class names corresponding to score vector.""" + class_names = [] + with tf.io.gfile.GFile(class_map_scv_text) as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + class_names.append(row['display_name']) + return class_names + +# Main function to process audio and classify +def classify_audio(file_path): + """ + Given an audio file, this function loads the audio, resamples it, + normalizes it, and runs it through the YAMNet model to classify the sound. + + Args: + - file_path (str): Path to the audio file (WAV, MP3, etc.). + + Returns: + - str: Predicted class label of the audio. + """ + # Load audio using librosa (this handles both loading, resampling, and conversion to mono) + waveform, sample_rate = librosa.load(file_path, sr=16000, mono=True) # Ensuring 16k sample rate and mono + + # Normalize the waveform to [-1.0, 1.0] (librosa already returns normalized values) + waveform = waveform / np.max(np.abs(waveform)) + + # Execute the YAMNet model + scores, embeddings, spectrogram = model(waveform) + + # Extract the class names from the model + class_map_path = model.class_map_path().numpy() + class_names = class_names_from_csv(class_map_path) + + # Find the class with the highest score + scores_np = scores.numpy() + inferred_class = class_names[scores_np.mean(axis=0).argmax()] + + return inferred_class \ No newline at end of file diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 208bb5b..0a57873 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -2,6 +2,7 @@ import pytest import librosa from src.preprocessing.normalize import normalize_audio from src.preprocessing.trim import trim_audio +from src.preprocessing.classify import classify_audio from src.input.file_reader import read_audio def test_normalize_audio(): @@ -17,4 +18,11 @@ def test_trim_audio(): audio, sr = read_audio(file_path) trimmed_audio = trim_audio(audio, sr) - assert len(trimmed_audio) <= len(audio) \ No newline at end of file + assert len(trimmed_audio) <= len(audio) + +def test_classify(): + file_path = "samples/cafe_crowd_talk.wav" + expected_class = "Speech" + predicted_class = classify_audio(file_path) + + assert predicted_class == expected_class , f"Expected {expected_class}, but got {predicted_class}" \ No newline at end of file