Merge pull request #23 from joelmathewthomas/feature/asteroid-wrapper

Feature/asteroid wrapper
This commit is contained in:
Joel Mathew Thomas
2025-01-29 00:51:33 +05:30
committed by GitHub
6 changed files with 162 additions and 0 deletions
+25
View File
@@ -1,7 +1,14 @@
aiohappyeyeballs==2.4.4
aiohttp==3.11.11
aiosignal==1.3.2
amqp==5.3.1
antlr4-python3-runtime==4.9.3
asteroid==0.7.0
asteroid-filterbanks==0.4.0
attrs==25.1.0
audioread==3.0.1
billiard==4.2.1
cached-property==2.0.1
celery==5.4.0
certifi==2024.12.14
cffi==1.17.1
@@ -19,7 +26,10 @@ dora_search==0.1.12
einops==0.8.0
filelock==3.17.0
fonttools==4.55.6
frozenlist==1.5.0
fsspec==2024.12.0
future==1.0.0
huggingface-hub==0.28.0
idna==3.10
iniconfig==2.0.0
Jinja2==3.1.5
@@ -30,11 +40,14 @@ kombu==5.4.2
lameenc==1.8.1
lazy_loader==0.4
librosa==0.10.2.post1
lightning-utilities==0.11.9
llvmlite==0.44.0
MarkupSafe==3.0.2
matplotlib==3.10.0
mir_eval==0.7
mpmath==1.3.0
msgpack==1.1.0
multidict==6.1.0
networkx==3.4.2
numba==0.61.0
numpy==2.1.3
@@ -53,15 +66,23 @@ nvidia-nvtx-cu12==12.4.127
omegaconf==2.3.0
openunmix==1.3.0
packaging==24.2
pandas==2.2.3
pb-bss-eval==0.0.2
pesq==0.0.4
pillow==11.1.0
platformdirs==4.3.6
pluggy==1.5.0
pooch==1.8.2
prompt_toolkit==3.0.50
propcache==0.2.1
pycparser==2.22
pyparsing==3.2.1
pystoi==0.4.1
pytest==8.3.4
python-dateutil==2.9.0.post0
pytorch-lightning==2.5.0.post0
pytorch-ranger==0.1.1
pytz==2024.2
PyYAML==6.0.2
redis==5.2.1
requests==2.32.3
@@ -76,7 +97,10 @@ submitit==1.5.2
sympy==1.13.1
threadpoolctl==3.5.0
torch==2.5.1
torch-optimizer==0.1.0
torch-stoi==0.2.3
torchaudio==2.5.1
torchmetrics==0.11.4
tqdm==4.67.1
treetable==0.2.5
triton==3.1.0
@@ -85,3 +109,4 @@ tzdata==2025.1
urllib3==2.3.0
vine==5.1.0
wcwidth==0.2.13
yarl==1.18.3
+12
View File
@@ -0,0 +1,12 @@
# __init__.py
import logging
from datetime import datetime
# Configure logging
logging.basicConfig(
format='%(asctime)s : %(message)s',
level = logging.INFO
)
logging.info("freq-split-enhance/postprocessing package has been imported.")
+36
View File
@@ -0,0 +1,36 @@
import soundfile as sf
import numpy as np
def export_audio(audio, output_path, sr):
"""
Save a NumPy audio array to a specified audio file.
Args:
audio (numpy.ndarray): The audio data to be saved.`
output_path (str): The path where the audio file should be saved.
sr (int): The sampling rate of the audio.
"""
try:
print(f"Initial audio shape: {audio.shape}, dtype: {audio.dtype}")
if audio.ndim == 2 and audio.shape[0] == 2:
# Transpose stereo audio to match the expected shape
audio = audio.T # From (2, num_samples) to (num_samples, 2)
# Ensure the audio data type is float32
audio = audio.astype('float32')
# Normalize audio to avoid distortion
if np.max(np.abs(audio)) > 0: # Avoid divide by zero
audio = audio / np.max(np.abs(audio))
# Verify final format
print(f"Final audio shape: {audio.shape}, dtype: {audio.dtype}, max: {np.max(audio)}, min: {np.min(audio)}")
sf.write(output_path, audio, sr, format='wav')
print(f"Audio saved to {output_path}")
except Exception as e:
print(f"Error saving audio: {e}")
+39
View File
@@ -0,0 +1,39 @@
import torch
from asteroid.models import ConvTasNet
def separate(audio, model_name='mpariente/ConvTasNet_WHAM!_sepclean'):
"""
Separates audio into sources using a pretrained Asteroid model.
Args:
audio (numpy.ndarray): The audio time series (1D numpy array).
model_name (str): Name of the pretrained model from Asteroid. Default is 'mpariente/ConvTasNet_WHAM!_sepclean'.
Returns:
list: List of separated sources as numpy arrays.
"""
try:
# Select the device: GPU if available, otherwise CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load the pretrained model and move it to the selected device
model = ConvTasNet.from_pretrained(model_name).to(device)
# Convert the audio array to a PyTorch tensor, add batch dimension, and move to device
audio_tensor = torch.tensor(audio).unsqueeze(0).to(device) # Shape: (1, num_samples)
# Perform source separation
with torch.no_grad():
separated_sources = model(audio_tensor)
# Remove batch dimension
separated_sources = separated_sources.squeeze(0) # Shape: (num_sources, num_samples)
# Split into list of sources
separated_sources_np = separated_sources.cpu().numpy() # Convert to NumPy
separated_sources_list = [separated_sources_np[i, :] for i in range(separated_sources_np.shape[0])]
return separated_sources_list
except Exception as e:
raise RuntimeError(f"Error during separation: {e}")
Binary file not shown.
+50
View File
@@ -7,6 +7,8 @@ from src.input.file_reader import read_audio
from src.preprocessing.trim import trim_audio
from src.preprocessing.resample import resample
from src.separation.demucs_wrapper import separate_audio_with_demucs
from src.separation.convtasnet_wrapper import separate
from src.postprocessing.audio_writer import export_audio
def test_demucs_separation_with_preprocessing():
@@ -51,3 +53,51 @@ def test_demucs_separation_with_preprocessing():
for expected_file in expected_files:
file_path = file_folder / expected_file
assert file_path.exists(), f"Expected file {expected_file} not found in {file_name} folder."
def test_convtasnet_separation_with_output_files():
"""
Test to ensure ConvTasNet separation creates expected source audio files.
"""
input_audio_path = "tests/test_audio/female-female-mixture.wav"
output_dir = "/tmp/convtasnet"
model_name = "mpariente/ConvTasNet_WHAM!_sepclean"
audio, sr = read_audio(input_audio_path)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
separated_sources = separate(audio, model_name)
for i, source in enumerate(separated_sources):
output_path = os.path.join(output_dir, f"source_{i+1}.wav")
export_audio(source, output_path, sr)
# Check if the output directory exists
assert os.path.exists(output_dir), "Output directory does not exist."
# Check if source_1.wav and source_2.wav are created
source_1_path = os.path.join(output_dir, "source_1.wav")
source_2_path = os.path.join(output_dir, "source_2.wav")
assert os.path.exists(source_1_path), "source_1.wav was not created."
assert os.path.exists(source_2_path), "source_2.wav was not created."
# Check if the files have content (not just created)
# For example, you can check if the length of the audio files is greater than a certain threshold
# Here, we'll just verify the files are not empty.
import soundfile as sf
def is_file_non_empty(file_path):
try:
data, _ = sf.read(file_path)
return data.size > 0
except Exception as e:
return False
assert is_file_non_empty(source_1_path), "source_1.wav is empty."
assert is_file_non_empty(source_2_path), "source_2.wav is empty."
print("Test passed: source_1.wav and source_2.wav are present and non-empty.")