Replace YAMNet model for panns-inference

YAMNet model was causing issues, as loading a pytorch framework model, when the tensorflow based YAMNet is loaded in the same environment already, caused segmentation fault
This commit is contained in:
Joel Mathew Thomas
2025-02-26 17:36:27 +05:30
parent f2011b4408
commit cbf2b022a5
4 changed files with 70 additions and 160 deletions
+24 -61
View File
@@ -6,76 +6,51 @@ build-backend = "setuptools.build_meta"
name = "freqsplit" name = "freqsplit"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"absl-py==2.1.0", "aiohappyeyeballs==2.4.6",
"aiohappyeyeballs==2.4.4", "aiohttp==3.11.13",
"aiohttp==3.11.11",
"aiosignal==1.3.2", "aiosignal==1.3.2",
"amqp==5.3.1",
"antlr4-python3-runtime==4.9.3", "antlr4-python3-runtime==4.9.3",
"appdirs==1.4.4", "appdirs==1.4.4",
"asgiref==3.8.1",
"asteroid==0.7.0", "asteroid==0.7.0",
"asteroid-filterbanks==0.4.0", "asteroid-filterbanks==0.4.0",
"astunparse==1.6.3",
"attrs==25.1.0", "attrs==25.1.0",
"audioread==3.0.1", "audioread==3.0.1",
"billiard==4.2.1",
"cached-property==2.0.1", "cached-property==2.0.1",
"celery==5.4.0", "certifi==2025.1.31",
"certifi==2024.12.14",
"cffi==1.17.1", "cffi==1.17.1",
"charset-normalizer==3.4.1", "charset-normalizer==3.4.1",
"click==8.1.8",
"click-didyoumean==0.3.1",
"click-plugins==1.1.1",
"click-repl==0.3.0",
"cloudpickle==3.1.1", "cloudpickle==3.1.1",
"contourpy==1.3.1", "contourpy==1.3.1",
"cycler==0.12.1", "cycler==0.12.1",
"decorator==5.1.1", "decorator==5.2.1",
"DeepFilterLib==0.5.6", "DeepFilterLib==0.5.6",
"DeepFilterNet==0.5.6", "DeepFilterNet==0.5.6",
"demucs==4.0.1", "demucs==4.0.1",
"Django==5.1.6",
"dora_search==0.1.12", "dora_search==0.1.12",
"einops==0.8.0", "einops==0.8.1",
"filelock==3.17.0", "filelock==3.17.0",
"flatbuffers==25.2.10", "fonttools==4.56.0",
"fonttools==4.55.6",
"frozenlist==1.5.0", "frozenlist==1.5.0",
"fsspec==2024.12.0", "fsspec==2025.2.0",
"future==1.0.0", "huggingface-hub==0.29.1",
"gast==0.6.0",
"google-pasta==0.2.0",
"grpcio==1.70.0",
"h5py==3.13.0",
"huggingface-hub==0.28.0",
"idna==3.10", "idna==3.10",
"iniconfig==2.0.0", "iniconfig==2.0.0",
"Jinja2==3.1.5", "Jinja2==3.1.5",
"joblib==1.4.2", "joblib==1.4.2",
"julius==0.2.7", "julius==0.2.7",
"keras==3.8.0",
"kiwisolver==1.4.8", "kiwisolver==1.4.8",
"kombu==5.4.2",
"lameenc==1.8.1", "lameenc==1.8.1",
"lazy_loader==0.4", "lazy_loader==0.4",
"libclang==18.1.1",
"librosa==0.10.2.post1", "librosa==0.10.2.post1",
"lightning-utilities==0.11.9", "lightning-utilities==0.12.0",
"llvmlite==0.44.0", "llvmlite==0.44.0",
"loguru==0.7.3", "loguru==0.7.3",
"Markdown==3.7",
"markdown-it-py==3.0.0",
"MarkupSafe==3.0.2", "MarkupSafe==3.0.2",
"matplotlib==3.10.0", "matplotlib==3.10.0",
"mdurl==0.1.2", "mir_eval==0.8.2",
"mir_eval==0.7",
"ml-dtypes==0.4.1",
"mpmath==1.3.0", "mpmath==1.3.0",
"msgpack==1.1.0", "msgpack==1.1.0",
"multidict==6.1.0", "multidict==6.1.0",
"namex==0.0.8",
"networkx==3.4.2", "networkx==3.4.2",
"numba==0.61.0", "numba==0.61.0",
"numpy==1.26.4", "numpy==1.26.4",
@@ -88,70 +63,58 @@ dependencies = [
"nvidia-curand-cu12==10.3.5.147", "nvidia-curand-cu12==10.3.5.147",
"nvidia-cusolver-cu12==11.6.1.9", "nvidia-cusolver-cu12==11.6.1.9",
"nvidia-cusparse-cu12==12.3.1.170", "nvidia-cusparse-cu12==12.3.1.170",
"nvidia-cusparselt-cu12==0.6.2",
"nvidia-nccl-cu12==2.21.5", "nvidia-nccl-cu12==2.21.5",
"nvidia-nvjitlink-cu12==12.4.127", "nvidia-nvjitlink-cu12==12.4.127",
"nvidia-nvtx-cu12==12.4.127", "nvidia-nvtx-cu12==12.4.127",
"omegaconf==2.3.0", "omegaconf==2.3.0",
"openunmix==1.3.0", "openunmix==1.3.0",
"opt_einsum==3.4.0",
"optree==0.14.0",
"packaging==23.2", "packaging==23.2",
"pandas==2.2.3", "pandas==2.2.3",
"panns-inference==0.1.1",
"pb-bss-eval==0.0.2", "pb-bss-eval==0.0.2",
"pesq==0.0.4", "pesq==0.0.4",
"pillow==11.1.0", "pillow==11.1.0",
"platformdirs==4.3.6", "platformdirs==4.3.6",
"pluggy==1.5.0", "pluggy==1.5.0",
"pooch==1.8.2", "pooch==1.8.2",
"prompt_toolkit==3.0.50", "propcache==0.3.0",
"propcache==0.2.1",
"protobuf==5.29.3",
"pycparser==2.22", "pycparser==2.22",
"Pygments==2.19.1",
"pyparsing==3.2.1", "pyparsing==3.2.1",
"pystoi==0.4.1", "pystoi==0.4.1",
"pytest==8.3.4", "pytest==8.3.4",
"python-dateutil==2.9.0.post0", "python-dateutil==2.9.0.post0",
"pytorch-lightning==2.5.0.post0", "pytorch-lightning==2.5.0.post0",
"pytorch-ranger==0.1.1", "pytorch-ranger==0.1.1",
"pytz==2024.2", "pytz==2025.1",
"PyYAML==6.0.2", "PyYAML==6.0.2",
"redis==5.2.1", "regex==2024.11.6",
"requests==2.32.3", "requests==2.32.3",
"retrying==1.3.4", "retrying==1.3.4",
"rich==13.9.4", "safetensors==0.5.3",
"scikit-learn==1.6.1", "scikit-learn==1.6.1",
"scipy==1.15.1", "scipy==1.15.2",
"setuptools==75.8.0", "setuptools==75.8.1",
"six==1.17.0", "six==1.17.0",
"soundfile==0.13.1", "soundfile==0.13.1",
"soxr==0.5.0.post1", "soxr==0.5.0.post1",
"sqlparse==0.5.3",
"submitit==1.5.2", "submitit==1.5.2",
"sympy==1.13.1", "sympy==1.13.1",
"tensorboard==2.18.0",
"tensorboard-data-server==0.7.2",
"tensorflow==2.18.0",
"tensorflow-hub==0.16.1",
"termcolor==2.5.0",
"tf_keras==2.18.0",
"threadpoolctl==3.5.0", "threadpoolctl==3.5.0",
"torch==2.5.1", "tokenizers==0.21.0",
"torch==2.6.0",
"torch-optimizer==0.1.0", "torch-optimizer==0.1.0",
"torch-stoi==0.2.3", "torch-stoi==0.2.3",
"torchaudio==2.5.1", "torchaudio==2.6.0",
"torchlibrosa==0.1.0",
"torchmetrics==0.11.4", "torchmetrics==0.11.4",
"tqdm==4.67.1", "tqdm==4.67.1",
"transformers==4.49.0",
"treetable==0.2.5", "treetable==0.2.5",
"triton==3.1.0", "triton==3.2.0",
"typing_extensions==4.12.2", "typing_extensions==4.12.2",
"tzdata==2025.1", "tzdata==2025.1",
"urllib3==2.3.0", "urllib3==2.3.0",
"vine==5.1.0",
"wcwidth==0.2.13",
"Werkzeug==3.1.3",
"wheel==0.45.1",
"wrapt==1.17.2",
"yarl==1.18.3", "yarl==1.18.3",
] ]
+24 -62
View File
@@ -1,74 +1,48 @@
absl-py==2.1.0 aiohappyeyeballs==2.4.6
aiohappyeyeballs==2.4.4 aiohttp==3.11.13
aiohttp==3.11.11
aiosignal==1.3.2 aiosignal==1.3.2
amqp==5.3.1
antlr4-python3-runtime==4.9.3 antlr4-python3-runtime==4.9.3
appdirs==1.4.4 appdirs==1.4.4
asgiref==3.8.1
asteroid==0.7.0 asteroid==0.7.0
asteroid-filterbanks==0.4.0 asteroid-filterbanks==0.4.0
astunparse==1.6.3
attrs==25.1.0 attrs==25.1.0
audioread==3.0.1 audioread==3.0.1
billiard==4.2.1
cached-property==2.0.1 cached-property==2.0.1
celery==5.4.0 certifi==2025.1.31
certifi==2024.12.14
cffi==1.17.1 cffi==1.17.1
charset-normalizer==3.4.1 charset-normalizer==3.4.1
click==8.1.8
click-didyoumean==0.3.1
click-plugins==1.1.1
click-repl==0.3.0
cloudpickle==3.1.1 cloudpickle==3.1.1
contourpy==1.3.1 contourpy==1.3.1
cycler==0.12.1 cycler==0.12.1
decorator==5.1.1 decorator==5.2.1
DeepFilterLib==0.5.6 DeepFilterLib==0.5.6
DeepFilterNet==0.5.6 DeepFilterNet==0.5.6
demucs==4.0.1 demucs==4.0.1
Django==5.1.6
djangorestframework==3.15.2
dora_search==0.1.12 dora_search==0.1.12
einops==0.8.0 einops==0.8.1
filelock==3.17.0 filelock==3.17.0
flatbuffers==25.2.10 fonttools==4.56.0
fonttools==4.55.6
frozenlist==1.5.0 frozenlist==1.5.0
fsspec==2024.12.0 fsspec==2025.2.0
future==1.0.0 huggingface-hub==0.29.1
gast==0.6.0
google-pasta==0.2.0
grpcio==1.70.0
h5py==3.13.0
huggingface-hub==0.28.0
idna==3.10 idna==3.10
iniconfig==2.0.0 iniconfig==2.0.0
Jinja2==3.1.5 Jinja2==3.1.5
joblib==1.4.2 joblib==1.4.2
julius==0.2.7 julius==0.2.7
keras==3.8.0
kiwisolver==1.4.8 kiwisolver==1.4.8
kombu==5.4.2
lameenc==1.8.1 lameenc==1.8.1
lazy_loader==0.4 lazy_loader==0.4
libclang==18.1.1
librosa==0.10.2.post1 librosa==0.10.2.post1
lightning-utilities==0.11.9 lightning-utilities==0.12.0
llvmlite==0.44.0 llvmlite==0.44.0
loguru==0.7.3 loguru==0.7.3
Markdown==3.7
markdown-it-py==3.0.0
MarkupSafe==3.0.2 MarkupSafe==3.0.2
matplotlib==3.10.0 matplotlib==3.10.0
mdurl==0.1.2 mir_eval==0.8.2
mir_eval==0.7
ml-dtypes==0.4.1
mpmath==1.3.0 mpmath==1.3.0
msgpack==1.1.0 msgpack==1.1.0
multidict==6.1.0 multidict==6.1.0
namex==0.0.8
networkx==3.4.2 networkx==3.4.2
numba==0.61.0 numba==0.61.0
numpy==1.26.4 numpy==1.26.4
@@ -81,68 +55,56 @@ nvidia-cufft-cu12==11.2.1.3
nvidia-curand-cu12==10.3.5.147 nvidia-curand-cu12==10.3.5.147
nvidia-cusolver-cu12==11.6.1.9 nvidia-cusolver-cu12==11.6.1.9
nvidia-cusparse-cu12==12.3.1.170 nvidia-cusparse-cu12==12.3.1.170
nvidia-cusparselt-cu12==0.6.2
nvidia-nccl-cu12==2.21.5 nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.4.127 nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.4.127 nvidia-nvtx-cu12==12.4.127
omegaconf==2.3.0 omegaconf==2.3.0
openunmix==1.3.0 openunmix==1.3.0
opt_einsum==3.4.0
optree==0.14.0
packaging==23.2 packaging==23.2
pandas==2.2.3 pandas==2.2.3
panns-inference==0.1.1
pb-bss-eval==0.0.2 pb-bss-eval==0.0.2
pesq==0.0.4 pesq==0.0.4
pillow==11.1.0 pillow==11.1.0
platformdirs==4.3.6 platformdirs==4.3.6
pluggy==1.5.0 pluggy==1.5.0
pooch==1.8.2 pooch==1.8.2
prompt_toolkit==3.0.50 propcache==0.3.0
propcache==0.2.1
protobuf==5.29.3
pycparser==2.22 pycparser==2.22
Pygments==2.19.1
pyparsing==3.2.1 pyparsing==3.2.1
pystoi==0.4.1 pystoi==0.4.1
pytest==8.3.4 pytest==8.3.4
python-dateutil==2.9.0.post0 python-dateutil==2.9.0.post0
pytorch-lightning==2.5.0.post0 pytorch-lightning==2.5.0.post0
pytorch-ranger==0.1.1 pytorch-ranger==0.1.1
pytz==2024.2 pytz==2025.1
PyYAML==6.0.2 PyYAML==6.0.2
redis==5.2.1 regex==2024.11.6
requests==2.32.3 requests==2.32.3
retrying==1.3.4 retrying==1.3.4
rich==13.9.4 safetensors==0.5.3
scikit-learn==1.6.1 scikit-learn==1.6.1
scipy==1.15.1 scipy==1.15.2
setuptools==75.8.0 setuptools==75.8.1
six==1.17.0 six==1.17.0
soundfile==0.13.1 soundfile==0.13.1
soxr==0.5.0.post1 soxr==0.5.0.post1
sqlparse==0.5.3
submitit==1.5.2 submitit==1.5.2
sympy==1.13.1 sympy==1.13.1
tensorboard==2.18.0
tensorboard-data-server==0.7.2
tensorflow==2.18.0
tensorflow-hub==0.16.1
termcolor==2.5.0
tf_keras==2.18.0
threadpoolctl==3.5.0 threadpoolctl==3.5.0
torch==2.5.1 tokenizers==0.21.0
torch==2.6.0
torch-optimizer==0.1.0 torch-optimizer==0.1.0
torch-stoi==0.2.3 torch-stoi==0.2.3
torchaudio==2.5.1 torchaudio==2.6.0
torchlibrosa==0.1.0
torchmetrics==0.11.4 torchmetrics==0.11.4
tqdm==4.67.1 tqdm==4.67.1
transformers==4.49.0
treetable==0.2.5 treetable==0.2.5
triton==3.1.0 triton==3.2.0
typing_extensions==4.12.2 typing_extensions==4.12.2
tzdata==2025.1 tzdata==2025.1
urllib3==2.3.0 urllib3==2.3.0
vine==5.1.0
wcwidth==0.2.13
Werkzeug==3.1.3
wheel==0.45.1
wrapt==1.17.2
yarl==1.18.3 yarl==1.18.3
+21 -36
View File
@@ -1,58 +1,43 @@
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np import numpy as np
import csv from panns_inference import AudioTagging, labels
# Force TensorFlow to use only CPU # Initialize PANNs model
tf.config.set_visible_devices([], 'GPU') at = AudioTagging(checkpoint_path=None, device='cuda')
model = hub.load('https://tfhub.dev/google/yamnet/1')
#Find the name of the class with the top score when mean-aggregated across frames.
def class_names_from_csv(class_map_scv_text):
"""Returns list of class names corresponding to score vector."""
class_names = []
with tf.io.gfile.GFile(class_map_scv_text) as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
class_names.append(row['display_name'])
return class_names
# Main function to process audio and classify
def classify_audio(waveform, sr): def classify_audio(waveform, sr):
""" """
Given an audio file, this function loads the audio, resamples it, Given an audio file, this function loads the audio, resamples it,
normalizes it, and runs it through the YAMNet model to classify the sound. normalizes it, and runs it through the PANNs model to classify the sound.
Args: Args:
- waveform (numpy.ndarray): waveform of the audio file (WAV, MP3, etc.). - waveform (numpy.ndarray): waveform of the audio file (WAV, MP3, etc.).
- sr (int): Sampling rate of the audio.
Returns: Returns:
- str: Predicted class label of the audio. - str: Predicted class label of the audio.
""" """
# Check if the sampling rate is 16000Hz # Check if the sampling rate is 32000Hz
try: try:
if(sr!=16000): if sr != 32000:
raise RuntimeError raise RuntimeError
except Exception: except Exception:
raise RuntimeError(f"The audio is not sampled at 16000Hz, failed to classify audio.") raise RuntimeError(f"The audio is not sampled at 32000Hz, failed to classify audio.")
# Normalize the waveform to [-1.0, 1.0] (librosa already returns normalized values) # Normalize the waveform to [-1.0, 1.0]
waveform = waveform / np.max(np.abs(waveform)) waveform = waveform / np.max(np.abs(waveform))
# Execute the YAMNet model # Ensure waveform shape is correct for model input
waveform = waveform[None, :]
# Execute the PANNs model
try: try:
scores, embeddings, spectrogram = model(waveform) clipwise_output, _ = at.inference(waveform)
except Exception as e: except Exception as e:
raise RuntimeError(f"Error: Failed to classify audio: {e}") raise RuntimeError(f"Error: Failed to classify audio: {e}")
# Extract the class names from the model # Get the top predicted class
class_map_path = model.class_map_path().numpy() predicted_index = np.argmax(clipwise_output)
class_names = class_names_from_csv(class_map_path) inferred_class = labels[predicted_index]
# Find the class with the highest score return inferred_class
scores_np = scores.numpy()
inferred_class = class_names[scores_np.mean(axis=0).argmax()]
return inferred_class
+1 -1
View File
@@ -24,7 +24,7 @@ def test_trim_audio():
def test_classify(): def test_classify():
file_path = "tests/test_audio/cafe_crowd_talk.wav" file_path = "tests/test_audio/cafe_crowd_talk.wav"
waveform, sr = read_audio(file_path, 16000, mono=True) waveform, sr = read_audio(file_path, 32000, mono=True)
expected_class = "Speech" expected_class = "Speech"
predicted_class = classify_audio(waveform, sr) predicted_class = classify_audio(waveform, sr)