Merge pull request #23 from joelmathewthomas/feature/asteroid-wrapper
Feature/asteroid wrapper
This commit is contained in:
@@ -1,7 +1,14 @@
|
||||
aiohappyeyeballs==2.4.4
|
||||
aiohttp==3.11.11
|
||||
aiosignal==1.3.2
|
||||
amqp==5.3.1
|
||||
antlr4-python3-runtime==4.9.3
|
||||
asteroid==0.7.0
|
||||
asteroid-filterbanks==0.4.0
|
||||
attrs==25.1.0
|
||||
audioread==3.0.1
|
||||
billiard==4.2.1
|
||||
cached-property==2.0.1
|
||||
celery==5.4.0
|
||||
certifi==2024.12.14
|
||||
cffi==1.17.1
|
||||
@@ -19,7 +26,10 @@ dora_search==0.1.12
|
||||
einops==0.8.0
|
||||
filelock==3.17.0
|
||||
fonttools==4.55.6
|
||||
frozenlist==1.5.0
|
||||
fsspec==2024.12.0
|
||||
future==1.0.0
|
||||
huggingface-hub==0.28.0
|
||||
idna==3.10
|
||||
iniconfig==2.0.0
|
||||
Jinja2==3.1.5
|
||||
@@ -30,11 +40,14 @@ kombu==5.4.2
|
||||
lameenc==1.8.1
|
||||
lazy_loader==0.4
|
||||
librosa==0.10.2.post1
|
||||
lightning-utilities==0.11.9
|
||||
llvmlite==0.44.0
|
||||
MarkupSafe==3.0.2
|
||||
matplotlib==3.10.0
|
||||
mir_eval==0.7
|
||||
mpmath==1.3.0
|
||||
msgpack==1.1.0
|
||||
multidict==6.1.0
|
||||
networkx==3.4.2
|
||||
numba==0.61.0
|
||||
numpy==2.1.3
|
||||
@@ -53,15 +66,23 @@ nvidia-nvtx-cu12==12.4.127
|
||||
omegaconf==2.3.0
|
||||
openunmix==1.3.0
|
||||
packaging==24.2
|
||||
pandas==2.2.3
|
||||
pb-bss-eval==0.0.2
|
||||
pesq==0.0.4
|
||||
pillow==11.1.0
|
||||
platformdirs==4.3.6
|
||||
pluggy==1.5.0
|
||||
pooch==1.8.2
|
||||
prompt_toolkit==3.0.50
|
||||
propcache==0.2.1
|
||||
pycparser==2.22
|
||||
pyparsing==3.2.1
|
||||
pystoi==0.4.1
|
||||
pytest==8.3.4
|
||||
python-dateutil==2.9.0.post0
|
||||
pytorch-lightning==2.5.0.post0
|
||||
pytorch-ranger==0.1.1
|
||||
pytz==2024.2
|
||||
PyYAML==6.0.2
|
||||
redis==5.2.1
|
||||
requests==2.32.3
|
||||
@@ -76,7 +97,10 @@ submitit==1.5.2
|
||||
sympy==1.13.1
|
||||
threadpoolctl==3.5.0
|
||||
torch==2.5.1
|
||||
torch-optimizer==0.1.0
|
||||
torch-stoi==0.2.3
|
||||
torchaudio==2.5.1
|
||||
torchmetrics==0.11.4
|
||||
tqdm==4.67.1
|
||||
treetable==0.2.5
|
||||
triton==3.1.0
|
||||
@@ -85,3 +109,4 @@ tzdata==2025.1
|
||||
urllib3==2.3.0
|
||||
vine==5.1.0
|
||||
wcwidth==0.2.13
|
||||
yarl==1.18.3
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
# __init__.py
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
format='%(asctime)s : %(message)s',
|
||||
level = logging.INFO
|
||||
)
|
||||
|
||||
logging.info("freq-split-enhance/postprocessing package has been imported.")
|
||||
@@ -0,0 +1,36 @@
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
|
||||
def export_audio(audio, output_path, sr):
|
||||
"""
|
||||
Save a NumPy audio array to a specified audio file.
|
||||
|
||||
Args:
|
||||
audio (numpy.ndarray): The audio data to be saved.`
|
||||
output_path (str): The path where the audio file should be saved.
|
||||
sr (int): The sampling rate of the audio.
|
||||
"""
|
||||
|
||||
try:
|
||||
|
||||
print(f"Initial audio shape: {audio.shape}, dtype: {audio.dtype}")
|
||||
|
||||
if audio.ndim == 2 and audio.shape[0] == 2:
|
||||
# Transpose stereo audio to match the expected shape
|
||||
audio = audio.T # From (2, num_samples) to (num_samples, 2)
|
||||
|
||||
# Ensure the audio data type is float32
|
||||
audio = audio.astype('float32')
|
||||
|
||||
# Normalize audio to avoid distortion
|
||||
if np.max(np.abs(audio)) > 0: # Avoid divide by zero
|
||||
audio = audio / np.max(np.abs(audio))
|
||||
|
||||
# Verify final format
|
||||
print(f"Final audio shape: {audio.shape}, dtype: {audio.dtype}, max: {np.max(audio)}, min: {np.min(audio)}")
|
||||
|
||||
|
||||
sf.write(output_path, audio, sr, format='wav')
|
||||
print(f"Audio saved to {output_path}")
|
||||
except Exception as e:
|
||||
print(f"Error saving audio: {e}")
|
||||
@@ -0,0 +1,39 @@
|
||||
import torch
|
||||
from asteroid.models import ConvTasNet
|
||||
|
||||
def separate(audio, model_name='mpariente/ConvTasNet_WHAM!_sepclean'):
|
||||
"""
|
||||
Separates audio into sources using a pretrained Asteroid model.
|
||||
|
||||
Args:
|
||||
audio (numpy.ndarray): The audio time series (1D numpy array).
|
||||
model_name (str): Name of the pretrained model from Asteroid. Default is 'mpariente/ConvTasNet_WHAM!_sepclean'.
|
||||
|
||||
Returns:
|
||||
list: List of separated sources as numpy arrays.
|
||||
"""
|
||||
try:
|
||||
# Select the device: GPU if available, otherwise CPU
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Load the pretrained model and move it to the selected device
|
||||
model = ConvTasNet.from_pretrained(model_name).to(device)
|
||||
|
||||
# Convert the audio array to a PyTorch tensor, add batch dimension, and move to device
|
||||
audio_tensor = torch.tensor(audio).unsqueeze(0).to(device) # Shape: (1, num_samples)
|
||||
|
||||
# Perform source separation
|
||||
with torch.no_grad():
|
||||
separated_sources = model(audio_tensor)
|
||||
|
||||
# Remove batch dimension
|
||||
separated_sources = separated_sources.squeeze(0) # Shape: (num_sources, num_samples)
|
||||
|
||||
# Split into list of sources
|
||||
separated_sources_np = separated_sources.cpu().numpy() # Convert to NumPy
|
||||
separated_sources_list = [separated_sources_np[i, :] for i in range(separated_sources_np.shape[0])]
|
||||
|
||||
return separated_sources_list
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error during separation: {e}")
|
||||
Binary file not shown.
@@ -7,6 +7,8 @@ from src.input.file_reader import read_audio
|
||||
from src.preprocessing.trim import trim_audio
|
||||
from src.preprocessing.resample import resample
|
||||
from src.separation.demucs_wrapper import separate_audio_with_demucs
|
||||
from src.separation.convtasnet_wrapper import separate
|
||||
from src.postprocessing.audio_writer import export_audio
|
||||
|
||||
|
||||
def test_demucs_separation_with_preprocessing():
|
||||
@@ -51,3 +53,51 @@ def test_demucs_separation_with_preprocessing():
|
||||
for expected_file in expected_files:
|
||||
file_path = file_folder / expected_file
|
||||
assert file_path.exists(), f"Expected file {expected_file} not found in {file_name} folder."
|
||||
|
||||
def test_convtasnet_separation_with_output_files():
|
||||
"""
|
||||
Test to ensure ConvTasNet separation creates expected source audio files.
|
||||
"""
|
||||
|
||||
input_audio_path = "tests/test_audio/female-female-mixture.wav"
|
||||
output_dir = "/tmp/convtasnet"
|
||||
model_name = "mpariente/ConvTasNet_WHAM!_sepclean"
|
||||
|
||||
audio, sr = read_audio(input_audio_path)
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
|
||||
separated_sources = separate(audio, model_name)
|
||||
|
||||
for i, source in enumerate(separated_sources):
|
||||
output_path = os.path.join(output_dir, f"source_{i+1}.wav")
|
||||
export_audio(source, output_path, sr)
|
||||
|
||||
# Check if the output directory exists
|
||||
assert os.path.exists(output_dir), "Output directory does not exist."
|
||||
|
||||
# Check if source_1.wav and source_2.wav are created
|
||||
source_1_path = os.path.join(output_dir, "source_1.wav")
|
||||
source_2_path = os.path.join(output_dir, "source_2.wav")
|
||||
|
||||
assert os.path.exists(source_1_path), "source_1.wav was not created."
|
||||
assert os.path.exists(source_2_path), "source_2.wav was not created."
|
||||
|
||||
# Check if the files have content (not just created)
|
||||
# For example, you can check if the length of the audio files is greater than a certain threshold
|
||||
# Here, we'll just verify the files are not empty.
|
||||
import soundfile as sf
|
||||
|
||||
def is_file_non_empty(file_path):
|
||||
try:
|
||||
data, _ = sf.read(file_path)
|
||||
return data.size > 0
|
||||
except Exception as e:
|
||||
return False
|
||||
|
||||
assert is_file_non_empty(source_1_path), "source_1.wav is empty."
|
||||
assert is_file_non_empty(source_2_path), "source_2.wav is empty."
|
||||
|
||||
print("Test passed: source_1.wav and source_2.wav are present and non-empty.")
|
||||
|
||||
Reference in New Issue
Block a user