Merge pull request #23 from joelmathewthomas/feature/asteroid-wrapper
Feature/asteroid wrapper
This commit is contained in:
@@ -1,7 +1,14 @@
|
|||||||
|
aiohappyeyeballs==2.4.4
|
||||||
|
aiohttp==3.11.11
|
||||||
|
aiosignal==1.3.2
|
||||||
amqp==5.3.1
|
amqp==5.3.1
|
||||||
antlr4-python3-runtime==4.9.3
|
antlr4-python3-runtime==4.9.3
|
||||||
|
asteroid==0.7.0
|
||||||
|
asteroid-filterbanks==0.4.0
|
||||||
|
attrs==25.1.0
|
||||||
audioread==3.0.1
|
audioread==3.0.1
|
||||||
billiard==4.2.1
|
billiard==4.2.1
|
||||||
|
cached-property==2.0.1
|
||||||
celery==5.4.0
|
celery==5.4.0
|
||||||
certifi==2024.12.14
|
certifi==2024.12.14
|
||||||
cffi==1.17.1
|
cffi==1.17.1
|
||||||
@@ -19,7 +26,10 @@ dora_search==0.1.12
|
|||||||
einops==0.8.0
|
einops==0.8.0
|
||||||
filelock==3.17.0
|
filelock==3.17.0
|
||||||
fonttools==4.55.6
|
fonttools==4.55.6
|
||||||
|
frozenlist==1.5.0
|
||||||
fsspec==2024.12.0
|
fsspec==2024.12.0
|
||||||
|
future==1.0.0
|
||||||
|
huggingface-hub==0.28.0
|
||||||
idna==3.10
|
idna==3.10
|
||||||
iniconfig==2.0.0
|
iniconfig==2.0.0
|
||||||
Jinja2==3.1.5
|
Jinja2==3.1.5
|
||||||
@@ -30,11 +40,14 @@ kombu==5.4.2
|
|||||||
lameenc==1.8.1
|
lameenc==1.8.1
|
||||||
lazy_loader==0.4
|
lazy_loader==0.4
|
||||||
librosa==0.10.2.post1
|
librosa==0.10.2.post1
|
||||||
|
lightning-utilities==0.11.9
|
||||||
llvmlite==0.44.0
|
llvmlite==0.44.0
|
||||||
MarkupSafe==3.0.2
|
MarkupSafe==3.0.2
|
||||||
matplotlib==3.10.0
|
matplotlib==3.10.0
|
||||||
|
mir_eval==0.7
|
||||||
mpmath==1.3.0
|
mpmath==1.3.0
|
||||||
msgpack==1.1.0
|
msgpack==1.1.0
|
||||||
|
multidict==6.1.0
|
||||||
networkx==3.4.2
|
networkx==3.4.2
|
||||||
numba==0.61.0
|
numba==0.61.0
|
||||||
numpy==2.1.3
|
numpy==2.1.3
|
||||||
@@ -53,15 +66,23 @@ nvidia-nvtx-cu12==12.4.127
|
|||||||
omegaconf==2.3.0
|
omegaconf==2.3.0
|
||||||
openunmix==1.3.0
|
openunmix==1.3.0
|
||||||
packaging==24.2
|
packaging==24.2
|
||||||
|
pandas==2.2.3
|
||||||
|
pb-bss-eval==0.0.2
|
||||||
|
pesq==0.0.4
|
||||||
pillow==11.1.0
|
pillow==11.1.0
|
||||||
platformdirs==4.3.6
|
platformdirs==4.3.6
|
||||||
pluggy==1.5.0
|
pluggy==1.5.0
|
||||||
pooch==1.8.2
|
pooch==1.8.2
|
||||||
prompt_toolkit==3.0.50
|
prompt_toolkit==3.0.50
|
||||||
|
propcache==0.2.1
|
||||||
pycparser==2.22
|
pycparser==2.22
|
||||||
pyparsing==3.2.1
|
pyparsing==3.2.1
|
||||||
|
pystoi==0.4.1
|
||||||
pytest==8.3.4
|
pytest==8.3.4
|
||||||
python-dateutil==2.9.0.post0
|
python-dateutil==2.9.0.post0
|
||||||
|
pytorch-lightning==2.5.0.post0
|
||||||
|
pytorch-ranger==0.1.1
|
||||||
|
pytz==2024.2
|
||||||
PyYAML==6.0.2
|
PyYAML==6.0.2
|
||||||
redis==5.2.1
|
redis==5.2.1
|
||||||
requests==2.32.3
|
requests==2.32.3
|
||||||
@@ -76,7 +97,10 @@ submitit==1.5.2
|
|||||||
sympy==1.13.1
|
sympy==1.13.1
|
||||||
threadpoolctl==3.5.0
|
threadpoolctl==3.5.0
|
||||||
torch==2.5.1
|
torch==2.5.1
|
||||||
|
torch-optimizer==0.1.0
|
||||||
|
torch-stoi==0.2.3
|
||||||
torchaudio==2.5.1
|
torchaudio==2.5.1
|
||||||
|
torchmetrics==0.11.4
|
||||||
tqdm==4.67.1
|
tqdm==4.67.1
|
||||||
treetable==0.2.5
|
treetable==0.2.5
|
||||||
triton==3.1.0
|
triton==3.1.0
|
||||||
@@ -85,3 +109,4 @@ tzdata==2025.1
|
|||||||
urllib3==2.3.0
|
urllib3==2.3.0
|
||||||
vine==5.1.0
|
vine==5.1.0
|
||||||
wcwidth==0.2.13
|
wcwidth==0.2.13
|
||||||
|
yarl==1.18.3
|
||||||
|
|||||||
@@ -0,0 +1,12 @@
|
|||||||
|
# __init__.py
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
format='%(asctime)s : %(message)s',
|
||||||
|
level = logging.INFO
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("freq-split-enhance/postprocessing package has been imported.")
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
import soundfile as sf
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def export_audio(audio, output_path, sr):
|
||||||
|
"""
|
||||||
|
Save a NumPy audio array to a specified audio file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio (numpy.ndarray): The audio data to be saved.`
|
||||||
|
output_path (str): The path where the audio file should be saved.
|
||||||
|
sr (int): The sampling rate of the audio.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
print(f"Initial audio shape: {audio.shape}, dtype: {audio.dtype}")
|
||||||
|
|
||||||
|
if audio.ndim == 2 and audio.shape[0] == 2:
|
||||||
|
# Transpose stereo audio to match the expected shape
|
||||||
|
audio = audio.T # From (2, num_samples) to (num_samples, 2)
|
||||||
|
|
||||||
|
# Ensure the audio data type is float32
|
||||||
|
audio = audio.astype('float32')
|
||||||
|
|
||||||
|
# Normalize audio to avoid distortion
|
||||||
|
if np.max(np.abs(audio)) > 0: # Avoid divide by zero
|
||||||
|
audio = audio / np.max(np.abs(audio))
|
||||||
|
|
||||||
|
# Verify final format
|
||||||
|
print(f"Final audio shape: {audio.shape}, dtype: {audio.dtype}, max: {np.max(audio)}, min: {np.min(audio)}")
|
||||||
|
|
||||||
|
|
||||||
|
sf.write(output_path, audio, sr, format='wav')
|
||||||
|
print(f"Audio saved to {output_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error saving audio: {e}")
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
import torch
|
||||||
|
from asteroid.models import ConvTasNet
|
||||||
|
|
||||||
|
def separate(audio, model_name='mpariente/ConvTasNet_WHAM!_sepclean'):
|
||||||
|
"""
|
||||||
|
Separates audio into sources using a pretrained Asteroid model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio (numpy.ndarray): The audio time series (1D numpy array).
|
||||||
|
model_name (str): Name of the pretrained model from Asteroid. Default is 'mpariente/ConvTasNet_WHAM!_sepclean'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of separated sources as numpy arrays.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Select the device: GPU if available, otherwise CPU
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
# Load the pretrained model and move it to the selected device
|
||||||
|
model = ConvTasNet.from_pretrained(model_name).to(device)
|
||||||
|
|
||||||
|
# Convert the audio array to a PyTorch tensor, add batch dimension, and move to device
|
||||||
|
audio_tensor = torch.tensor(audio).unsqueeze(0).to(device) # Shape: (1, num_samples)
|
||||||
|
|
||||||
|
# Perform source separation
|
||||||
|
with torch.no_grad():
|
||||||
|
separated_sources = model(audio_tensor)
|
||||||
|
|
||||||
|
# Remove batch dimension
|
||||||
|
separated_sources = separated_sources.squeeze(0) # Shape: (num_sources, num_samples)
|
||||||
|
|
||||||
|
# Split into list of sources
|
||||||
|
separated_sources_np = separated_sources.cpu().numpy() # Convert to NumPy
|
||||||
|
separated_sources_list = [separated_sources_np[i, :] for i in range(separated_sources_np.shape[0])]
|
||||||
|
|
||||||
|
return separated_sources_list
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Error during separation: {e}")
|
||||||
Binary file not shown.
@@ -7,6 +7,8 @@ from src.input.file_reader import read_audio
|
|||||||
from src.preprocessing.trim import trim_audio
|
from src.preprocessing.trim import trim_audio
|
||||||
from src.preprocessing.resample import resample
|
from src.preprocessing.resample import resample
|
||||||
from src.separation.demucs_wrapper import separate_audio_with_demucs
|
from src.separation.demucs_wrapper import separate_audio_with_demucs
|
||||||
|
from src.separation.convtasnet_wrapper import separate
|
||||||
|
from src.postprocessing.audio_writer import export_audio
|
||||||
|
|
||||||
|
|
||||||
def test_demucs_separation_with_preprocessing():
|
def test_demucs_separation_with_preprocessing():
|
||||||
@@ -51,3 +53,51 @@ def test_demucs_separation_with_preprocessing():
|
|||||||
for expected_file in expected_files:
|
for expected_file in expected_files:
|
||||||
file_path = file_folder / expected_file
|
file_path = file_folder / expected_file
|
||||||
assert file_path.exists(), f"Expected file {expected_file} not found in {file_name} folder."
|
assert file_path.exists(), f"Expected file {expected_file} not found in {file_name} folder."
|
||||||
|
|
||||||
|
def test_convtasnet_separation_with_output_files():
|
||||||
|
"""
|
||||||
|
Test to ensure ConvTasNet separation creates expected source audio files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
input_audio_path = "tests/test_audio/female-female-mixture.wav"
|
||||||
|
output_dir = "/tmp/convtasnet"
|
||||||
|
model_name = "mpariente/ConvTasNet_WHAM!_sepclean"
|
||||||
|
|
||||||
|
audio, sr = read_audio(input_audio_path)
|
||||||
|
|
||||||
|
if not os.path.exists(output_dir):
|
||||||
|
os.makedirs(output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
separated_sources = separate(audio, model_name)
|
||||||
|
|
||||||
|
for i, source in enumerate(separated_sources):
|
||||||
|
output_path = os.path.join(output_dir, f"source_{i+1}.wav")
|
||||||
|
export_audio(source, output_path, sr)
|
||||||
|
|
||||||
|
# Check if the output directory exists
|
||||||
|
assert os.path.exists(output_dir), "Output directory does not exist."
|
||||||
|
|
||||||
|
# Check if source_1.wav and source_2.wav are created
|
||||||
|
source_1_path = os.path.join(output_dir, "source_1.wav")
|
||||||
|
source_2_path = os.path.join(output_dir, "source_2.wav")
|
||||||
|
|
||||||
|
assert os.path.exists(source_1_path), "source_1.wav was not created."
|
||||||
|
assert os.path.exists(source_2_path), "source_2.wav was not created."
|
||||||
|
|
||||||
|
# Check if the files have content (not just created)
|
||||||
|
# For example, you can check if the length of the audio files is greater than a certain threshold
|
||||||
|
# Here, we'll just verify the files are not empty.
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
def is_file_non_empty(file_path):
|
||||||
|
try:
|
||||||
|
data, _ = sf.read(file_path)
|
||||||
|
return data.size > 0
|
||||||
|
except Exception as e:
|
||||||
|
return False
|
||||||
|
|
||||||
|
assert is_file_non_empty(source_1_path), "source_1.wav is empty."
|
||||||
|
assert is_file_non_empty(source_2_path), "source_2.wav is empty."
|
||||||
|
|
||||||
|
print("Test passed: source_1.wav and source_2.wav are present and non-empty.")
|
||||||
|
|||||||
Reference in New Issue
Block a user