diff --git a/pyproject.toml b/pyproject.toml index 717eaf4..04d4c3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,76 +6,51 @@ build-backend = "setuptools.build_meta" name = "freqsplit" version = "0.1.0" dependencies = [ - "absl-py==2.1.0", - "aiohappyeyeballs==2.4.4", - "aiohttp==3.11.11", + "aiohappyeyeballs==2.4.6", + "aiohttp==3.11.13", "aiosignal==1.3.2", - "amqp==5.3.1", "antlr4-python3-runtime==4.9.3", "appdirs==1.4.4", - "asgiref==3.8.1", "asteroid==0.7.0", "asteroid-filterbanks==0.4.0", - "astunparse==1.6.3", "attrs==25.1.0", "audioread==3.0.1", - "billiard==4.2.1", "cached-property==2.0.1", - "celery==5.4.0", - "certifi==2024.12.14", + "certifi==2025.1.31", "cffi==1.17.1", "charset-normalizer==3.4.1", - "click==8.1.8", - "click-didyoumean==0.3.1", - "click-plugins==1.1.1", - "click-repl==0.3.0", "cloudpickle==3.1.1", "contourpy==1.3.1", "cycler==0.12.1", - "decorator==5.1.1", + "decorator==5.2.1", "DeepFilterLib==0.5.6", "DeepFilterNet==0.5.6", "demucs==4.0.1", - "Django==5.1.6", "dora_search==0.1.12", - "einops==0.8.0", + "einops==0.8.1", "filelock==3.17.0", - "flatbuffers==25.2.10", - "fonttools==4.55.6", + "fonttools==4.56.0", "frozenlist==1.5.0", - "fsspec==2024.12.0", - "future==1.0.0", - "gast==0.6.0", - "google-pasta==0.2.0", - "grpcio==1.70.0", - "h5py==3.13.0", - "huggingface-hub==0.28.0", + "fsspec==2025.2.0", + "huggingface-hub==0.29.1", "idna==3.10", "iniconfig==2.0.0", "Jinja2==3.1.5", "joblib==1.4.2", "julius==0.2.7", - "keras==3.8.0", "kiwisolver==1.4.8", - "kombu==5.4.2", "lameenc==1.8.1", "lazy_loader==0.4", - "libclang==18.1.1", "librosa==0.10.2.post1", - "lightning-utilities==0.11.9", + "lightning-utilities==0.12.0", "llvmlite==0.44.0", "loguru==0.7.3", - "Markdown==3.7", - "markdown-it-py==3.0.0", "MarkupSafe==3.0.2", "matplotlib==3.10.0", - "mdurl==0.1.2", - "mir_eval==0.7", - "ml-dtypes==0.4.1", + "mir_eval==0.8.2", "mpmath==1.3.0", "msgpack==1.1.0", "multidict==6.1.0", - "namex==0.0.8", "networkx==3.4.2", "numba==0.61.0", "numpy==1.26.4", @@ -88,70 +63,58 @@ dependencies = [ "nvidia-curand-cu12==10.3.5.147", "nvidia-cusolver-cu12==11.6.1.9", "nvidia-cusparse-cu12==12.3.1.170", + "nvidia-cusparselt-cu12==0.6.2", "nvidia-nccl-cu12==2.21.5", "nvidia-nvjitlink-cu12==12.4.127", "nvidia-nvtx-cu12==12.4.127", "omegaconf==2.3.0", "openunmix==1.3.0", - "opt_einsum==3.4.0", - "optree==0.14.0", "packaging==23.2", "pandas==2.2.3", + "panns-inference==0.1.1", "pb-bss-eval==0.0.2", "pesq==0.0.4", "pillow==11.1.0", "platformdirs==4.3.6", "pluggy==1.5.0", "pooch==1.8.2", - "prompt_toolkit==3.0.50", - "propcache==0.2.1", - "protobuf==5.29.3", + "propcache==0.3.0", "pycparser==2.22", - "Pygments==2.19.1", "pyparsing==3.2.1", "pystoi==0.4.1", "pytest==8.3.4", "python-dateutil==2.9.0.post0", "pytorch-lightning==2.5.0.post0", "pytorch-ranger==0.1.1", - "pytz==2024.2", + "pytz==2025.1", "PyYAML==6.0.2", - "redis==5.2.1", + "regex==2024.11.6", "requests==2.32.3", "retrying==1.3.4", - "rich==13.9.4", + "safetensors==0.5.3", "scikit-learn==1.6.1", - "scipy==1.15.1", - "setuptools==75.8.0", + "scipy==1.15.2", + "setuptools==75.8.1", "six==1.17.0", "soundfile==0.13.1", "soxr==0.5.0.post1", - "sqlparse==0.5.3", "submitit==1.5.2", "sympy==1.13.1", - "tensorboard==2.18.0", - "tensorboard-data-server==0.7.2", - "tensorflow==2.18.0", - "tensorflow-hub==0.16.1", - "termcolor==2.5.0", - "tf_keras==2.18.0", "threadpoolctl==3.5.0", - "torch==2.5.1", + "tokenizers==0.21.0", + "torch==2.6.0", "torch-optimizer==0.1.0", "torch-stoi==0.2.3", - "torchaudio==2.5.1", + "torchaudio==2.6.0", + "torchlibrosa==0.1.0", "torchmetrics==0.11.4", "tqdm==4.67.1", + "transformers==4.49.0", "treetable==0.2.5", - "triton==3.1.0", + "triton==3.2.0", "typing_extensions==4.12.2", "tzdata==2025.1", "urllib3==2.3.0", - "vine==5.1.0", - "wcwidth==0.2.13", - "Werkzeug==3.1.3", - "wheel==0.45.1", - "wrapt==1.17.2", "yarl==1.18.3", ] diff --git a/requirements.txt b/requirements.txt index abd5442..cd78deb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,74 +1,48 @@ -absl-py==2.1.0 -aiohappyeyeballs==2.4.4 -aiohttp==3.11.11 +aiohappyeyeballs==2.4.6 +aiohttp==3.11.13 aiosignal==1.3.2 -amqp==5.3.1 antlr4-python3-runtime==4.9.3 appdirs==1.4.4 -asgiref==3.8.1 asteroid==0.7.0 asteroid-filterbanks==0.4.0 -astunparse==1.6.3 attrs==25.1.0 audioread==3.0.1 -billiard==4.2.1 cached-property==2.0.1 -celery==5.4.0 -certifi==2024.12.14 +certifi==2025.1.31 cffi==1.17.1 charset-normalizer==3.4.1 -click==8.1.8 -click-didyoumean==0.3.1 -click-plugins==1.1.1 -click-repl==0.3.0 cloudpickle==3.1.1 contourpy==1.3.1 cycler==0.12.1 -decorator==5.1.1 +decorator==5.2.1 DeepFilterLib==0.5.6 DeepFilterNet==0.5.6 demucs==4.0.1 -Django==5.1.6 -djangorestframework==3.15.2 dora_search==0.1.12 -einops==0.8.0 +einops==0.8.1 filelock==3.17.0 -flatbuffers==25.2.10 -fonttools==4.55.6 +fonttools==4.56.0 frozenlist==1.5.0 -fsspec==2024.12.0 -future==1.0.0 -gast==0.6.0 -google-pasta==0.2.0 -grpcio==1.70.0 -h5py==3.13.0 -huggingface-hub==0.28.0 +fsspec==2025.2.0 +huggingface-hub==0.29.1 idna==3.10 iniconfig==2.0.0 Jinja2==3.1.5 joblib==1.4.2 julius==0.2.7 -keras==3.8.0 kiwisolver==1.4.8 -kombu==5.4.2 lameenc==1.8.1 lazy_loader==0.4 -libclang==18.1.1 librosa==0.10.2.post1 -lightning-utilities==0.11.9 +lightning-utilities==0.12.0 llvmlite==0.44.0 loguru==0.7.3 -Markdown==3.7 -markdown-it-py==3.0.0 MarkupSafe==3.0.2 matplotlib==3.10.0 -mdurl==0.1.2 -mir_eval==0.7 -ml-dtypes==0.4.1 +mir_eval==0.8.2 mpmath==1.3.0 msgpack==1.1.0 multidict==6.1.0 -namex==0.0.8 networkx==3.4.2 numba==0.61.0 numpy==1.26.4 @@ -81,68 +55,56 @@ nvidia-cufft-cu12==11.2.1.3 nvidia-curand-cu12==10.3.5.147 nvidia-cusolver-cu12==11.6.1.9 nvidia-cusparse-cu12==12.3.1.170 +nvidia-cusparselt-cu12==0.6.2 nvidia-nccl-cu12==2.21.5 nvidia-nvjitlink-cu12==12.4.127 nvidia-nvtx-cu12==12.4.127 omegaconf==2.3.0 openunmix==1.3.0 -opt_einsum==3.4.0 -optree==0.14.0 packaging==23.2 pandas==2.2.3 +panns-inference==0.1.1 pb-bss-eval==0.0.2 pesq==0.0.4 pillow==11.1.0 platformdirs==4.3.6 pluggy==1.5.0 pooch==1.8.2 -prompt_toolkit==3.0.50 -propcache==0.2.1 -protobuf==5.29.3 +propcache==0.3.0 pycparser==2.22 -Pygments==2.19.1 pyparsing==3.2.1 pystoi==0.4.1 pytest==8.3.4 python-dateutil==2.9.0.post0 pytorch-lightning==2.5.0.post0 pytorch-ranger==0.1.1 -pytz==2024.2 +pytz==2025.1 PyYAML==6.0.2 -redis==5.2.1 +regex==2024.11.6 requests==2.32.3 retrying==1.3.4 -rich==13.9.4 +safetensors==0.5.3 scikit-learn==1.6.1 -scipy==1.15.1 -setuptools==75.8.0 +scipy==1.15.2 +setuptools==75.8.1 six==1.17.0 soundfile==0.13.1 soxr==0.5.0.post1 -sqlparse==0.5.3 submitit==1.5.2 sympy==1.13.1 -tensorboard==2.18.0 -tensorboard-data-server==0.7.2 -tensorflow==2.18.0 -tensorflow-hub==0.16.1 -termcolor==2.5.0 -tf_keras==2.18.0 threadpoolctl==3.5.0 -torch==2.5.1 +tokenizers==0.21.0 +torch==2.6.0 torch-optimizer==0.1.0 torch-stoi==0.2.3 -torchaudio==2.5.1 +torchaudio==2.6.0 +torchlibrosa==0.1.0 torchmetrics==0.11.4 tqdm==4.67.1 +transformers==4.49.0 treetable==0.2.5 -triton==3.1.0 +triton==3.2.0 typing_extensions==4.12.2 tzdata==2025.1 urllib3==2.3.0 -vine==5.1.0 -wcwidth==0.2.13 -Werkzeug==3.1.3 -wheel==0.45.1 -wrapt==1.17.2 yarl==1.18.3 diff --git a/src/freqsplit/preprocessing/classify.py b/src/freqsplit/preprocessing/classify.py index 3baccd8..0a52d3a 100644 --- a/src/freqsplit/preprocessing/classify.py +++ b/src/freqsplit/preprocessing/classify.py @@ -1,58 +1,43 @@ -import tensorflow as tf -import tensorflow_hub as hub import numpy as np -import csv +from panns_inference import AudioTagging, labels -# Force TensorFlow to use only CPU -tf.config.set_visible_devices([], 'GPU') +# Initialize PANNs model +at = AudioTagging(checkpoint_path=None, device='cuda') -model = hub.load('https://tfhub.dev/google/yamnet/1') - -#Find the name of the class with the top score when mean-aggregated across frames. -def class_names_from_csv(class_map_scv_text): - """Returns list of class names corresponding to score vector.""" - class_names = [] - with tf.io.gfile.GFile(class_map_scv_text) as csvfile: - reader = csv.DictReader(csvfile) - for row in reader: - class_names.append(row['display_name']) - return class_names - -# Main function to process audio and classify def classify_audio(waveform, sr): """ Given an audio file, this function loads the audio, resamples it, - normalizes it, and runs it through the YAMNet model to classify the sound. + normalizes it, and runs it through the PANNs model to classify the sound. Args: - waveform (numpy.ndarray): waveform of the audio file (WAV, MP3, etc.). + - sr (int): Sampling rate of the audio. Returns: - str: Predicted class label of the audio. """ - # Check if the sampling rate is 16000Hz + # Check if the sampling rate is 32000Hz try: - if(sr!=16000): + if sr != 32000: raise RuntimeError except Exception: - raise RuntimeError(f"The audio is not sampled at 16000Hz, failed to classify audio.") + raise RuntimeError(f"The audio is not sampled at 32000Hz, failed to classify audio.") - # Normalize the waveform to [-1.0, 1.0] (librosa already returns normalized values) + # Normalize the waveform to [-1.0, 1.0] waveform = waveform / np.max(np.abs(waveform)) - - # Execute the YAMNet model + + # Ensure waveform shape is correct for model input + waveform = waveform[None, :] + + # Execute the PANNs model try: - scores, embeddings, spectrogram = model(waveform) + clipwise_output, _ = at.inference(waveform) except Exception as e: raise RuntimeError(f"Error: Failed to classify audio: {e}") - - # Extract the class names from the model - class_map_path = model.class_map_path().numpy() - class_names = class_names_from_csv(class_map_path) - - # Find the class with the highest score - scores_np = scores.numpy() - inferred_class = class_names[scores_np.mean(axis=0).argmax()] - - return inferred_class \ No newline at end of file + + # Get the top predicted class + predicted_index = np.argmax(clipwise_output) + inferred_class = labels[predicted_index] + + return inferred_class diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 787fd65..9af97cd 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -24,7 +24,7 @@ def test_trim_audio(): def test_classify(): file_path = "tests/test_audio/cafe_crowd_talk.wav" - waveform, sr = read_audio(file_path, 16000, mono=True) + waveform, sr = read_audio(file_path, 32000, mono=True) expected_class = "Speech" predicted_class = classify_audio(waveform, sr)