create ConvTasNet wrapper using asteroid
This commit is contained in:
@@ -0,0 +1,36 @@
|
||||
import torch
|
||||
from asteroid.models import ConvTasNet
|
||||
|
||||
def separate(audio, sr, model_name='mpariente/ConvTasNet_WHAMR_enhsingle'):
|
||||
"""
|
||||
Separates audio into sources using a pretrained Asteroid model.
|
||||
|
||||
Args:
|
||||
audio (numpy.ndarray): The audio time series (1D numpy array).
|
||||
sr (int): Sampling rate of the audio.
|
||||
model_name (str): Name of the pretrained model from Asteroid. Default is 'mpariente/ConvTasNet_WHAMR_enhsingle'.
|
||||
|
||||
Returns:
|
||||
list: List of separated sources as numpy arrays.
|
||||
"""
|
||||
try:
|
||||
# Select the device: GPU if available, otherwise CPU
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Load the pretrained model and move it to the selected device
|
||||
model = ConvTasNet.from_pretrained(model_name).to(device)
|
||||
|
||||
# Convert the audio array to a PyTorch tensor, add batch dimension, and move to device
|
||||
audio_tensor = torch.tensor(audio).unsqueeze(0).to(device) # Shape: (1, num_samples)
|
||||
|
||||
# Perform source separation
|
||||
with torch.no_grad():
|
||||
separated_sources = model(audio_tensor)
|
||||
|
||||
# Convert separated sources to NumPy arrays and remove batch dimension
|
||||
separated_sources_np = [src.squeeze(0).cpu().numpy() for src in separated_sources]
|
||||
|
||||
return separated_sources_np
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error during separation: {e}")
|
||||
Reference in New Issue
Block a user