diff --git a/requirements/env.txt b/requirements/env.txt index 676f0f7..128956f 100644 --- a/requirements/env.txt +++ b/requirements/env.txt @@ -1,7 +1,14 @@ +aiohappyeyeballs==2.4.4 +aiohttp==3.11.11 +aiosignal==1.3.2 amqp==5.3.1 antlr4-python3-runtime==4.9.3 +asteroid==0.7.0 +asteroid-filterbanks==0.4.0 +attrs==25.1.0 audioread==3.0.1 billiard==4.2.1 +cached-property==2.0.1 celery==5.4.0 certifi==2024.12.14 cffi==1.17.1 @@ -19,7 +26,10 @@ dora_search==0.1.12 einops==0.8.0 filelock==3.17.0 fonttools==4.55.6 +frozenlist==1.5.0 fsspec==2024.12.0 +future==1.0.0 +huggingface-hub==0.28.0 idna==3.10 iniconfig==2.0.0 Jinja2==3.1.5 @@ -30,11 +40,14 @@ kombu==5.4.2 lameenc==1.8.1 lazy_loader==0.4 librosa==0.10.2.post1 +lightning-utilities==0.11.9 llvmlite==0.44.0 MarkupSafe==3.0.2 matplotlib==3.10.0 +mir_eval==0.7 mpmath==1.3.0 msgpack==1.1.0 +multidict==6.1.0 networkx==3.4.2 numba==0.61.0 numpy==2.1.3 @@ -53,15 +66,23 @@ nvidia-nvtx-cu12==12.4.127 omegaconf==2.3.0 openunmix==1.3.0 packaging==24.2 +pandas==2.2.3 +pb-bss-eval==0.0.2 +pesq==0.0.4 pillow==11.1.0 platformdirs==4.3.6 pluggy==1.5.0 pooch==1.8.2 prompt_toolkit==3.0.50 +propcache==0.2.1 pycparser==2.22 pyparsing==3.2.1 +pystoi==0.4.1 pytest==8.3.4 python-dateutil==2.9.0.post0 +pytorch-lightning==2.5.0.post0 +pytorch-ranger==0.1.1 +pytz==2024.2 PyYAML==6.0.2 redis==5.2.1 requests==2.32.3 @@ -76,7 +97,10 @@ submitit==1.5.2 sympy==1.13.1 threadpoolctl==3.5.0 torch==2.5.1 +torch-optimizer==0.1.0 +torch-stoi==0.2.3 torchaudio==2.5.1 +torchmetrics==0.11.4 tqdm==4.67.1 treetable==0.2.5 triton==3.1.0 @@ -85,3 +109,4 @@ tzdata==2025.1 urllib3==2.3.0 vine==5.1.0 wcwidth==0.2.13 +yarl==1.18.3 diff --git a/src/postprocessing/__init__.py b/src/postprocessing/__init__.py new file mode 100644 index 0000000..636a4ce --- /dev/null +++ b/src/postprocessing/__init__.py @@ -0,0 +1,12 @@ +# __init__.py + +import logging +from datetime import datetime + +# Configure logging +logging.basicConfig( + format='%(asctime)s : %(message)s', + level = logging.INFO +) + +logging.info("freq-split-enhance/postprocessing package has been imported.") \ No newline at end of file diff --git a/src/postprocessing/audio_writer.py b/src/postprocessing/audio_writer.py new file mode 100644 index 0000000..2fd7da0 --- /dev/null +++ b/src/postprocessing/audio_writer.py @@ -0,0 +1,36 @@ +import soundfile as sf +import numpy as np + +def export_audio(audio, output_path, sr): + """ + Save a NumPy audio array to a specified audio file. + + Args: + audio (numpy.ndarray): The audio data to be saved.` + output_path (str): The path where the audio file should be saved. + sr (int): The sampling rate of the audio. + """ + + try: + + print(f"Initial audio shape: {audio.shape}, dtype: {audio.dtype}") + + if audio.ndim == 2 and audio.shape[0] == 2: + # Transpose stereo audio to match the expected shape + audio = audio.T # From (2, num_samples) to (num_samples, 2) + + # Ensure the audio data type is float32 + audio = audio.astype('float32') + + # Normalize audio to avoid distortion + if np.max(np.abs(audio)) > 0: # Avoid divide by zero + audio = audio / np.max(np.abs(audio)) + + # Verify final format + print(f"Final audio shape: {audio.shape}, dtype: {audio.dtype}, max: {np.max(audio)}, min: {np.min(audio)}") + + + sf.write(output_path, audio, sr, format='wav') + print(f"Audio saved to {output_path}") + except Exception as e: + print(f"Error saving audio: {e}") \ No newline at end of file diff --git a/src/separation/convtasnet_wrapper.py b/src/separation/convtasnet_wrapper.py new file mode 100644 index 0000000..ae52b38 --- /dev/null +++ b/src/separation/convtasnet_wrapper.py @@ -0,0 +1,39 @@ +import torch +from asteroid.models import ConvTasNet + +def separate(audio, model_name='mpariente/ConvTasNet_WHAM!_sepclean'): + """ + Separates audio into sources using a pretrained Asteroid model. + + Args: + audio (numpy.ndarray): The audio time series (1D numpy array). + model_name (str): Name of the pretrained model from Asteroid. Default is 'mpariente/ConvTasNet_WHAM!_sepclean'. + + Returns: + list: List of separated sources as numpy arrays. + """ + try: + # Select the device: GPU if available, otherwise CPU + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + # Load the pretrained model and move it to the selected device + model = ConvTasNet.from_pretrained(model_name).to(device) + + # Convert the audio array to a PyTorch tensor, add batch dimension, and move to device + audio_tensor = torch.tensor(audio).unsqueeze(0).to(device) # Shape: (1, num_samples) + + # Perform source separation + with torch.no_grad(): + separated_sources = model(audio_tensor) + + # Remove batch dimension + separated_sources = separated_sources.squeeze(0) # Shape: (num_sources, num_samples) + + # Split into list of sources + separated_sources_np = separated_sources.cpu().numpy() # Convert to NumPy + separated_sources_list = [separated_sources_np[i, :] for i in range(separated_sources_np.shape[0])] + + return separated_sources_list + except Exception as e: + raise RuntimeError(f"Error during separation: {e}") diff --git a/tests/test_audio/female-female-mixture.wav b/tests/test_audio/female-female-mixture.wav new file mode 100644 index 0000000..2137997 Binary files /dev/null and b/tests/test_audio/female-female-mixture.wav differ diff --git a/tests/test_separation.py b/tests/test_separation.py index 00c020a..dfeb282 100644 --- a/tests/test_separation.py +++ b/tests/test_separation.py @@ -7,6 +7,8 @@ from src.input.file_reader import read_audio from src.preprocessing.trim import trim_audio from src.preprocessing.resample import resample from src.separation.demucs_wrapper import separate_audio_with_demucs +from src.separation.convtasnet_wrapper import separate +from src.postprocessing.audio_writer import export_audio def test_demucs_separation_with_preprocessing(): @@ -51,3 +53,51 @@ def test_demucs_separation_with_preprocessing(): for expected_file in expected_files: file_path = file_folder / expected_file assert file_path.exists(), f"Expected file {expected_file} not found in {file_name} folder." + +def test_convtasnet_separation_with_output_files(): + """ + Test to ensure ConvTasNet separation creates expected source audio files. + """ + + input_audio_path = "tests/test_audio/female-female-mixture.wav" + output_dir = "/tmp/convtasnet" + model_name = "mpariente/ConvTasNet_WHAM!_sepclean" + + audio, sr = read_audio(input_audio_path) + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + + separated_sources = separate(audio, model_name) + + for i, source in enumerate(separated_sources): + output_path = os.path.join(output_dir, f"source_{i+1}.wav") + export_audio(source, output_path, sr) + + # Check if the output directory exists + assert os.path.exists(output_dir), "Output directory does not exist." + + # Check if source_1.wav and source_2.wav are created + source_1_path = os.path.join(output_dir, "source_1.wav") + source_2_path = os.path.join(output_dir, "source_2.wav") + + assert os.path.exists(source_1_path), "source_1.wav was not created." + assert os.path.exists(source_2_path), "source_2.wav was not created." + + # Check if the files have content (not just created) + # For example, you can check if the length of the audio files is greater than a certain threshold + # Here, we'll just verify the files are not empty. + import soundfile as sf + + def is_file_non_empty(file_path): + try: + data, _ = sf.read(file_path) + return data.size > 0 + except Exception as e: + return False + + assert is_file_non_empty(source_1_path), "source_1.wav is empty." + assert is_file_non_empty(source_2_path), "source_2.wav is empty." + + print("Test passed: source_1.wav and source_2.wav are present and non-empty.")