Replace YAMNet model for panns-inference

YAMNet model was causing issues, as loading a pytorch framework model, when the tensorflow based YAMNet is loaded in the same environment already, caused segmentation fault
This commit is contained in:
Joel Mathew Thomas
2025-02-26 17:36:27 +05:30
parent f2011b4408
commit cbf2b022a5
4 changed files with 70 additions and 160 deletions
+24 -61
View File
@@ -6,76 +6,51 @@ build-backend = "setuptools.build_meta"
name = "freqsplit"
version = "0.1.0"
dependencies = [
"absl-py==2.1.0",
"aiohappyeyeballs==2.4.4",
"aiohttp==3.11.11",
"aiohappyeyeballs==2.4.6",
"aiohttp==3.11.13",
"aiosignal==1.3.2",
"amqp==5.3.1",
"antlr4-python3-runtime==4.9.3",
"appdirs==1.4.4",
"asgiref==3.8.1",
"asteroid==0.7.0",
"asteroid-filterbanks==0.4.0",
"astunparse==1.6.3",
"attrs==25.1.0",
"audioread==3.0.1",
"billiard==4.2.1",
"cached-property==2.0.1",
"celery==5.4.0",
"certifi==2024.12.14",
"certifi==2025.1.31",
"cffi==1.17.1",
"charset-normalizer==3.4.1",
"click==8.1.8",
"click-didyoumean==0.3.1",
"click-plugins==1.1.1",
"click-repl==0.3.0",
"cloudpickle==3.1.1",
"contourpy==1.3.1",
"cycler==0.12.1",
"decorator==5.1.1",
"decorator==5.2.1",
"DeepFilterLib==0.5.6",
"DeepFilterNet==0.5.6",
"demucs==4.0.1",
"Django==5.1.6",
"dora_search==0.1.12",
"einops==0.8.0",
"einops==0.8.1",
"filelock==3.17.0",
"flatbuffers==25.2.10",
"fonttools==4.55.6",
"fonttools==4.56.0",
"frozenlist==1.5.0",
"fsspec==2024.12.0",
"future==1.0.0",
"gast==0.6.0",
"google-pasta==0.2.0",
"grpcio==1.70.0",
"h5py==3.13.0",
"huggingface-hub==0.28.0",
"fsspec==2025.2.0",
"huggingface-hub==0.29.1",
"idna==3.10",
"iniconfig==2.0.0",
"Jinja2==3.1.5",
"joblib==1.4.2",
"julius==0.2.7",
"keras==3.8.0",
"kiwisolver==1.4.8",
"kombu==5.4.2",
"lameenc==1.8.1",
"lazy_loader==0.4",
"libclang==18.1.1",
"librosa==0.10.2.post1",
"lightning-utilities==0.11.9",
"lightning-utilities==0.12.0",
"llvmlite==0.44.0",
"loguru==0.7.3",
"Markdown==3.7",
"markdown-it-py==3.0.0",
"MarkupSafe==3.0.2",
"matplotlib==3.10.0",
"mdurl==0.1.2",
"mir_eval==0.7",
"ml-dtypes==0.4.1",
"mir_eval==0.8.2",
"mpmath==1.3.0",
"msgpack==1.1.0",
"multidict==6.1.0",
"namex==0.0.8",
"networkx==3.4.2",
"numba==0.61.0",
"numpy==1.26.4",
@@ -88,70 +63,58 @@ dependencies = [
"nvidia-curand-cu12==10.3.5.147",
"nvidia-cusolver-cu12==11.6.1.9",
"nvidia-cusparse-cu12==12.3.1.170",
"nvidia-cusparselt-cu12==0.6.2",
"nvidia-nccl-cu12==2.21.5",
"nvidia-nvjitlink-cu12==12.4.127",
"nvidia-nvtx-cu12==12.4.127",
"omegaconf==2.3.0",
"openunmix==1.3.0",
"opt_einsum==3.4.0",
"optree==0.14.0",
"packaging==23.2",
"pandas==2.2.3",
"panns-inference==0.1.1",
"pb-bss-eval==0.0.2",
"pesq==0.0.4",
"pillow==11.1.0",
"platformdirs==4.3.6",
"pluggy==1.5.0",
"pooch==1.8.2",
"prompt_toolkit==3.0.50",
"propcache==0.2.1",
"protobuf==5.29.3",
"propcache==0.3.0",
"pycparser==2.22",
"Pygments==2.19.1",
"pyparsing==3.2.1",
"pystoi==0.4.1",
"pytest==8.3.4",
"python-dateutil==2.9.0.post0",
"pytorch-lightning==2.5.0.post0",
"pytorch-ranger==0.1.1",
"pytz==2024.2",
"pytz==2025.1",
"PyYAML==6.0.2",
"redis==5.2.1",
"regex==2024.11.6",
"requests==2.32.3",
"retrying==1.3.4",
"rich==13.9.4",
"safetensors==0.5.3",
"scikit-learn==1.6.1",
"scipy==1.15.1",
"setuptools==75.8.0",
"scipy==1.15.2",
"setuptools==75.8.1",
"six==1.17.0",
"soundfile==0.13.1",
"soxr==0.5.0.post1",
"sqlparse==0.5.3",
"submitit==1.5.2",
"sympy==1.13.1",
"tensorboard==2.18.0",
"tensorboard-data-server==0.7.2",
"tensorflow==2.18.0",
"tensorflow-hub==0.16.1",
"termcolor==2.5.0",
"tf_keras==2.18.0",
"threadpoolctl==3.5.0",
"torch==2.5.1",
"tokenizers==0.21.0",
"torch==2.6.0",
"torch-optimizer==0.1.0",
"torch-stoi==0.2.3",
"torchaudio==2.5.1",
"torchaudio==2.6.0",
"torchlibrosa==0.1.0",
"torchmetrics==0.11.4",
"tqdm==4.67.1",
"transformers==4.49.0",
"treetable==0.2.5",
"triton==3.1.0",
"triton==3.2.0",
"typing_extensions==4.12.2",
"tzdata==2025.1",
"urllib3==2.3.0",
"vine==5.1.0",
"wcwidth==0.2.13",
"Werkzeug==3.1.3",
"wheel==0.45.1",
"wrapt==1.17.2",
"yarl==1.18.3",
]
+24 -62
View File
@@ -1,74 +1,48 @@
absl-py==2.1.0
aiohappyeyeballs==2.4.4
aiohttp==3.11.11
aiohappyeyeballs==2.4.6
aiohttp==3.11.13
aiosignal==1.3.2
amqp==5.3.1
antlr4-python3-runtime==4.9.3
appdirs==1.4.4
asgiref==3.8.1
asteroid==0.7.0
asteroid-filterbanks==0.4.0
astunparse==1.6.3
attrs==25.1.0
audioread==3.0.1
billiard==4.2.1
cached-property==2.0.1
celery==5.4.0
certifi==2024.12.14
certifi==2025.1.31
cffi==1.17.1
charset-normalizer==3.4.1
click==8.1.8
click-didyoumean==0.3.1
click-plugins==1.1.1
click-repl==0.3.0
cloudpickle==3.1.1
contourpy==1.3.1
cycler==0.12.1
decorator==5.1.1
decorator==5.2.1
DeepFilterLib==0.5.6
DeepFilterNet==0.5.6
demucs==4.0.1
Django==5.1.6
djangorestframework==3.15.2
dora_search==0.1.12
einops==0.8.0
einops==0.8.1
filelock==3.17.0
flatbuffers==25.2.10
fonttools==4.55.6
fonttools==4.56.0
frozenlist==1.5.0
fsspec==2024.12.0
future==1.0.0
gast==0.6.0
google-pasta==0.2.0
grpcio==1.70.0
h5py==3.13.0
huggingface-hub==0.28.0
fsspec==2025.2.0
huggingface-hub==0.29.1
idna==3.10
iniconfig==2.0.0
Jinja2==3.1.5
joblib==1.4.2
julius==0.2.7
keras==3.8.0
kiwisolver==1.4.8
kombu==5.4.2
lameenc==1.8.1
lazy_loader==0.4
libclang==18.1.1
librosa==0.10.2.post1
lightning-utilities==0.11.9
lightning-utilities==0.12.0
llvmlite==0.44.0
loguru==0.7.3
Markdown==3.7
markdown-it-py==3.0.0
MarkupSafe==3.0.2
matplotlib==3.10.0
mdurl==0.1.2
mir_eval==0.7
ml-dtypes==0.4.1
mir_eval==0.8.2
mpmath==1.3.0
msgpack==1.1.0
multidict==6.1.0
namex==0.0.8
networkx==3.4.2
numba==0.61.0
numpy==1.26.4
@@ -81,68 +55,56 @@ nvidia-cufft-cu12==11.2.1.3
nvidia-curand-cu12==10.3.5.147
nvidia-cusolver-cu12==11.6.1.9
nvidia-cusparse-cu12==12.3.1.170
nvidia-cusparselt-cu12==0.6.2
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.4.127
omegaconf==2.3.0
openunmix==1.3.0
opt_einsum==3.4.0
optree==0.14.0
packaging==23.2
pandas==2.2.3
panns-inference==0.1.1
pb-bss-eval==0.0.2
pesq==0.0.4
pillow==11.1.0
platformdirs==4.3.6
pluggy==1.5.0
pooch==1.8.2
prompt_toolkit==3.0.50
propcache==0.2.1
protobuf==5.29.3
propcache==0.3.0
pycparser==2.22
Pygments==2.19.1
pyparsing==3.2.1
pystoi==0.4.1
pytest==8.3.4
python-dateutil==2.9.0.post0
pytorch-lightning==2.5.0.post0
pytorch-ranger==0.1.1
pytz==2024.2
pytz==2025.1
PyYAML==6.0.2
redis==5.2.1
regex==2024.11.6
requests==2.32.3
retrying==1.3.4
rich==13.9.4
safetensors==0.5.3
scikit-learn==1.6.1
scipy==1.15.1
setuptools==75.8.0
scipy==1.15.2
setuptools==75.8.1
six==1.17.0
soundfile==0.13.1
soxr==0.5.0.post1
sqlparse==0.5.3
submitit==1.5.2
sympy==1.13.1
tensorboard==2.18.0
tensorboard-data-server==0.7.2
tensorflow==2.18.0
tensorflow-hub==0.16.1
termcolor==2.5.0
tf_keras==2.18.0
threadpoolctl==3.5.0
torch==2.5.1
tokenizers==0.21.0
torch==2.6.0
torch-optimizer==0.1.0
torch-stoi==0.2.3
torchaudio==2.5.1
torchaudio==2.6.0
torchlibrosa==0.1.0
torchmetrics==0.11.4
tqdm==4.67.1
transformers==4.49.0
treetable==0.2.5
triton==3.1.0
triton==3.2.0
typing_extensions==4.12.2
tzdata==2025.1
urllib3==2.3.0
vine==5.1.0
wcwidth==0.2.13
Werkzeug==3.1.3
wheel==0.45.1
wrapt==1.17.2
yarl==1.18.3
+17 -32
View File
@@ -1,58 +1,43 @@
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import csv
from panns_inference import AudioTagging, labels
# Force TensorFlow to use only CPU
tf.config.set_visible_devices([], 'GPU')
# Initialize PANNs model
at = AudioTagging(checkpoint_path=None, device='cuda')
model = hub.load('https://tfhub.dev/google/yamnet/1')
#Find the name of the class with the top score when mean-aggregated across frames.
def class_names_from_csv(class_map_scv_text):
"""Returns list of class names corresponding to score vector."""
class_names = []
with tf.io.gfile.GFile(class_map_scv_text) as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
class_names.append(row['display_name'])
return class_names
# Main function to process audio and classify
def classify_audio(waveform, sr):
"""
Given an audio file, this function loads the audio, resamples it,
normalizes it, and runs it through the YAMNet model to classify the sound.
normalizes it, and runs it through the PANNs model to classify the sound.
Args:
- waveform (numpy.ndarray): waveform of the audio file (WAV, MP3, etc.).
- sr (int): Sampling rate of the audio.
Returns:
- str: Predicted class label of the audio.
"""
# Check if the sampling rate is 16000Hz
# Check if the sampling rate is 32000Hz
try:
if(sr!=16000):
if sr != 32000:
raise RuntimeError
except Exception:
raise RuntimeError(f"The audio is not sampled at 16000Hz, failed to classify audio.")
raise RuntimeError(f"The audio is not sampled at 32000Hz, failed to classify audio.")
# Normalize the waveform to [-1.0, 1.0] (librosa already returns normalized values)
# Normalize the waveform to [-1.0, 1.0]
waveform = waveform / np.max(np.abs(waveform))
# Execute the YAMNet model
# Ensure waveform shape is correct for model input
waveform = waveform[None, :]
# Execute the PANNs model
try:
scores, embeddings, spectrogram = model(waveform)
clipwise_output, _ = at.inference(waveform)
except Exception as e:
raise RuntimeError(f"Error: Failed to classify audio: {e}")
# Extract the class names from the model
class_map_path = model.class_map_path().numpy()
class_names = class_names_from_csv(class_map_path)
# Find the class with the highest score
scores_np = scores.numpy()
inferred_class = class_names[scores_np.mean(axis=0).argmax()]
# Get the top predicted class
predicted_index = np.argmax(clipwise_output)
inferred_class = labels[predicted_index]
return inferred_class
+1 -1
View File
@@ -24,7 +24,7 @@ def test_trim_audio():
def test_classify():
file_path = "tests/test_audio/cafe_crowd_talk.wav"
waveform, sr = read_audio(file_path, 16000, mono=True)
waveform, sr = read_audio(file_path, 32000, mono=True)
expected_class = "Speech"
predicted_class = classify_audio(waveform, sr)