add test case to test mpariente/ConvTasNet_WHAMsepcleanwrapper
This commit is contained in:
@@ -1,32 +0,0 @@
|
|||||||
import os
|
|
||||||
from src.input.file_reader import read_audio
|
|
||||||
from src.separation.convtasnet_wrapper import separate
|
|
||||||
from src.postprocessing.audio_writer import export_audio
|
|
||||||
|
|
||||||
def main(input_audio_path, output_dir, model_name="mpariente/ConvTasNet_WHAM!_sepclean"):
|
|
||||||
try:
|
|
||||||
audio, sr = read_audio(input_audio_path)
|
|
||||||
print(f"Loaded audio from {input_audio_path} with sampling rate {sr}")
|
|
||||||
|
|
||||||
if not os.path.exists(output_dir):
|
|
||||||
os.makedirs(output_dir)
|
|
||||||
|
|
||||||
|
|
||||||
separated_sources = separate(audio, model_name)
|
|
||||||
print(f"Separated {len(separated_sources)} sources")
|
|
||||||
|
|
||||||
for i, source in enumerate(separated_sources):
|
|
||||||
output_path = os.path.join(output_dir, f"source_{i+1}.wav")
|
|
||||||
export_audio(source, output_path, sr)
|
|
||||||
print(f"Exported separated source {i + 1} to {output_path}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error in processing: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
input_audio_path = "/home/joel/Downloads/wham/female-female-mixture.wav"
|
|
||||||
output_dir = "/tmp/convtasnet"
|
|
||||||
main(input_audio_path, output_dir)
|
|
||||||
|
|
||||||
|
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
from src.input.file_reader import read_audio
|
||||||
|
from src.separation.convtasnet_wrapper import separate
|
||||||
|
from src.postprocessing.audio_writer import export_audio
|
||||||
|
|
||||||
|
|
||||||
|
def test_convtasnet_separation_with_output_files():
|
||||||
|
"""
|
||||||
|
Test to ensure ConvTasNet separation creates expected source audio files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
input_audio_path = "tests/test_audio/female-female-mixture.wav"
|
||||||
|
output_dir = "/tmp/convtasnet"
|
||||||
|
model_name = "mpariente/ConvTasNet_WHAM!_sepclean"
|
||||||
|
|
||||||
|
audio, sr = read_audio(input_audio_path)
|
||||||
|
|
||||||
|
if not os.path.exists(output_dir):
|
||||||
|
os.makedirs(output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
separated_sources = separate(audio, model_name)
|
||||||
|
|
||||||
|
for i, source in enumerate(separated_sources):
|
||||||
|
output_path = os.path.join(output_dir, f"source_{i+1}.wav")
|
||||||
|
export_audio(source, output_path, sr)
|
||||||
|
|
||||||
|
# Check if the output directory exists
|
||||||
|
assert os.path.exists(output_dir), "Output directory does not exist."
|
||||||
|
|
||||||
|
# Check if source_1.wav and source_2.wav are created
|
||||||
|
source_1_path = os.path.join(output_dir, "source_1.wav")
|
||||||
|
source_2_path = os.path.join(output_dir, "source_2.wav")
|
||||||
|
|
||||||
|
assert os.path.exists(source_1_path), "source_1.wav was not created."
|
||||||
|
assert os.path.exists(source_2_path), "source_2.wav was not created."
|
||||||
|
|
||||||
|
# Check if the files have content (not just created)
|
||||||
|
# For example, you can check if the length of the audio files is greater than a certain threshold
|
||||||
|
# Here, we'll just verify the files are not empty.
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
def is_file_non_empty(file_path):
|
||||||
|
try:
|
||||||
|
data, _ = sf.read(file_path)
|
||||||
|
return data.size > 0
|
||||||
|
except Exception as e:
|
||||||
|
return False
|
||||||
|
|
||||||
|
assert is_file_non_empty(source_1_path), "source_1.wav is empty."
|
||||||
|
assert is_file_non_empty(source_2_path), "source_2.wav is empty."
|
||||||
|
|
||||||
|
print("Test passed: source_1.wav and source_2.wav are present and non-empty.")
|
||||||
Reference in New Issue
Block a user