diff --git a/samples/cafe_crowd_talk.wav b/samples/cafe_crowd_talk.wav new file mode 100644 index 0000000..09f2872 Binary files /dev/null and b/samples/cafe_crowd_talk.wav differ diff --git a/samples/miaow_16k.wav b/samples/miaow_16k.wav new file mode 100644 index 0000000..98478c1 Binary files /dev/null and b/samples/miaow_16k.wav differ diff --git a/samples/requirements.txt b/samples/requirements.txt new file mode 100644 index 0000000..dc35a38 --- /dev/null +++ b/samples/requirements.txt @@ -0,0 +1,80 @@ +absl-py==2.1.0 +asttokens==3.0.0 +astunparse==1.6.3 +audioread==3.0.1 +certifi==2024.12.14 +cffi==1.17.1 +charset-normalizer==3.4.0 +contourpy==1.3.1 +cycler==0.12.1 +decorator==5.1.1 +executing==2.1.0 +flatbuffers==24.12.23 +fonttools==4.55.3 +gast==0.6.0 +google-pasta==0.2.0 +grpcio==1.68.1 +h5py==3.12.1 +idna==3.10 +iniconfig==2.0.0 +ipython==8.31.0 +jedi==0.19.2 +joblib==1.4.2 +keras==3.7.0 +kiwisolver==1.4.8 +lazy_loader==0.4 +libclang==18.1.1 +librosa==0.10.2.post1 +llvmlite==0.43.0 +Markdown==3.7 +markdown-it-py==3.0.0 +MarkupSafe==3.0.2 +matplotlib==3.10.0 +matplotlib-inline==0.1.7 +mdurl==0.1.2 +ml-dtypes==0.4.1 +msgpack==1.1.0 +namex==0.0.8 +numba==0.60.0 +numpy==2.0.2 +opt_einsum==3.4.0 +optree==0.13.1 +packaging==24.2 +parso==0.8.4 +pexpect==4.9.0 +pillow==11.0.0 +platformdirs==4.3.6 +pluggy==1.5.0 +pooch==1.8.2 +prompt_toolkit==3.0.48 +protobuf==5.29.2 +ptyprocess==0.7.0 +pure_eval==0.2.3 +pycparser==2.22 +Pygments==2.18.0 +pyparsing==3.2.0 +pytest==8.3.4 +python-dateutil==2.9.0.post0 +requests==2.32.3 +rich==13.9.4 +scikit-learn==1.6.0 +scipy==1.14.1 +setuptools==75.6.0 +six==1.17.0 +soundfile==0.12.1 +soxr==0.5.0.post1 +stack-data==0.6.3 +tensorboard==2.18.0 +tensorboard-data-server==0.7.2 +tensorflow==2.18.0 +tensorflow-hub==0.16.1 +termcolor==2.5.0 +tf_keras==2.18.0 +threadpoolctl==3.5.0 +traitlets==5.14.3 +typing_extensions==4.12.2 +urllib3==2.3.0 +wcwidth==0.2.13 +Werkzeug==3.1.3 +wheel==0.45.1 +wrapt==1.17.0 diff --git a/samples/speech_whistling2.wav b/samples/speech_whistling2.wav new file mode 100644 index 0000000..76b1ebb Binary files /dev/null and b/samples/speech_whistling2.wav differ diff --git a/src/preprocessing/classify.py b/src/preprocessing/classify.py new file mode 100644 index 0000000..a96ecf4 --- /dev/null +++ b/src/preprocessing/classify.py @@ -0,0 +1,58 @@ +import tensorflow as tf +import tensorflow_hub as hub +import numpy as np +import csv + +import matplotlib.pyplot as plt +from IPython.display import Audio +from scipy.io import wavfile +from scipy import signal + +model = hub.load('https://tfhub.dev/google/yamnet/1') + +#Find the name of the class with the top score when mean-aggregated across frames. +def class_names_from_csv(class_map_scv_text): + """Returns list of class names corresponding to score vector.""" + class_names = [] + with tf.io.gfile.GFile(class_map_scv_text) as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + class_names.append(row['display_name']) + + return class_names + +class_map_path = model.class_map_path().numpy() +class_names = class_names_from_csv(class_map_path) + +# Resample audio to 16K +def ensure_sample_rate(original_sample_rate, waveform, desired_sample_rate=16000): + """Resample waveform if required.""" + if original_sample_rate != desired_sample_rate: + desired_length = int(round(float(len(waveform)) / original_sample_rate * desired_sample_rate)) + waveform = signal.resample(waveform, desired_length) + return desired_sample_rate, waveform + +# wav_file_name = 'speech_whistling2.wav' +wav_file_name = 'cafe_crowd_talk.wav' +sample_rate, wav_data = wavfile.read(wav_file_name, 'rb') +sample_rate, wav_data = ensure_sample_rate(sample_rate, wav_data) + +# Show some basic information about the audio. +duration = len(wav_data)/sample_rate +print(f'Sample rate: {sample_rate} Hz') +print(f'Total duration: {duration:.2f}s') +print(f'Size of the input: {len(wav_data)}') + +# The wav_data needs to be normalized to values in [-1.0, 1.0] +waveform = wav_data / tf.int16.max + +# Execute the Model +# Check the output. +scores, embeddings, spectogram = model(waveform) +scores_np = scores.numpy() +spectogram_np = spectogram.numpy() +infered_class = class_names[scores_np.mean(axis=0).argmax()] +print(f'The main sound is : {infered_class}') + + +