Replace YAMNet model for panns-inference
YAMNet model was causing issues, as loading a pytorch framework model, when the tensorflow based YAMNet is loaded in the same environment already, caused segmentation fault
This commit is contained in:
+24
-61
@@ -6,76 +6,51 @@ build-backend = "setuptools.build_meta"
|
||||
name = "freqsplit"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"absl-py==2.1.0",
|
||||
"aiohappyeyeballs==2.4.4",
|
||||
"aiohttp==3.11.11",
|
||||
"aiohappyeyeballs==2.4.6",
|
||||
"aiohttp==3.11.13",
|
||||
"aiosignal==1.3.2",
|
||||
"amqp==5.3.1",
|
||||
"antlr4-python3-runtime==4.9.3",
|
||||
"appdirs==1.4.4",
|
||||
"asgiref==3.8.1",
|
||||
"asteroid==0.7.0",
|
||||
"asteroid-filterbanks==0.4.0",
|
||||
"astunparse==1.6.3",
|
||||
"attrs==25.1.0",
|
||||
"audioread==3.0.1",
|
||||
"billiard==4.2.1",
|
||||
"cached-property==2.0.1",
|
||||
"celery==5.4.0",
|
||||
"certifi==2024.12.14",
|
||||
"certifi==2025.1.31",
|
||||
"cffi==1.17.1",
|
||||
"charset-normalizer==3.4.1",
|
||||
"click==8.1.8",
|
||||
"click-didyoumean==0.3.1",
|
||||
"click-plugins==1.1.1",
|
||||
"click-repl==0.3.0",
|
||||
"cloudpickle==3.1.1",
|
||||
"contourpy==1.3.1",
|
||||
"cycler==0.12.1",
|
||||
"decorator==5.1.1",
|
||||
"decorator==5.2.1",
|
||||
"DeepFilterLib==0.5.6",
|
||||
"DeepFilterNet==0.5.6",
|
||||
"demucs==4.0.1",
|
||||
"Django==5.1.6",
|
||||
"dora_search==0.1.12",
|
||||
"einops==0.8.0",
|
||||
"einops==0.8.1",
|
||||
"filelock==3.17.0",
|
||||
"flatbuffers==25.2.10",
|
||||
"fonttools==4.55.6",
|
||||
"fonttools==4.56.0",
|
||||
"frozenlist==1.5.0",
|
||||
"fsspec==2024.12.0",
|
||||
"future==1.0.0",
|
||||
"gast==0.6.0",
|
||||
"google-pasta==0.2.0",
|
||||
"grpcio==1.70.0",
|
||||
"h5py==3.13.0",
|
||||
"huggingface-hub==0.28.0",
|
||||
"fsspec==2025.2.0",
|
||||
"huggingface-hub==0.29.1",
|
||||
"idna==3.10",
|
||||
"iniconfig==2.0.0",
|
||||
"Jinja2==3.1.5",
|
||||
"joblib==1.4.2",
|
||||
"julius==0.2.7",
|
||||
"keras==3.8.0",
|
||||
"kiwisolver==1.4.8",
|
||||
"kombu==5.4.2",
|
||||
"lameenc==1.8.1",
|
||||
"lazy_loader==0.4",
|
||||
"libclang==18.1.1",
|
||||
"librosa==0.10.2.post1",
|
||||
"lightning-utilities==0.11.9",
|
||||
"lightning-utilities==0.12.0",
|
||||
"llvmlite==0.44.0",
|
||||
"loguru==0.7.3",
|
||||
"Markdown==3.7",
|
||||
"markdown-it-py==3.0.0",
|
||||
"MarkupSafe==3.0.2",
|
||||
"matplotlib==3.10.0",
|
||||
"mdurl==0.1.2",
|
||||
"mir_eval==0.7",
|
||||
"ml-dtypes==0.4.1",
|
||||
"mir_eval==0.8.2",
|
||||
"mpmath==1.3.0",
|
||||
"msgpack==1.1.0",
|
||||
"multidict==6.1.0",
|
||||
"namex==0.0.8",
|
||||
"networkx==3.4.2",
|
||||
"numba==0.61.0",
|
||||
"numpy==1.26.4",
|
||||
@@ -88,70 +63,58 @@ dependencies = [
|
||||
"nvidia-curand-cu12==10.3.5.147",
|
||||
"nvidia-cusolver-cu12==11.6.1.9",
|
||||
"nvidia-cusparse-cu12==12.3.1.170",
|
||||
"nvidia-cusparselt-cu12==0.6.2",
|
||||
"nvidia-nccl-cu12==2.21.5",
|
||||
"nvidia-nvjitlink-cu12==12.4.127",
|
||||
"nvidia-nvtx-cu12==12.4.127",
|
||||
"omegaconf==2.3.0",
|
||||
"openunmix==1.3.0",
|
||||
"opt_einsum==3.4.0",
|
||||
"optree==0.14.0",
|
||||
"packaging==23.2",
|
||||
"pandas==2.2.3",
|
||||
"panns-inference==0.1.1",
|
||||
"pb-bss-eval==0.0.2",
|
||||
"pesq==0.0.4",
|
||||
"pillow==11.1.0",
|
||||
"platformdirs==4.3.6",
|
||||
"pluggy==1.5.0",
|
||||
"pooch==1.8.2",
|
||||
"prompt_toolkit==3.0.50",
|
||||
"propcache==0.2.1",
|
||||
"protobuf==5.29.3",
|
||||
"propcache==0.3.0",
|
||||
"pycparser==2.22",
|
||||
"Pygments==2.19.1",
|
||||
"pyparsing==3.2.1",
|
||||
"pystoi==0.4.1",
|
||||
"pytest==8.3.4",
|
||||
"python-dateutil==2.9.0.post0",
|
||||
"pytorch-lightning==2.5.0.post0",
|
||||
"pytorch-ranger==0.1.1",
|
||||
"pytz==2024.2",
|
||||
"pytz==2025.1",
|
||||
"PyYAML==6.0.2",
|
||||
"redis==5.2.1",
|
||||
"regex==2024.11.6",
|
||||
"requests==2.32.3",
|
||||
"retrying==1.3.4",
|
||||
"rich==13.9.4",
|
||||
"safetensors==0.5.3",
|
||||
"scikit-learn==1.6.1",
|
||||
"scipy==1.15.1",
|
||||
"setuptools==75.8.0",
|
||||
"scipy==1.15.2",
|
||||
"setuptools==75.8.1",
|
||||
"six==1.17.0",
|
||||
"soundfile==0.13.1",
|
||||
"soxr==0.5.0.post1",
|
||||
"sqlparse==0.5.3",
|
||||
"submitit==1.5.2",
|
||||
"sympy==1.13.1",
|
||||
"tensorboard==2.18.0",
|
||||
"tensorboard-data-server==0.7.2",
|
||||
"tensorflow==2.18.0",
|
||||
"tensorflow-hub==0.16.1",
|
||||
"termcolor==2.5.0",
|
||||
"tf_keras==2.18.0",
|
||||
"threadpoolctl==3.5.0",
|
||||
"torch==2.5.1",
|
||||
"tokenizers==0.21.0",
|
||||
"torch==2.6.0",
|
||||
"torch-optimizer==0.1.0",
|
||||
"torch-stoi==0.2.3",
|
||||
"torchaudio==2.5.1",
|
||||
"torchaudio==2.6.0",
|
||||
"torchlibrosa==0.1.0",
|
||||
"torchmetrics==0.11.4",
|
||||
"tqdm==4.67.1",
|
||||
"transformers==4.49.0",
|
||||
"treetable==0.2.5",
|
||||
"triton==3.1.0",
|
||||
"triton==3.2.0",
|
||||
"typing_extensions==4.12.2",
|
||||
"tzdata==2025.1",
|
||||
"urllib3==2.3.0",
|
||||
"vine==5.1.0",
|
||||
"wcwidth==0.2.13",
|
||||
"Werkzeug==3.1.3",
|
||||
"wheel==0.45.1",
|
||||
"wrapt==1.17.2",
|
||||
"yarl==1.18.3",
|
||||
]
|
||||
|
||||
|
||||
+24
-62
@@ -1,74 +1,48 @@
|
||||
absl-py==2.1.0
|
||||
aiohappyeyeballs==2.4.4
|
||||
aiohttp==3.11.11
|
||||
aiohappyeyeballs==2.4.6
|
||||
aiohttp==3.11.13
|
||||
aiosignal==1.3.2
|
||||
amqp==5.3.1
|
||||
antlr4-python3-runtime==4.9.3
|
||||
appdirs==1.4.4
|
||||
asgiref==3.8.1
|
||||
asteroid==0.7.0
|
||||
asteroid-filterbanks==0.4.0
|
||||
astunparse==1.6.3
|
||||
attrs==25.1.0
|
||||
audioread==3.0.1
|
||||
billiard==4.2.1
|
||||
cached-property==2.0.1
|
||||
celery==5.4.0
|
||||
certifi==2024.12.14
|
||||
certifi==2025.1.31
|
||||
cffi==1.17.1
|
||||
charset-normalizer==3.4.1
|
||||
click==8.1.8
|
||||
click-didyoumean==0.3.1
|
||||
click-plugins==1.1.1
|
||||
click-repl==0.3.0
|
||||
cloudpickle==3.1.1
|
||||
contourpy==1.3.1
|
||||
cycler==0.12.1
|
||||
decorator==5.1.1
|
||||
decorator==5.2.1
|
||||
DeepFilterLib==0.5.6
|
||||
DeepFilterNet==0.5.6
|
||||
demucs==4.0.1
|
||||
Django==5.1.6
|
||||
djangorestframework==3.15.2
|
||||
dora_search==0.1.12
|
||||
einops==0.8.0
|
||||
einops==0.8.1
|
||||
filelock==3.17.0
|
||||
flatbuffers==25.2.10
|
||||
fonttools==4.55.6
|
||||
fonttools==4.56.0
|
||||
frozenlist==1.5.0
|
||||
fsspec==2024.12.0
|
||||
future==1.0.0
|
||||
gast==0.6.0
|
||||
google-pasta==0.2.0
|
||||
grpcio==1.70.0
|
||||
h5py==3.13.0
|
||||
huggingface-hub==0.28.0
|
||||
fsspec==2025.2.0
|
||||
huggingface-hub==0.29.1
|
||||
idna==3.10
|
||||
iniconfig==2.0.0
|
||||
Jinja2==3.1.5
|
||||
joblib==1.4.2
|
||||
julius==0.2.7
|
||||
keras==3.8.0
|
||||
kiwisolver==1.4.8
|
||||
kombu==5.4.2
|
||||
lameenc==1.8.1
|
||||
lazy_loader==0.4
|
||||
libclang==18.1.1
|
||||
librosa==0.10.2.post1
|
||||
lightning-utilities==0.11.9
|
||||
lightning-utilities==0.12.0
|
||||
llvmlite==0.44.0
|
||||
loguru==0.7.3
|
||||
Markdown==3.7
|
||||
markdown-it-py==3.0.0
|
||||
MarkupSafe==3.0.2
|
||||
matplotlib==3.10.0
|
||||
mdurl==0.1.2
|
||||
mir_eval==0.7
|
||||
ml-dtypes==0.4.1
|
||||
mir_eval==0.8.2
|
||||
mpmath==1.3.0
|
||||
msgpack==1.1.0
|
||||
multidict==6.1.0
|
||||
namex==0.0.8
|
||||
networkx==3.4.2
|
||||
numba==0.61.0
|
||||
numpy==1.26.4
|
||||
@@ -81,68 +55,56 @@ nvidia-cufft-cu12==11.2.1.3
|
||||
nvidia-curand-cu12==10.3.5.147
|
||||
nvidia-cusolver-cu12==11.6.1.9
|
||||
nvidia-cusparse-cu12==12.3.1.170
|
||||
nvidia-cusparselt-cu12==0.6.2
|
||||
nvidia-nccl-cu12==2.21.5
|
||||
nvidia-nvjitlink-cu12==12.4.127
|
||||
nvidia-nvtx-cu12==12.4.127
|
||||
omegaconf==2.3.0
|
||||
openunmix==1.3.0
|
||||
opt_einsum==3.4.0
|
||||
optree==0.14.0
|
||||
packaging==23.2
|
||||
pandas==2.2.3
|
||||
panns-inference==0.1.1
|
||||
pb-bss-eval==0.0.2
|
||||
pesq==0.0.4
|
||||
pillow==11.1.0
|
||||
platformdirs==4.3.6
|
||||
pluggy==1.5.0
|
||||
pooch==1.8.2
|
||||
prompt_toolkit==3.0.50
|
||||
propcache==0.2.1
|
||||
protobuf==5.29.3
|
||||
propcache==0.3.0
|
||||
pycparser==2.22
|
||||
Pygments==2.19.1
|
||||
pyparsing==3.2.1
|
||||
pystoi==0.4.1
|
||||
pytest==8.3.4
|
||||
python-dateutil==2.9.0.post0
|
||||
pytorch-lightning==2.5.0.post0
|
||||
pytorch-ranger==0.1.1
|
||||
pytz==2024.2
|
||||
pytz==2025.1
|
||||
PyYAML==6.0.2
|
||||
redis==5.2.1
|
||||
regex==2024.11.6
|
||||
requests==2.32.3
|
||||
retrying==1.3.4
|
||||
rich==13.9.4
|
||||
safetensors==0.5.3
|
||||
scikit-learn==1.6.1
|
||||
scipy==1.15.1
|
||||
setuptools==75.8.0
|
||||
scipy==1.15.2
|
||||
setuptools==75.8.1
|
||||
six==1.17.0
|
||||
soundfile==0.13.1
|
||||
soxr==0.5.0.post1
|
||||
sqlparse==0.5.3
|
||||
submitit==1.5.2
|
||||
sympy==1.13.1
|
||||
tensorboard==2.18.0
|
||||
tensorboard-data-server==0.7.2
|
||||
tensorflow==2.18.0
|
||||
tensorflow-hub==0.16.1
|
||||
termcolor==2.5.0
|
||||
tf_keras==2.18.0
|
||||
threadpoolctl==3.5.0
|
||||
torch==2.5.1
|
||||
tokenizers==0.21.0
|
||||
torch==2.6.0
|
||||
torch-optimizer==0.1.0
|
||||
torch-stoi==0.2.3
|
||||
torchaudio==2.5.1
|
||||
torchaudio==2.6.0
|
||||
torchlibrosa==0.1.0
|
||||
torchmetrics==0.11.4
|
||||
tqdm==4.67.1
|
||||
transformers==4.49.0
|
||||
treetable==0.2.5
|
||||
triton==3.1.0
|
||||
triton==3.2.0
|
||||
typing_extensions==4.12.2
|
||||
tzdata==2025.1
|
||||
urllib3==2.3.0
|
||||
vine==5.1.0
|
||||
wcwidth==0.2.13
|
||||
Werkzeug==3.1.3
|
||||
wheel==0.45.1
|
||||
wrapt==1.17.2
|
||||
yarl==1.18.3
|
||||
|
||||
@@ -1,58 +1,43 @@
|
||||
import tensorflow as tf
|
||||
import tensorflow_hub as hub
|
||||
import numpy as np
|
||||
import csv
|
||||
from panns_inference import AudioTagging, labels
|
||||
|
||||
# Force TensorFlow to use only CPU
|
||||
tf.config.set_visible_devices([], 'GPU')
|
||||
# Initialize PANNs model
|
||||
at = AudioTagging(checkpoint_path=None, device='cuda')
|
||||
|
||||
model = hub.load('https://tfhub.dev/google/yamnet/1')
|
||||
|
||||
#Find the name of the class with the top score when mean-aggregated across frames.
|
||||
def class_names_from_csv(class_map_scv_text):
|
||||
"""Returns list of class names corresponding to score vector."""
|
||||
class_names = []
|
||||
with tf.io.gfile.GFile(class_map_scv_text) as csvfile:
|
||||
reader = csv.DictReader(csvfile)
|
||||
for row in reader:
|
||||
class_names.append(row['display_name'])
|
||||
return class_names
|
||||
|
||||
# Main function to process audio and classify
|
||||
def classify_audio(waveform, sr):
|
||||
"""
|
||||
Given an audio file, this function loads the audio, resamples it,
|
||||
normalizes it, and runs it through the YAMNet model to classify the sound.
|
||||
normalizes it, and runs it through the PANNs model to classify the sound.
|
||||
|
||||
Args:
|
||||
- waveform (numpy.ndarray): waveform of the audio file (WAV, MP3, etc.).
|
||||
- sr (int): Sampling rate of the audio.
|
||||
|
||||
Returns:
|
||||
- str: Predicted class label of the audio.
|
||||
"""
|
||||
|
||||
# Check if the sampling rate is 16000Hz
|
||||
# Check if the sampling rate is 32000Hz
|
||||
try:
|
||||
if(sr!=16000):
|
||||
if sr != 32000:
|
||||
raise RuntimeError
|
||||
except Exception:
|
||||
raise RuntimeError(f"The audio is not sampled at 16000Hz, failed to classify audio.")
|
||||
raise RuntimeError(f"The audio is not sampled at 32000Hz, failed to classify audio.")
|
||||
|
||||
# Normalize the waveform to [-1.0, 1.0] (librosa already returns normalized values)
|
||||
# Normalize the waveform to [-1.0, 1.0]
|
||||
waveform = waveform / np.max(np.abs(waveform))
|
||||
|
||||
# Execute the YAMNet model
|
||||
|
||||
# Ensure waveform shape is correct for model input
|
||||
waveform = waveform[None, :]
|
||||
|
||||
# Execute the PANNs model
|
||||
try:
|
||||
scores, embeddings, spectrogram = model(waveform)
|
||||
clipwise_output, _ = at.inference(waveform)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error: Failed to classify audio: {e}")
|
||||
|
||||
# Extract the class names from the model
|
||||
class_map_path = model.class_map_path().numpy()
|
||||
class_names = class_names_from_csv(class_map_path)
|
||||
|
||||
# Find the class with the highest score
|
||||
scores_np = scores.numpy()
|
||||
inferred_class = class_names[scores_np.mean(axis=0).argmax()]
|
||||
|
||||
return inferred_class
|
||||
|
||||
# Get the top predicted class
|
||||
predicted_index = np.argmax(clipwise_output)
|
||||
inferred_class = labels[predicted_index]
|
||||
|
||||
return inferred_class
|
||||
|
||||
@@ -24,7 +24,7 @@ def test_trim_audio():
|
||||
|
||||
def test_classify():
|
||||
file_path = "tests/test_audio/cafe_crowd_talk.wav"
|
||||
waveform, sr = read_audio(file_path, 16000, mono=True)
|
||||
waveform, sr = read_audio(file_path, 32000, mono=True)
|
||||
expected_class = "Speech"
|
||||
predicted_class = classify_audio(waveform, sr)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user