Replace YAMNet model for panns-inference
YAMNet model was causing issues, as loading a pytorch framework model, when the tensorflow based YAMNet is loaded in the same environment already, caused segmentation fault
This commit is contained in:
+24
-61
@@ -6,76 +6,51 @@ build-backend = "setuptools.build_meta"
|
|||||||
name = "freqsplit"
|
name = "freqsplit"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"absl-py==2.1.0",
|
"aiohappyeyeballs==2.4.6",
|
||||||
"aiohappyeyeballs==2.4.4",
|
"aiohttp==3.11.13",
|
||||||
"aiohttp==3.11.11",
|
|
||||||
"aiosignal==1.3.2",
|
"aiosignal==1.3.2",
|
||||||
"amqp==5.3.1",
|
|
||||||
"antlr4-python3-runtime==4.9.3",
|
"antlr4-python3-runtime==4.9.3",
|
||||||
"appdirs==1.4.4",
|
"appdirs==1.4.4",
|
||||||
"asgiref==3.8.1",
|
|
||||||
"asteroid==0.7.0",
|
"asteroid==0.7.0",
|
||||||
"asteroid-filterbanks==0.4.0",
|
"asteroid-filterbanks==0.4.0",
|
||||||
"astunparse==1.6.3",
|
|
||||||
"attrs==25.1.0",
|
"attrs==25.1.0",
|
||||||
"audioread==3.0.1",
|
"audioread==3.0.1",
|
||||||
"billiard==4.2.1",
|
|
||||||
"cached-property==2.0.1",
|
"cached-property==2.0.1",
|
||||||
"celery==5.4.0",
|
"certifi==2025.1.31",
|
||||||
"certifi==2024.12.14",
|
|
||||||
"cffi==1.17.1",
|
"cffi==1.17.1",
|
||||||
"charset-normalizer==3.4.1",
|
"charset-normalizer==3.4.1",
|
||||||
"click==8.1.8",
|
|
||||||
"click-didyoumean==0.3.1",
|
|
||||||
"click-plugins==1.1.1",
|
|
||||||
"click-repl==0.3.0",
|
|
||||||
"cloudpickle==3.1.1",
|
"cloudpickle==3.1.1",
|
||||||
"contourpy==1.3.1",
|
"contourpy==1.3.1",
|
||||||
"cycler==0.12.1",
|
"cycler==0.12.1",
|
||||||
"decorator==5.1.1",
|
"decorator==5.2.1",
|
||||||
"DeepFilterLib==0.5.6",
|
"DeepFilterLib==0.5.6",
|
||||||
"DeepFilterNet==0.5.6",
|
"DeepFilterNet==0.5.6",
|
||||||
"demucs==4.0.1",
|
"demucs==4.0.1",
|
||||||
"Django==5.1.6",
|
|
||||||
"dora_search==0.1.12",
|
"dora_search==0.1.12",
|
||||||
"einops==0.8.0",
|
"einops==0.8.1",
|
||||||
"filelock==3.17.0",
|
"filelock==3.17.0",
|
||||||
"flatbuffers==25.2.10",
|
"fonttools==4.56.0",
|
||||||
"fonttools==4.55.6",
|
|
||||||
"frozenlist==1.5.0",
|
"frozenlist==1.5.0",
|
||||||
"fsspec==2024.12.0",
|
"fsspec==2025.2.0",
|
||||||
"future==1.0.0",
|
"huggingface-hub==0.29.1",
|
||||||
"gast==0.6.0",
|
|
||||||
"google-pasta==0.2.0",
|
|
||||||
"grpcio==1.70.0",
|
|
||||||
"h5py==3.13.0",
|
|
||||||
"huggingface-hub==0.28.0",
|
|
||||||
"idna==3.10",
|
"idna==3.10",
|
||||||
"iniconfig==2.0.0",
|
"iniconfig==2.0.0",
|
||||||
"Jinja2==3.1.5",
|
"Jinja2==3.1.5",
|
||||||
"joblib==1.4.2",
|
"joblib==1.4.2",
|
||||||
"julius==0.2.7",
|
"julius==0.2.7",
|
||||||
"keras==3.8.0",
|
|
||||||
"kiwisolver==1.4.8",
|
"kiwisolver==1.4.8",
|
||||||
"kombu==5.4.2",
|
|
||||||
"lameenc==1.8.1",
|
"lameenc==1.8.1",
|
||||||
"lazy_loader==0.4",
|
"lazy_loader==0.4",
|
||||||
"libclang==18.1.1",
|
|
||||||
"librosa==0.10.2.post1",
|
"librosa==0.10.2.post1",
|
||||||
"lightning-utilities==0.11.9",
|
"lightning-utilities==0.12.0",
|
||||||
"llvmlite==0.44.0",
|
"llvmlite==0.44.0",
|
||||||
"loguru==0.7.3",
|
"loguru==0.7.3",
|
||||||
"Markdown==3.7",
|
|
||||||
"markdown-it-py==3.0.0",
|
|
||||||
"MarkupSafe==3.0.2",
|
"MarkupSafe==3.0.2",
|
||||||
"matplotlib==3.10.0",
|
"matplotlib==3.10.0",
|
||||||
"mdurl==0.1.2",
|
"mir_eval==0.8.2",
|
||||||
"mir_eval==0.7",
|
|
||||||
"ml-dtypes==0.4.1",
|
|
||||||
"mpmath==1.3.0",
|
"mpmath==1.3.0",
|
||||||
"msgpack==1.1.0",
|
"msgpack==1.1.0",
|
||||||
"multidict==6.1.0",
|
"multidict==6.1.0",
|
||||||
"namex==0.0.8",
|
|
||||||
"networkx==3.4.2",
|
"networkx==3.4.2",
|
||||||
"numba==0.61.0",
|
"numba==0.61.0",
|
||||||
"numpy==1.26.4",
|
"numpy==1.26.4",
|
||||||
@@ -88,70 +63,58 @@ dependencies = [
|
|||||||
"nvidia-curand-cu12==10.3.5.147",
|
"nvidia-curand-cu12==10.3.5.147",
|
||||||
"nvidia-cusolver-cu12==11.6.1.9",
|
"nvidia-cusolver-cu12==11.6.1.9",
|
||||||
"nvidia-cusparse-cu12==12.3.1.170",
|
"nvidia-cusparse-cu12==12.3.1.170",
|
||||||
|
"nvidia-cusparselt-cu12==0.6.2",
|
||||||
"nvidia-nccl-cu12==2.21.5",
|
"nvidia-nccl-cu12==2.21.5",
|
||||||
"nvidia-nvjitlink-cu12==12.4.127",
|
"nvidia-nvjitlink-cu12==12.4.127",
|
||||||
"nvidia-nvtx-cu12==12.4.127",
|
"nvidia-nvtx-cu12==12.4.127",
|
||||||
"omegaconf==2.3.0",
|
"omegaconf==2.3.0",
|
||||||
"openunmix==1.3.0",
|
"openunmix==1.3.0",
|
||||||
"opt_einsum==3.4.0",
|
|
||||||
"optree==0.14.0",
|
|
||||||
"packaging==23.2",
|
"packaging==23.2",
|
||||||
"pandas==2.2.3",
|
"pandas==2.2.3",
|
||||||
|
"panns-inference==0.1.1",
|
||||||
"pb-bss-eval==0.0.2",
|
"pb-bss-eval==0.0.2",
|
||||||
"pesq==0.0.4",
|
"pesq==0.0.4",
|
||||||
"pillow==11.1.0",
|
"pillow==11.1.0",
|
||||||
"platformdirs==4.3.6",
|
"platformdirs==4.3.6",
|
||||||
"pluggy==1.5.0",
|
"pluggy==1.5.0",
|
||||||
"pooch==1.8.2",
|
"pooch==1.8.2",
|
||||||
"prompt_toolkit==3.0.50",
|
"propcache==0.3.0",
|
||||||
"propcache==0.2.1",
|
|
||||||
"protobuf==5.29.3",
|
|
||||||
"pycparser==2.22",
|
"pycparser==2.22",
|
||||||
"Pygments==2.19.1",
|
|
||||||
"pyparsing==3.2.1",
|
"pyparsing==3.2.1",
|
||||||
"pystoi==0.4.1",
|
"pystoi==0.4.1",
|
||||||
"pytest==8.3.4",
|
"pytest==8.3.4",
|
||||||
"python-dateutil==2.9.0.post0",
|
"python-dateutil==2.9.0.post0",
|
||||||
"pytorch-lightning==2.5.0.post0",
|
"pytorch-lightning==2.5.0.post0",
|
||||||
"pytorch-ranger==0.1.1",
|
"pytorch-ranger==0.1.1",
|
||||||
"pytz==2024.2",
|
"pytz==2025.1",
|
||||||
"PyYAML==6.0.2",
|
"PyYAML==6.0.2",
|
||||||
"redis==5.2.1",
|
"regex==2024.11.6",
|
||||||
"requests==2.32.3",
|
"requests==2.32.3",
|
||||||
"retrying==1.3.4",
|
"retrying==1.3.4",
|
||||||
"rich==13.9.4",
|
"safetensors==0.5.3",
|
||||||
"scikit-learn==1.6.1",
|
"scikit-learn==1.6.1",
|
||||||
"scipy==1.15.1",
|
"scipy==1.15.2",
|
||||||
"setuptools==75.8.0",
|
"setuptools==75.8.1",
|
||||||
"six==1.17.0",
|
"six==1.17.0",
|
||||||
"soundfile==0.13.1",
|
"soundfile==0.13.1",
|
||||||
"soxr==0.5.0.post1",
|
"soxr==0.5.0.post1",
|
||||||
"sqlparse==0.5.3",
|
|
||||||
"submitit==1.5.2",
|
"submitit==1.5.2",
|
||||||
"sympy==1.13.1",
|
"sympy==1.13.1",
|
||||||
"tensorboard==2.18.0",
|
|
||||||
"tensorboard-data-server==0.7.2",
|
|
||||||
"tensorflow==2.18.0",
|
|
||||||
"tensorflow-hub==0.16.1",
|
|
||||||
"termcolor==2.5.0",
|
|
||||||
"tf_keras==2.18.0",
|
|
||||||
"threadpoolctl==3.5.0",
|
"threadpoolctl==3.5.0",
|
||||||
"torch==2.5.1",
|
"tokenizers==0.21.0",
|
||||||
|
"torch==2.6.0",
|
||||||
"torch-optimizer==0.1.0",
|
"torch-optimizer==0.1.0",
|
||||||
"torch-stoi==0.2.3",
|
"torch-stoi==0.2.3",
|
||||||
"torchaudio==2.5.1",
|
"torchaudio==2.6.0",
|
||||||
|
"torchlibrosa==0.1.0",
|
||||||
"torchmetrics==0.11.4",
|
"torchmetrics==0.11.4",
|
||||||
"tqdm==4.67.1",
|
"tqdm==4.67.1",
|
||||||
|
"transformers==4.49.0",
|
||||||
"treetable==0.2.5",
|
"treetable==0.2.5",
|
||||||
"triton==3.1.0",
|
"triton==3.2.0",
|
||||||
"typing_extensions==4.12.2",
|
"typing_extensions==4.12.2",
|
||||||
"tzdata==2025.1",
|
"tzdata==2025.1",
|
||||||
"urllib3==2.3.0",
|
"urllib3==2.3.0",
|
||||||
"vine==5.1.0",
|
|
||||||
"wcwidth==0.2.13",
|
|
||||||
"Werkzeug==3.1.3",
|
|
||||||
"wheel==0.45.1",
|
|
||||||
"wrapt==1.17.2",
|
|
||||||
"yarl==1.18.3",
|
"yarl==1.18.3",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
+24
-62
@@ -1,74 +1,48 @@
|
|||||||
absl-py==2.1.0
|
aiohappyeyeballs==2.4.6
|
||||||
aiohappyeyeballs==2.4.4
|
aiohttp==3.11.13
|
||||||
aiohttp==3.11.11
|
|
||||||
aiosignal==1.3.2
|
aiosignal==1.3.2
|
||||||
amqp==5.3.1
|
|
||||||
antlr4-python3-runtime==4.9.3
|
antlr4-python3-runtime==4.9.3
|
||||||
appdirs==1.4.4
|
appdirs==1.4.4
|
||||||
asgiref==3.8.1
|
|
||||||
asteroid==0.7.0
|
asteroid==0.7.0
|
||||||
asteroid-filterbanks==0.4.0
|
asteroid-filterbanks==0.4.0
|
||||||
astunparse==1.6.3
|
|
||||||
attrs==25.1.0
|
attrs==25.1.0
|
||||||
audioread==3.0.1
|
audioread==3.0.1
|
||||||
billiard==4.2.1
|
|
||||||
cached-property==2.0.1
|
cached-property==2.0.1
|
||||||
celery==5.4.0
|
certifi==2025.1.31
|
||||||
certifi==2024.12.14
|
|
||||||
cffi==1.17.1
|
cffi==1.17.1
|
||||||
charset-normalizer==3.4.1
|
charset-normalizer==3.4.1
|
||||||
click==8.1.8
|
|
||||||
click-didyoumean==0.3.1
|
|
||||||
click-plugins==1.1.1
|
|
||||||
click-repl==0.3.0
|
|
||||||
cloudpickle==3.1.1
|
cloudpickle==3.1.1
|
||||||
contourpy==1.3.1
|
contourpy==1.3.1
|
||||||
cycler==0.12.1
|
cycler==0.12.1
|
||||||
decorator==5.1.1
|
decorator==5.2.1
|
||||||
DeepFilterLib==0.5.6
|
DeepFilterLib==0.5.6
|
||||||
DeepFilterNet==0.5.6
|
DeepFilterNet==0.5.6
|
||||||
demucs==4.0.1
|
demucs==4.0.1
|
||||||
Django==5.1.6
|
|
||||||
djangorestframework==3.15.2
|
|
||||||
dora_search==0.1.12
|
dora_search==0.1.12
|
||||||
einops==0.8.0
|
einops==0.8.1
|
||||||
filelock==3.17.0
|
filelock==3.17.0
|
||||||
flatbuffers==25.2.10
|
fonttools==4.56.0
|
||||||
fonttools==4.55.6
|
|
||||||
frozenlist==1.5.0
|
frozenlist==1.5.0
|
||||||
fsspec==2024.12.0
|
fsspec==2025.2.0
|
||||||
future==1.0.0
|
huggingface-hub==0.29.1
|
||||||
gast==0.6.0
|
|
||||||
google-pasta==0.2.0
|
|
||||||
grpcio==1.70.0
|
|
||||||
h5py==3.13.0
|
|
||||||
huggingface-hub==0.28.0
|
|
||||||
idna==3.10
|
idna==3.10
|
||||||
iniconfig==2.0.0
|
iniconfig==2.0.0
|
||||||
Jinja2==3.1.5
|
Jinja2==3.1.5
|
||||||
joblib==1.4.2
|
joblib==1.4.2
|
||||||
julius==0.2.7
|
julius==0.2.7
|
||||||
keras==3.8.0
|
|
||||||
kiwisolver==1.4.8
|
kiwisolver==1.4.8
|
||||||
kombu==5.4.2
|
|
||||||
lameenc==1.8.1
|
lameenc==1.8.1
|
||||||
lazy_loader==0.4
|
lazy_loader==0.4
|
||||||
libclang==18.1.1
|
|
||||||
librosa==0.10.2.post1
|
librosa==0.10.2.post1
|
||||||
lightning-utilities==0.11.9
|
lightning-utilities==0.12.0
|
||||||
llvmlite==0.44.0
|
llvmlite==0.44.0
|
||||||
loguru==0.7.3
|
loguru==0.7.3
|
||||||
Markdown==3.7
|
|
||||||
markdown-it-py==3.0.0
|
|
||||||
MarkupSafe==3.0.2
|
MarkupSafe==3.0.2
|
||||||
matplotlib==3.10.0
|
matplotlib==3.10.0
|
||||||
mdurl==0.1.2
|
mir_eval==0.8.2
|
||||||
mir_eval==0.7
|
|
||||||
ml-dtypes==0.4.1
|
|
||||||
mpmath==1.3.0
|
mpmath==1.3.0
|
||||||
msgpack==1.1.0
|
msgpack==1.1.0
|
||||||
multidict==6.1.0
|
multidict==6.1.0
|
||||||
namex==0.0.8
|
|
||||||
networkx==3.4.2
|
networkx==3.4.2
|
||||||
numba==0.61.0
|
numba==0.61.0
|
||||||
numpy==1.26.4
|
numpy==1.26.4
|
||||||
@@ -81,68 +55,56 @@ nvidia-cufft-cu12==11.2.1.3
|
|||||||
nvidia-curand-cu12==10.3.5.147
|
nvidia-curand-cu12==10.3.5.147
|
||||||
nvidia-cusolver-cu12==11.6.1.9
|
nvidia-cusolver-cu12==11.6.1.9
|
||||||
nvidia-cusparse-cu12==12.3.1.170
|
nvidia-cusparse-cu12==12.3.1.170
|
||||||
|
nvidia-cusparselt-cu12==0.6.2
|
||||||
nvidia-nccl-cu12==2.21.5
|
nvidia-nccl-cu12==2.21.5
|
||||||
nvidia-nvjitlink-cu12==12.4.127
|
nvidia-nvjitlink-cu12==12.4.127
|
||||||
nvidia-nvtx-cu12==12.4.127
|
nvidia-nvtx-cu12==12.4.127
|
||||||
omegaconf==2.3.0
|
omegaconf==2.3.0
|
||||||
openunmix==1.3.0
|
openunmix==1.3.0
|
||||||
opt_einsum==3.4.0
|
|
||||||
optree==0.14.0
|
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
pandas==2.2.3
|
pandas==2.2.3
|
||||||
|
panns-inference==0.1.1
|
||||||
pb-bss-eval==0.0.2
|
pb-bss-eval==0.0.2
|
||||||
pesq==0.0.4
|
pesq==0.0.4
|
||||||
pillow==11.1.0
|
pillow==11.1.0
|
||||||
platformdirs==4.3.6
|
platformdirs==4.3.6
|
||||||
pluggy==1.5.0
|
pluggy==1.5.0
|
||||||
pooch==1.8.2
|
pooch==1.8.2
|
||||||
prompt_toolkit==3.0.50
|
propcache==0.3.0
|
||||||
propcache==0.2.1
|
|
||||||
protobuf==5.29.3
|
|
||||||
pycparser==2.22
|
pycparser==2.22
|
||||||
Pygments==2.19.1
|
|
||||||
pyparsing==3.2.1
|
pyparsing==3.2.1
|
||||||
pystoi==0.4.1
|
pystoi==0.4.1
|
||||||
pytest==8.3.4
|
pytest==8.3.4
|
||||||
python-dateutil==2.9.0.post0
|
python-dateutil==2.9.0.post0
|
||||||
pytorch-lightning==2.5.0.post0
|
pytorch-lightning==2.5.0.post0
|
||||||
pytorch-ranger==0.1.1
|
pytorch-ranger==0.1.1
|
||||||
pytz==2024.2
|
pytz==2025.1
|
||||||
PyYAML==6.0.2
|
PyYAML==6.0.2
|
||||||
redis==5.2.1
|
regex==2024.11.6
|
||||||
requests==2.32.3
|
requests==2.32.3
|
||||||
retrying==1.3.4
|
retrying==1.3.4
|
||||||
rich==13.9.4
|
safetensors==0.5.3
|
||||||
scikit-learn==1.6.1
|
scikit-learn==1.6.1
|
||||||
scipy==1.15.1
|
scipy==1.15.2
|
||||||
setuptools==75.8.0
|
setuptools==75.8.1
|
||||||
six==1.17.0
|
six==1.17.0
|
||||||
soundfile==0.13.1
|
soundfile==0.13.1
|
||||||
soxr==0.5.0.post1
|
soxr==0.5.0.post1
|
||||||
sqlparse==0.5.3
|
|
||||||
submitit==1.5.2
|
submitit==1.5.2
|
||||||
sympy==1.13.1
|
sympy==1.13.1
|
||||||
tensorboard==2.18.0
|
|
||||||
tensorboard-data-server==0.7.2
|
|
||||||
tensorflow==2.18.0
|
|
||||||
tensorflow-hub==0.16.1
|
|
||||||
termcolor==2.5.0
|
|
||||||
tf_keras==2.18.0
|
|
||||||
threadpoolctl==3.5.0
|
threadpoolctl==3.5.0
|
||||||
torch==2.5.1
|
tokenizers==0.21.0
|
||||||
|
torch==2.6.0
|
||||||
torch-optimizer==0.1.0
|
torch-optimizer==0.1.0
|
||||||
torch-stoi==0.2.3
|
torch-stoi==0.2.3
|
||||||
torchaudio==2.5.1
|
torchaudio==2.6.0
|
||||||
|
torchlibrosa==0.1.0
|
||||||
torchmetrics==0.11.4
|
torchmetrics==0.11.4
|
||||||
tqdm==4.67.1
|
tqdm==4.67.1
|
||||||
|
transformers==4.49.0
|
||||||
treetable==0.2.5
|
treetable==0.2.5
|
||||||
triton==3.1.0
|
triton==3.2.0
|
||||||
typing_extensions==4.12.2
|
typing_extensions==4.12.2
|
||||||
tzdata==2025.1
|
tzdata==2025.1
|
||||||
urllib3==2.3.0
|
urllib3==2.3.0
|
||||||
vine==5.1.0
|
|
||||||
wcwidth==0.2.13
|
|
||||||
Werkzeug==3.1.3
|
|
||||||
wheel==0.45.1
|
|
||||||
wrapt==1.17.2
|
|
||||||
yarl==1.18.3
|
yarl==1.18.3
|
||||||
|
|||||||
@@ -1,58 +1,43 @@
|
|||||||
import tensorflow as tf
|
|
||||||
import tensorflow_hub as hub
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import csv
|
from panns_inference import AudioTagging, labels
|
||||||
|
|
||||||
# Force TensorFlow to use only CPU
|
# Initialize PANNs model
|
||||||
tf.config.set_visible_devices([], 'GPU')
|
at = AudioTagging(checkpoint_path=None, device='cuda')
|
||||||
|
|
||||||
model = hub.load('https://tfhub.dev/google/yamnet/1')
|
|
||||||
|
|
||||||
#Find the name of the class with the top score when mean-aggregated across frames.
|
|
||||||
def class_names_from_csv(class_map_scv_text):
|
|
||||||
"""Returns list of class names corresponding to score vector."""
|
|
||||||
class_names = []
|
|
||||||
with tf.io.gfile.GFile(class_map_scv_text) as csvfile:
|
|
||||||
reader = csv.DictReader(csvfile)
|
|
||||||
for row in reader:
|
|
||||||
class_names.append(row['display_name'])
|
|
||||||
return class_names
|
|
||||||
|
|
||||||
# Main function to process audio and classify
|
|
||||||
def classify_audio(waveform, sr):
|
def classify_audio(waveform, sr):
|
||||||
"""
|
"""
|
||||||
Given an audio file, this function loads the audio, resamples it,
|
Given an audio file, this function loads the audio, resamples it,
|
||||||
normalizes it, and runs it through the YAMNet model to classify the sound.
|
normalizes it, and runs it through the PANNs model to classify the sound.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- waveform (numpy.ndarray): waveform of the audio file (WAV, MP3, etc.).
|
- waveform (numpy.ndarray): waveform of the audio file (WAV, MP3, etc.).
|
||||||
|
- sr (int): Sampling rate of the audio.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- str: Predicted class label of the audio.
|
- str: Predicted class label of the audio.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Check if the sampling rate is 16000Hz
|
# Check if the sampling rate is 32000Hz
|
||||||
try:
|
try:
|
||||||
if(sr!=16000):
|
if sr != 32000:
|
||||||
raise RuntimeError
|
raise RuntimeError
|
||||||
except Exception:
|
except Exception:
|
||||||
raise RuntimeError(f"The audio is not sampled at 16000Hz, failed to classify audio.")
|
raise RuntimeError(f"The audio is not sampled at 32000Hz, failed to classify audio.")
|
||||||
|
|
||||||
# Normalize the waveform to [-1.0, 1.0] (librosa already returns normalized values)
|
# Normalize the waveform to [-1.0, 1.0]
|
||||||
waveform = waveform / np.max(np.abs(waveform))
|
waveform = waveform / np.max(np.abs(waveform))
|
||||||
|
|
||||||
# Execute the YAMNet model
|
# Ensure waveform shape is correct for model input
|
||||||
|
waveform = waveform[None, :]
|
||||||
|
|
||||||
|
# Execute the PANNs model
|
||||||
try:
|
try:
|
||||||
scores, embeddings, spectrogram = model(waveform)
|
clipwise_output, _ = at.inference(waveform)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Error: Failed to classify audio: {e}")
|
raise RuntimeError(f"Error: Failed to classify audio: {e}")
|
||||||
|
|
||||||
# Extract the class names from the model
|
# Get the top predicted class
|
||||||
class_map_path = model.class_map_path().numpy()
|
predicted_index = np.argmax(clipwise_output)
|
||||||
class_names = class_names_from_csv(class_map_path)
|
inferred_class = labels[predicted_index]
|
||||||
|
|
||||||
# Find the class with the highest score
|
return inferred_class
|
||||||
scores_np = scores.numpy()
|
|
||||||
inferred_class = class_names[scores_np.mean(axis=0).argmax()]
|
|
||||||
|
|
||||||
return inferred_class
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ def test_trim_audio():
|
|||||||
|
|
||||||
def test_classify():
|
def test_classify():
|
||||||
file_path = "tests/test_audio/cafe_crowd_talk.wav"
|
file_path = "tests/test_audio/cafe_crowd_talk.wav"
|
||||||
waveform, sr = read_audio(file_path, 16000, mono=True)
|
waveform, sr = read_audio(file_path, 32000, mono=True)
|
||||||
expected_class = "Speech"
|
expected_class = "Speech"
|
||||||
predicted_class = classify_audio(waveform, sr)
|
predicted_class = classify_audio(waveform, sr)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user