Implement audio classification function using YAMNet
- Create a function to load audio, resample, and classify using YAMNet - Ensure compatibility with different audio formats and sample rates - Normalize audio and classify it into one of 600 categories
This commit is contained in:
@@ -56,6 +56,7 @@ python-dateutil==2.9.0.post0
|
|||||||
requests==2.32.3
|
requests==2.32.3
|
||||||
rich==13.9.4
|
rich==13.9.4
|
||||||
scikit-learn==1.6.0
|
scikit-learn==1.6.0
|
||||||
|
scipy==1.14.1
|
||||||
setuptools==75.6.0
|
setuptools==75.6.0
|
||||||
six==1.17.0
|
six==1.17.0
|
||||||
soundfile==0.12.1
|
soundfile==0.12.1
|
||||||
|
|||||||
@@ -14,31 +14,35 @@ def class_names_from_csv(class_map_scv_text):
|
|||||||
reader = csv.DictReader(csvfile)
|
reader = csv.DictReader(csvfile)
|
||||||
for row in reader:
|
for row in reader:
|
||||||
class_names.append(row['display_name'])
|
class_names.append(row['display_name'])
|
||||||
|
|
||||||
return class_names
|
return class_names
|
||||||
|
|
||||||
|
# Main function to process audio and classify
|
||||||
|
def classify_audio(file_path):
|
||||||
|
"""
|
||||||
|
Given an audio file, this function loads the audio, resamples it,
|
||||||
|
normalizes it, and runs it through the YAMNet model to classify the sound.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- file_path (str): Path to the audio file (WAV, MP3, etc.).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- str: Predicted class label of the audio.
|
||||||
|
"""
|
||||||
|
# Load audio using librosa (this handles both loading, resampling, and conversion to mono)
|
||||||
|
waveform, sample_rate = librosa.load(file_path, sr=16000, mono=True) # Ensuring 16k sample rate and mono
|
||||||
|
|
||||||
|
# Normalize the waveform to [-1.0, 1.0] (librosa already returns normalized values)
|
||||||
|
waveform = waveform / np.max(np.abs(waveform))
|
||||||
|
|
||||||
|
# Execute the YAMNet model
|
||||||
|
scores, embeddings, spectrogram = model(waveform)
|
||||||
|
|
||||||
|
# Extract the class names from the model
|
||||||
class_map_path = model.class_map_path().numpy()
|
class_map_path = model.class_map_path().numpy()
|
||||||
class_names = class_names_from_csv(class_map_path)
|
class_names = class_names_from_csv(class_map_path)
|
||||||
|
|
||||||
wav_file_name = 'cafe_crowd_talk.wav'
|
# Find the class with the highest score
|
||||||
waveform, sample_rate = librosa.load(wav_file_name, sr=16000)
|
|
||||||
|
|
||||||
# Show some basic information about the audio.
|
|
||||||
duration = len(waveform)/sample_rate
|
|
||||||
print(f'Sample rate: {sample_rate} Hz')
|
|
||||||
print(f'Total duration: {duration:.2f}s')
|
|
||||||
print(f'Size of the input: {len(waveform)}')
|
|
||||||
|
|
||||||
# The waveform needs to be normalized to values in [-1.0, 1.0] (librosa load already does this)
|
|
||||||
# No need to do this as librosa already normalizes# The wav_data needs to be normalized to values in [-1.0, 1.0]
|
|
||||||
|
|
||||||
# Execute the Model
|
|
||||||
# Check the output.
|
|
||||||
scores, embeddings, spectogram = model(waveform)
|
|
||||||
scores_np = scores.numpy()
|
scores_np = scores.numpy()
|
||||||
spectogram_np = spectogram.numpy()
|
inferred_class = class_names[scores_np.mean(axis=0).argmax()]
|
||||||
infered_class = class_names[scores_np.mean(axis=0).argmax()]
|
|
||||||
print(f'The main sound is : {infered_class}')
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
return inferred_class
|
||||||
Reference in New Issue
Block a user