From 7f973db1d0b952511e755c1e3097fda54babd32b Mon Sep 17 00:00:00 2001 From: Joel Mathew Thomas <90510078+joelmathewthomas@users.noreply.github.com> Date: Tue, 28 Jan 2025 23:51:21 +0530 Subject: [PATCH] add test case to test mpariente/ConvTasNet_WHAMsepcleanwrapper --- test.py | 32 ----------- tests/test_convtasnet_wrapper.py | 54 +++++++++++++++++++ ...t_separation.py => test_demucs_wrapper.py} | 0 3 files changed, 54 insertions(+), 32 deletions(-) delete mode 100644 test.py create mode 100644 tests/test_convtasnet_wrapper.py rename tests/{test_separation.py => test_demucs_wrapper.py} (100%) diff --git a/test.py b/test.py deleted file mode 100644 index 272e3d8..0000000 --- a/test.py +++ /dev/null @@ -1,32 +0,0 @@ -import os -from src.input.file_reader import read_audio -from src.separation.convtasnet_wrapper import separate -from src.postprocessing.audio_writer import export_audio - -def main(input_audio_path, output_dir, model_name="mpariente/ConvTasNet_WHAM!_sepclean"): - try: - audio, sr = read_audio(input_audio_path) - print(f"Loaded audio from {input_audio_path} with sampling rate {sr}") - - if not os.path.exists(output_dir): - os.makedirs(output_dir) - - - separated_sources = separate(audio, model_name) - print(f"Separated {len(separated_sources)} sources") - - for i, source in enumerate(separated_sources): - output_path = os.path.join(output_dir, f"source_{i+1}.wav") - export_audio(source, output_path, sr) - print(f"Exported separated source {i + 1} to {output_path}") - - except Exception as e: - print(f"Error in processing: {e}") - - -if __name__ == "__main__": - input_audio_path = "/home/joel/Downloads/wham/female-female-mixture.wav" - output_dir = "/tmp/convtasnet" - main(input_audio_path, output_dir) - - \ No newline at end of file diff --git a/tests/test_convtasnet_wrapper.py b/tests/test_convtasnet_wrapper.py new file mode 100644 index 0000000..ca9b865 --- /dev/null +++ b/tests/test_convtasnet_wrapper.py @@ -0,0 +1,54 @@ +import os +import pytest +from src.input.file_reader import read_audio +from src.separation.convtasnet_wrapper import separate +from src.postprocessing.audio_writer import export_audio + + +def test_convtasnet_separation_with_output_files(): + """ + Test to ensure ConvTasNet separation creates expected source audio files. + """ + + input_audio_path = "tests/test_audio/female-female-mixture.wav" + output_dir = "/tmp/convtasnet" + model_name = "mpariente/ConvTasNet_WHAM!_sepclean" + + audio, sr = read_audio(input_audio_path) + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + + separated_sources = separate(audio, model_name) + + for i, source in enumerate(separated_sources): + output_path = os.path.join(output_dir, f"source_{i+1}.wav") + export_audio(source, output_path, sr) + + # Check if the output directory exists + assert os.path.exists(output_dir), "Output directory does not exist." + + # Check if source_1.wav and source_2.wav are created + source_1_path = os.path.join(output_dir, "source_1.wav") + source_2_path = os.path.join(output_dir, "source_2.wav") + + assert os.path.exists(source_1_path), "source_1.wav was not created." + assert os.path.exists(source_2_path), "source_2.wav was not created." + + # Check if the files have content (not just created) + # For example, you can check if the length of the audio files is greater than a certain threshold + # Here, we'll just verify the files are not empty. + import soundfile as sf + + def is_file_non_empty(file_path): + try: + data, _ = sf.read(file_path) + return data.size > 0 + except Exception as e: + return False + + assert is_file_non_empty(source_1_path), "source_1.wav is empty." + assert is_file_non_empty(source_2_path), "source_2.wav is empty." + + print("Test passed: source_1.wav and source_2.wav are present and non-empty.") diff --git a/tests/test_separation.py b/tests/test_demucs_wrapper.py similarity index 100% rename from tests/test_separation.py rename to tests/test_demucs_wrapper.py