diff --git a/src/postprocessing/audio_writer.py b/src/postprocessing/audio_writer.py index ba8228c..2fd7da0 100644 --- a/src/postprocessing/audio_writer.py +++ b/src/postprocessing/audio_writer.py @@ -1,4 +1,5 @@ import soundfile as sf +import numpy as np def export_audio(audio, output_path, sr): """ @@ -11,7 +12,25 @@ def export_audio(audio, output_path, sr): """ try: - sf.write(output_path, audio, sr) + + print(f"Initial audio shape: {audio.shape}, dtype: {audio.dtype}") + + if audio.ndim == 2 and audio.shape[0] == 2: + # Transpose stereo audio to match the expected shape + audio = audio.T # From (2, num_samples) to (num_samples, 2) + + # Ensure the audio data type is float32 + audio = audio.astype('float32') + + # Normalize audio to avoid distortion + if np.max(np.abs(audio)) > 0: # Avoid divide by zero + audio = audio / np.max(np.abs(audio)) + + # Verify final format + print(f"Final audio shape: {audio.shape}, dtype: {audio.dtype}, max: {np.max(audio)}, min: {np.min(audio)}") + + + sf.write(output_path, audio, sr, format='wav') print(f"Audio saved to {output_path}") except Exception as e: print(f"Error saving audio: {e}") \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000..9ed305e --- /dev/null +++ b/test.py @@ -0,0 +1,33 @@ +import os +from src.input.file_reader import read_audio +from src.separation.convtasnet_wrapper import separate +from src.postprocessing.audio_writer import export_audio + +def main(input_audio_path, output_dir, model_name="mpariente/ConvTasNet_WHAM!_sepclean"): + try: + audio, sr = read_audio(input_audio_path) + print(f"Loaded audio from {input_audio_path} with sampling rate {sr}") + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + + separated_sources = separate(audio, model_name) + print(f"Separated {len(separated_sources)} sources") + print("Separated sources are", separated_sources) + + for i, source in enumerate(separated_sources): + output_path = os.path.join(output_dir, f"source_{i+1}.wav") + export_audio(source, output_path, sr) + print(f"Exported separated source {i + 1} to {output_path}") + + except Exception as e: + print(f"Error in processing: {e}") + + +if __name__ == "__main__": + input_audio_path = "/home/joel/Downloads/female_female_speech.wav" + output_dir = "/tmp/convtasnet" + main(input_audio_path, output_dir) + + \ No newline at end of file diff --git a/tests/test_audio/noise.wav b/tests/test_audio/noise.wav new file mode 100644 index 0000000..6e28687 Binary files /dev/null and b/tests/test_audio/noise.wav differ