create ConvTasNet wrapper using asteroid

This commit is contained in:
Joel Mathew Thomas
2025-01-28 20:20:48 +05:30
parent 8e3e7dd062
commit 7a7c27caa3
+36
View File
@@ -0,0 +1,36 @@
import torch
from asteroid.models import ConvTasNet
def separate(audio, sr, model_name='mpariente/ConvTasNet_WHAMR_enhsingle'):
"""
Separates audio into sources using a pretrained Asteroid model.
Args:
audio (numpy.ndarray): The audio time series (1D numpy array).
sr (int): Sampling rate of the audio.
model_name (str): Name of the pretrained model from Asteroid. Default is 'mpariente/ConvTasNet_WHAMR_enhsingle'.
Returns:
list: List of separated sources as numpy arrays.
"""
try:
# Select the device: GPU if available, otherwise CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load the pretrained model and move it to the selected device
model = ConvTasNet.from_pretrained(model_name).to(device)
# Convert the audio array to a PyTorch tensor, add batch dimension, and move to device
audio_tensor = torch.tensor(audio).unsqueeze(0).to(device) # Shape: (1, num_samples)
# Perform source separation
with torch.no_grad():
separated_sources = model(audio_tensor)
# Convert separated sources to NumPy arrays and remove batch dimension
separated_sources_np = [src.squeeze(0).cpu().numpy() for src in separated_sources]
return separated_sources_np
except Exception as e:
raise RuntimeError(f"Error during separation: {e}")