Use Speaker Diarization with Async Chunking
This guide uses AssemblyAI and Nvidia’s NeMo framework. We’ll be using TitaNet, a state of the art open source model that is trained for speaker recognition tasks. TitaNet will allow us to generate audio embeddings for speakers, which can be used to identify semantic similarity matches between two speakers.
Quickstart
1 import assemblyai as aai 2 import requests 3 import json 4 import time 5 import requests 6 import copy 7 from pydub import AudioSegment 8 import os 9 import nemo.collections.asr as nemo_asr 10 from pydub import AudioSegment 11 12 speaker_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained("nvidia/speakerverification_en_titanet_large") 13 14 assemblyai_key = "YOUR_API_KEY" 15 16 headers = { 17 "authorization": assemblyai_key 18 } 19 20 def get_transcript(transcript_id): 21 polling_endpoint = f"https://api.assemblyai.com/v2/transcript/{transcript_id}" 22 23 while True: 24 transcription_result = requests.get(polling_endpoint, headers=headers).json() 25 26 if transcription_result['status'] == 'completed': 27 # print("Transcript ID:", transcript_id) 28 return(transcription_result) 29 break 30 31 elif transcription_result['status'] == 'error': 32 raise RuntimeError(f"Transcription failed: {transcription_result['error']}") 33 34 else: 35 time.sleep(3) 36 37 def download_wav(presigned_url, output_filename): 38 # Download the WAV file from the presigned URL 39 response = requests.get(presigned_url) 40 if response.status_code == 200: 41 print("downloading...") 42 with open(output_filename, 'wb') as f: 43 f.write(response.content) 44 print("successfully downloaded file:", output_filename) 45 else: 46 raise Exception("Failed to download file, status code: {}".format(response.status_code)) 47 48 # Function to identify the longest monologue of each speaker from each clip 49 # you pass in the utterances and it returns the longest monologue from each speaker on that file 50 def find_longest_monologues(utterances): 51 longest_monologues = {} 52 current_monologue = {} 53 last_speaker = None # Track the last speaker to identify interruptions 54 55 for utterance in utterances: 56 speaker = utterance['speaker'] 57 start_time = utterance['start'] 58 end_time = utterance['end'] 59 60 if speaker not in current_monologue: 61 current_monologue[speaker] = {"start": start_time, "end": end_time} 62 longest_monologues[speaker] = [] 63 else: 64 # Extend monologue only if it's the same speaker speaking continuously 65 if current_monologue[speaker]["end"] == start_time and last_speaker == speaker: 66 current_monologue[speaker]["end"] = end_time 67 else: 68 monologue_length = current_monologue[speaker]["end"] - current_monologue[speaker]["start"] 69 new_entry = (monologue_length, copy.deepcopy(current_monologue[speaker])) 70 71 if len(longest_monologues[speaker]) < 1 or monologue_length > min(longest_monologues[speaker], key=lambda x: x[0])[0]: 72 if len(longest_monologues[speaker]) == 1: 73 longest_monologues[speaker].remove(min(longest_monologues[speaker], key=lambda x: x[0])) 74 75 longest_monologues[speaker].append(new_entry) 76 77 current_monologue[speaker] = {"start": start_time, "end": end_time} 78 79 last_speaker = speaker # Update the last speaker 80 81 # Check the last monologue for each speaker 82 for speaker, monologue in current_monologue.items(): 83 monologue_length = monologue["end"] - monologue["start"] 84 new_entry = (monologue_length, monologue) 85 if len(longest_monologues[speaker]) < 1 or monologue_length > min(longest_monologues[speaker], key=lambda x: x[0])[0]: 86 if len(longest_monologues[speaker]) == 1: 87 longest_monologues[speaker].remove(min(longest_monologues[speaker], key=lambda x: x[0])) 88 longest_monologues[speaker].append(new_entry) 89 90 return longest_monologues 91 92 # Create clips of each long monologue and embed the clip 93 # you pass in the file path and the longest monologue objects returned by the find_longest_monologues function. 94 # This function will create new audio file clips which contain only the longest monologue from each speaker 95 def clip_and_store_utterances(audio_file, longest_monologues): 96 # Load the full conversation audio 97 full_audio = AudioSegment.from_wav(audio_file) 98 full_audio = full_audio.set_channels(1) 99 100 utterance_clips = [] 101 102 for speaker, monologues in longest_monologues.items(): 103 for _, monologue in monologues: 104 start_ms = monologue['start'] 105 end_ms = monologue['end'] 106 clip = full_audio[start_ms:end_ms] 107 clip_filename = f"{speaker}_monologue_{start_ms}_{end_ms}.wav" 108 clip.export(clip_filename, format="wav") 109 110 utterance_clips.append({ 111 'clip_filename': clip_filename, 112 'start': start_ms, 113 'end': end_ms, 114 'speaker': speaker 115 }) 116 117 print("Total Number of Monologue Clips Found: ", len(utterance_clips)) 118 119 return utterance_clips 120 121 # This function uses NeMO to compare two files 122 def compare_embeddings(utterance_clip, reference_file): 123 verification_result = speaker_model.verify_speakers( 124 utterance_clip, 125 reference_file 126 ) 127 return verification_result 128 129 file_one = "YOUR_FILE_1" 130 file_two = "YOUR_FILE_2" 131 file_three = "YOUR_FILE_3" 132 133 download_wav(file_one, "testone.wav") 134 download_wav(file_two, "testtwo.wav") 135 download_wav(file_three, "testthree.wav") 136 137 # Store utterances from each clip, keyed by clip index 138 clip_utterances = {} 139 140 # Dictionary to track known speaker identities across all clips 141 # Maps current clip speaker labels to a unified speaker label 142 speaker_identity_map = {} 143 144 def process_clips(clip_transcript_ids, audio_files): 145 global clip_utterances, speaker_identity_map 146 147 # This will store the longest clip filenames for each speaker from the previous clips 148 previous_speaker_clips = {} 149 150 for clip_index, (transcript_id, audio_file) in enumerate(zip(clip_transcript_ids, audio_files)): 151 transcript = get_transcript(transcript_id) 152 utterances = transcript['utterances'] 153 clip_utterances[clip_index] = utterances # Store utterances for the current clip 154 155 longest_monologues = find_longest_monologues(utterances) 156 157 # Process the longest monologues for clipping and storing 158 current_speaker_clips = {} 159 for speaker, monologue_data in longest_monologues.items(): 160 clip_and_store_utterances(audio_file, {speaker: monologue_data}) 161 longest_clip = f"{speaker}_monologue_{monologue_data[0][1]['start']}_{monologue_data[0][1]['end']}.wav" 162 current_speaker_clips[speaker] = longest_clip 163 164 if clip_index == 0: 165 speaker_identity_map = {speaker: speaker for speaker in longest_monologues.keys()} 166 previous_speaker_clips = current_speaker_clips.copy() 167 else: 168 # Compare all new speakers against all base speakers from previous clips 169 for new_speaker, new_clip in current_speaker_clips.items(): 170 for base_speaker, base_clip in previous_speaker_clips.items(): 171 if compare_embeddings(new_clip, base_clip): 172 speaker_identity_map[new_speaker] = base_speaker 173 break 174 else: 175 # If no match is found, assign a new label 176 new_label = chr(ord(max(speaker_identity_map.values(), key=lambda x: ord(x))) + 1) 177 speaker_identity_map[new_speaker] = new_label 178 179 # Update the previous_speaker_clips for the next iteration 180 previous_speaker_clips.update(current_speaker_clips) 181 182 # Update utterances with the new speaker labels for the current clip 183 for utterance in clip_utterances[clip_index]: 184 original_speaker = utterance['speaker'] 185 # Update only if there's a change in speaker identity 186 if original_speaker in speaker_identity_map: 187 utterance['speaker'] = speaker_identity_map[original_speaker] 188 189 190 # Add your clip transcript IDs 191 clip_transcript_ids = [ 192 "YOUR_TRANSCRIPT_ID_1", 193 "YOUR_TRANSCRIPT_ID_2", 194 "YOUR_TRANSCRIPT_ID_3" 195 ] 196 197 # Add filepaths to your downloaded files 198 audio_files = [ 199 "/testone.wav", 200 "/testtwo.wav", 201 "/testthree.wav" 202 ] 203 204 process_clips(clip_transcript_ids, audio_files) 205 206 def display_transcript(transcript_data): 207 for clip_index, utterances in transcript_data.items(): 208 print(f"Clip {clip_index + 1}:") 209 for utterance in utterances: 210 speaker = utterance['speaker'] 211 text = utterance['text'] 212 print(f" Speaker {speaker}: {text}") 213 print("\n") # Add an extra newline for spacing between 214 215 216 display_transcript(clip_utterances)
Get Started
Before we begin, make sure you have an AssemblyAI account and an API key. You can sign up for an AssemblyAI account and get your API key from your dashboard.
Step-by-step instructions
Install Dependencies
$ pip install pytorch > pip install nemo_toolkit['all'] > pip install ffmpeg > pip install assemblyai
AssemblyAI Setup, Transcript Setup, and Load the Model Using NeMO
In this section, we’ll import dependencies and add functions to transcribe and store transcript IDs if needed.
1 import assemblyai as aai 2 import requests 3 import json 4 import time 5 import requests 6 import copy 7 from pydub import AudioSegment 8 import os 9 import nemo.collections.asr as nemo_asr 10 11 speaker_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained("nvidia/speakerverification_en_titanet_large") 12 13 assemblyai_key = "YOUR_API_KEY" 14 15 headers = { 16 "authorization": assemblyai_key 17 }
Helper Functions
The function below requests a transcript based on a transcript ID.
1 def get_transcript(transcript_id): 2 polling_endpoint = f"https://api.assemblyai.com/v2/transcript/{transcript_id}" 3 4 while True: 5 transcription_result = requests.get(polling_endpoint, headers=headers).json() 6 7 if transcription_result['status'] == 'completed': 8 # print("Transcript ID:", transcript_id) 9 return(transcription_result) 10 break 11 12 elif transcription_result['status'] == 'error': 13 raise RuntimeError(f"Transcription failed: {transcription_result['error']}") 14 15 else: 16 time.sleep(3)
Our main inference function will make use of these functions:
1 def download_wav(presigned_url, output_filename): 2 # Download the WAV file from the presigned URL 3 response = requests.get(presigned_url) 4 if response.status_code == 200: 5 print("downloading...") 6 with open(output_filename, 'wb') as f: 7 f.write(response.content) 8 print("successfully downloaded file:", output_filename) 9 else: 10 raise Exception("Failed to download file, status code: {}".format(response.status_code)) 11 12 # Function to identify the longest monologue of each speaker from each clip 13 # you pass in the utterances and it returns the longest monologue from each speaker on that file 14 def find_longest_monologues(utterances): 15 longest_monologues = {} 16 current_monologue = {} 17 last_speaker = None # Track the last speaker to identify interruptions 18 19 for utterance in utterances: 20 speaker = utterance['speaker'] 21 start_time = utterance['start'] 22 end_time = utterance['end'] 23 24 if speaker not in current_monologue: 25 current_monologue[speaker] = {"start": start_time, "end": end_time} 26 longest_monologues[speaker] = [] 27 else: 28 # Extend monologue only if it's the same speaker speaking continuously 29 if current_monologue[speaker]["end"] == start_time and last_speaker == speaker: 30 current_monologue[speaker]["end"] = end_time 31 else: 32 monologue_length = current_monologue[speaker]["end"] - current_monologue[speaker]["start"] 33 new_entry = (monologue_length, copy.deepcopy(current_monologue[speaker])) 34 35 if len(longest_monologues[speaker]) < 1 or monologue_length > min(longest_monologues[speaker], key=lambda x: x[0])[0]: 36 if len(longest_monologues[speaker]) == 1: 37 longest_monologues[speaker].remove(min(longest_monologues[speaker], key=lambda x: x[0])) 38 39 longest_monologues[speaker].append(new_entry) 40 41 current_monologue[speaker] = {"start": start_time, "end": end_time} 42 43 last_speaker = speaker # Update the last speaker 44 45 # Check the last monologue for each speaker 46 for speaker, monologue in current_monologue.items(): 47 monologue_length = monologue["end"] - monologue["start"] 48 new_entry = (monologue_length, monologue) 49 if len(longest_monologues[speaker]) < 1 or monologue_length > min(longest_monologues[speaker], key=lambda x: x[0])[0]: 50 if len(longest_monologues[speaker]) == 1: 51 longest_monologues[speaker].remove(min(longest_monologues[speaker], key=lambda x: x[0])) 52 longest_monologues[speaker].append(new_entry) 53 54 return longest_monologues 55 56 # Create clips of each long monologue and embed the clip 57 # you pass in the file path and the longest monologue objects returned by the find_longest_monologues function. 58 # This function will create new audio file clips which contain only the longest monologue from each speaker 59 def clip_and_store_utterances(audio_file, longest_monologues): 60 # Load the full conversation audio 61 full_audio = AudioSegment.from_wav(audio_file) 62 full_audio = full_audio.set_channels(1) 63 64 utterance_clips = [] 65 66 for speaker, monologues in longest_monologues.items(): 67 for _, monologue in monologues: 68 start_ms = monologue['start'] 69 end_ms = monologue['end'] 70 clip = full_audio[start_ms:end_ms] 71 clip_filename = f"{speaker}_monologue_{start_ms}_{end_ms}.wav" 72 clip.export(clip_filename, format="wav") 73 74 utterance_clips.append({ 75 'clip_filename': clip_filename, 76 'start': start_ms, 77 'end': end_ms, 78 'speaker': speaker 79 }) 80 81 print("Total Number of Monologue Clips Found: ", len(utterance_clips)) 82 83 return utterance_clips 84 85 # This function uses NeMO to compare two files 86 def compare_embeddings(utterance_clip, reference_file): 87 verification_result = speaker_model.verify_speakers( 88 utterance_clip, 89 reference_file 90 ) 91 return verification_result
Inference
Add the links to the WAV file clips you have stored on your server.
1 file_one = "YOUR_FILE_1" 2 file_two = "YOUR_FILE_2" 3 file_three = "YOUR_FILE_3" 4 5 download_wav(file_one, "testone.wav") 6 download_wav(file_two, "testtwo.wav") 7 download_wav(file_three, "testthree.wav")
1 from pydub import AudioSegment 2 3 # Store utterances from each clip, keyed by clip index 4 clip_utterances = {} 5 6 # Dictionary to track known speaker identities across all clips 7 # Maps current clip speaker labels to a unified speaker label 8 speaker_identity_map = {} 9 10 def process_clips(clip_transcript_ids, audio_files): 11 global clip_utterances, speaker_identity_map 12 13 # This will store the longest clip filenames for each speaker from the previous clips 14 previous_speaker_clips = {} 15 16 for clip_index, (transcript_id, audio_file) in enumerate(zip(clip_transcript_ids, audio_files)): 17 transcript = get_transcript(transcript_id) 18 utterances = transcript['utterances'] 19 clip_utterances[clip_index] = utterances # Store utterances for the current clip 20 21 longest_monologues = find_longest_monologues(utterances) 22 23 # Process the longest monologues for clipping and storing 24 current_speaker_clips = {} 25 for speaker, monologue_data in longest_monologues.items(): 26 clip_and_store_utterances(audio_file, {speaker: monologue_data}) 27 longest_clip = f"{speaker}_monologue_{monologue_data[0][1]['start']}_{monologue_data[0][1]['end']}.wav" 28 current_speaker_clips[speaker] = longest_clip 29 30 if clip_index == 0: 31 speaker_identity_map = {speaker: speaker for speaker in longest_monologues.keys()} 32 previous_speaker_clips = current_speaker_clips.copy() 33 else: 34 # Compare all new speakers against all base speakers from previous clips 35 for new_speaker, new_clip in current_speaker_clips.items(): 36 for base_speaker, base_clip in previous_speaker_clips.items(): 37 if compare_embeddings(new_clip, base_clip): 38 speaker_identity_map[new_speaker] = base_speaker 39 break 40 else: 41 # If no match is found, assign a new label 42 new_label = chr(ord(max(speaker_identity_map.values(), key=lambda x: ord(x))) + 1) 43 speaker_identity_map[new_speaker] = new_label 44 45 # Update the previous_speaker_clips for the next iteration 46 previous_speaker_clips.update(current_speaker_clips) 47 48 # Update utterances with the new speaker labels for the current clip 49 for utterance in clip_utterances[clip_index]: 50 original_speaker = utterance['speaker'] 51 # Update only if there's a change in speaker identity 52 if original_speaker in speaker_identity_map: 53 utterance['speaker'] = speaker_identity_map[original_speaker] 54 55 56 # Add your clip transcript IDs 57 clip_transcript_ids = [ 58 "YOUR_TRANSCRIPT_ID_1", 59 "YOUR_TRANSCRIPT_ID_2", 60 "YOUR_TRANSCRIPT_ID_3" 61 ] 62 63 # Add filepaths to your downloaded files 64 audio_files = [ 65 "/testone.wav", 66 "/testtwo.wav", 67 "/testthree.wav" 68 ] 69 70 process_clips(clip_transcript_ids, audio_files)
New Clip Utterances
The clip utterances returned by the process_clips
function will contain the corrected utterances, which can be seen by printing out the utterances or by using the display_transcript
function.
1 print(clip_utterances)
1 def display_transcript(transcript_data): 2 for clip_index, utterances in transcript_data.items(): 3 print(f"Clip {clip_index + 1}:") 4 for utterance in utterances: 5 speaker = utterance['speaker'] 6 text = utterance['text'] 7 print(f" Speaker {speaker}: {text}") 8 print("\n") # Add an extra newline for spacing between 9 10 11 display_transcript(clip_utterances)