diff --git a/src/separation/convtasnet_wrapper.py b/src/separation/convtasnet_wrapper.py index 0c08b2a..a277762 100644 --- a/src/separation/convtasnet_wrapper.py +++ b/src/separation/convtasnet_wrapper.py @@ -1,13 +1,13 @@ import torch from asteroid.models import ConvTasNet -def separate(audio, model_name='mpariente/ConvTasNet_WHAMR_enhsingle'): +def separate(audio, model_name='mpariente/ConvTasNet_WHAM!_sepclean'): """ Separates audio into sources using a pretrained Asteroid model. Args: audio (numpy.ndarray): The audio time series (1D numpy array). - model_name (str): Name of the pretrained model from Asteroid. Default is 'mpariente/ConvTasNet_WHAMR_enhsingle'. + model_name (str): Name of the pretrained model from Asteroid. Default is 'mpariente/ConvTasNet_WHAM!_sepclean'. Returns: list: List of separated sources as numpy arrays.