diff --git a/src/separation/convtasnet_wrapper.py b/src/separation/convtasnet_wrapper.py index a277762..ae52b38 100644 --- a/src/separation/convtasnet_wrapper.py +++ b/src/separation/convtasnet_wrapper.py @@ -27,9 +27,13 @@ def separate(audio, model_name='mpariente/ConvTasNet_WHAM!_sepclean'): with torch.no_grad(): separated_sources = model(audio_tensor) - # Convert separated sources to NumPy arrays and remove batch dimension - separated_sources_np = [src.squeeze(0).cpu().numpy() for src in separated_sources] + # Remove batch dimension + separated_sources = separated_sources.squeeze(0) # Shape: (num_sources, num_samples) - return separated_sources_np + # Split into list of sources + separated_sources_np = separated_sources.cpu().numpy() # Convert to NumPy + separated_sources_list = [separated_sources_np[i, :] for i in range(separated_sources_np.shape[0])] + + return separated_sources_list except Exception as e: raise RuntimeError(f"Error during separation: {e}") diff --git a/test.py b/test.py index 9ed305e..272e3d8 100644 --- a/test.py +++ b/test.py @@ -14,8 +14,7 @@ def main(input_audio_path, output_dir, model_name="mpariente/ConvTasNet_WHAM!_se separated_sources = separate(audio, model_name) print(f"Separated {len(separated_sources)} sources") - print("Separated sources are", separated_sources) - + for i, source in enumerate(separated_sources): output_path = os.path.join(output_dir, f"source_{i+1}.wav") export_audio(source, output_path, sr) @@ -26,7 +25,7 @@ def main(input_audio_path, output_dir, model_name="mpariente/ConvTasNet_WHAM!_se if __name__ == "__main__": - input_audio_path = "/home/joel/Downloads/female_female_speech.wav" + input_audio_path = "/home/joel/Downloads/wham/female-female-mixture.wav" output_dir = "/tmp/convtasnet" main(input_audio_path, output_dir) diff --git a/tests/test_audio/female-female-mixture.wav b/tests/test_audio/female-female-mixture.wav new file mode 100644 index 0000000..2137997 Binary files /dev/null and b/tests/test_audio/female-female-mixture.wav differ diff --git a/tests/test_audio/noise.wav b/tests/test_audio/noise.wav deleted file mode 100644 index 6e28687..0000000 Binary files a/tests/test_audio/noise.wav and /dev/null differ