Fix: Corrected shape handling for separated sources, extracting each source as a separate 1D array

This commit is contained in:
Joel Mathew Thomas
2025-01-28 23:34:08 +05:30
parent 9e7c01c19e
commit fae60a8120
4 changed files with 9 additions and 6 deletions
+7 -3
View File
@@ -27,9 +27,13 @@ def separate(audio, model_name='mpariente/ConvTasNet_WHAM!_sepclean'):
with torch.no_grad(): with torch.no_grad():
separated_sources = model(audio_tensor) separated_sources = model(audio_tensor)
# Convert separated sources to NumPy arrays and remove batch dimension # Remove batch dimension
separated_sources_np = [src.squeeze(0).cpu().numpy() for src in separated_sources] separated_sources = separated_sources.squeeze(0) # Shape: (num_sources, num_samples)
return separated_sources_np # Split into list of sources
separated_sources_np = separated_sources.cpu().numpy() # Convert to NumPy
separated_sources_list = [separated_sources_np[i, :] for i in range(separated_sources_np.shape[0])]
return separated_sources_list
except Exception as e: except Exception as e:
raise RuntimeError(f"Error during separation: {e}") raise RuntimeError(f"Error during separation: {e}")
+2 -3
View File
@@ -14,8 +14,7 @@ def main(input_audio_path, output_dir, model_name="mpariente/ConvTasNet_WHAM!_se
separated_sources = separate(audio, model_name) separated_sources = separate(audio, model_name)
print(f"Separated {len(separated_sources)} sources") print(f"Separated {len(separated_sources)} sources")
print("Separated sources are", separated_sources)
for i, source in enumerate(separated_sources): for i, source in enumerate(separated_sources):
output_path = os.path.join(output_dir, f"source_{i+1}.wav") output_path = os.path.join(output_dir, f"source_{i+1}.wav")
export_audio(source, output_path, sr) export_audio(source, output_path, sr)
@@ -26,7 +25,7 @@ def main(input_audio_path, output_dir, model_name="mpariente/ConvTasNet_WHAM!_se
if __name__ == "__main__": if __name__ == "__main__":
input_audio_path = "/home/joel/Downloads/female_female_speech.wav" input_audio_path = "/home/joel/Downloads/wham/female-female-mixture.wav"
output_dir = "/tmp/convtasnet" output_dir = "/tmp/convtasnet"
main(input_audio_path, output_dir) main(input_audio_path, output_dir)
Binary file not shown.
Binary file not shown.