Use Speaker Diarization with Async Chunking

This guide uses AssemblyAI and Nvidia’s NeMo framework. We’ll be using TitaNet, a state of the art open source model that is trained for speaker recognition tasks. TitaNet will allow us to generate audio embeddings for speakers, which can be used to identify semantic similarity matches between two speakers.

Quickstart

1import assemblyai as aai
2import requests
3import json
4import time
5import requests
6import copy
7from pydub import AudioSegment
8import os
9import nemo.collections.asr as nemo_asr
10from pydub import AudioSegment
11
12speaker_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained("nvidia/speakerverification_en_titanet_large")
13
14assemblyai_key = "YOUR_API_KEY"
15
16headers = {
17 "authorization": assemblyai_key
18}
19
20def get_transcript(transcript_id):
21 polling_endpoint = f"https://api.assemblyai.com/v2/transcript/{transcript_id}"
22
23 while True:
24 transcription_result = requests.get(polling_endpoint, headers=headers).json()
25
26 if transcription_result['status'] == 'completed':
27 # print("Transcript ID:", transcript_id)
28 return(transcription_result)
29 break
30
31 elif transcription_result['status'] == 'error':
32 raise RuntimeError(f"Transcription failed: {transcription_result['error']}")
33
34 else:
35 time.sleep(3)
36
37def download_wav(presigned_url, output_filename):
38 # Download the WAV file from the presigned URL
39 response = requests.get(presigned_url)
40 if response.status_code == 200:
41 print("downloading...")
42 with open(output_filename, 'wb') as f:
43 f.write(response.content)
44 print("successfully downloaded file:", output_filename)
45 else:
46 raise Exception("Failed to download file, status code: {}".format(response.status_code))
47
48# Function to identify the longest monologue of each speaker from each clip
49# you pass in the utterances and it returns the longest monologue from each speaker on that file
50def find_longest_monologues(utterances):
51 longest_monologues = {}
52 current_monologue = {}
53 last_speaker = None # Track the last speaker to identify interruptions
54
55 for utterance in utterances:
56 speaker = utterance['speaker']
57 start_time = utterance['start']
58 end_time = utterance['end']
59
60 if speaker not in current_monologue:
61 current_monologue[speaker] = {"start": start_time, "end": end_time}
62 longest_monologues[speaker] = []
63 else:
64 # Extend monologue only if it's the same speaker speaking continuously
65 if current_monologue[speaker]["end"] == start_time and last_speaker == speaker:
66 current_monologue[speaker]["end"] = end_time
67 else:
68 monologue_length = current_monologue[speaker]["end"] - current_monologue[speaker]["start"]
69 new_entry = (monologue_length, copy.deepcopy(current_monologue[speaker]))
70
71 if len(longest_monologues[speaker]) < 1 or monologue_length > min(longest_monologues[speaker], key=lambda x: x[0])[0]:
72 if len(longest_monologues[speaker]) == 1:
73 longest_monologues[speaker].remove(min(longest_monologues[speaker], key=lambda x: x[0]))
74
75 longest_monologues[speaker].append(new_entry)
76
77 current_monologue[speaker] = {"start": start_time, "end": end_time}
78
79 last_speaker = speaker # Update the last speaker
80
81 # Check the last monologue for each speaker
82 for speaker, monologue in current_monologue.items():
83 monologue_length = monologue["end"] - monologue["start"]
84 new_entry = (monologue_length, monologue)
85 if len(longest_monologues[speaker]) < 1 or monologue_length > min(longest_monologues[speaker], key=lambda x: x[0])[0]:
86 if len(longest_monologues[speaker]) == 1:
87 longest_monologues[speaker].remove(min(longest_monologues[speaker], key=lambda x: x[0]))
88 longest_monologues[speaker].append(new_entry)
89
90 return longest_monologues
91
92# Create clips of each long monologue and embed the clip
93# you pass in the file path and the longest monologue objects returned by the find_longest_monologues function.
94# This function will create new audio file clips which contain only the longest monologue from each speaker
95def clip_and_store_utterances(audio_file, longest_monologues):
96 # Load the full conversation audio
97 full_audio = AudioSegment.from_wav(audio_file)
98 full_audio = full_audio.set_channels(1)
99
100 utterance_clips = []
101
102 for speaker, monologues in longest_monologues.items():
103 for _, monologue in monologues:
104 start_ms = monologue['start']
105 end_ms = monologue['end']
106 clip = full_audio[start_ms:end_ms]
107 clip_filename = f"{speaker}_monologue_{start_ms}_{end_ms}.wav"
108 clip.export(clip_filename, format="wav")
109
110 utterance_clips.append({
111 'clip_filename': clip_filename,
112 'start': start_ms,
113 'end': end_ms,
114 'speaker': speaker
115 })
116
117 print("Total Number of Monologue Clips Found: ", len(utterance_clips))
118
119 return utterance_clips
120
121# This function uses NeMO to compare two files
122def compare_embeddings(utterance_clip, reference_file):
123 verification_result = speaker_model.verify_speakers(
124 utterance_clip,
125 reference_file
126 )
127 return verification_result
128
129file_one = "YOUR_FILE_1"
130file_two = "YOUR_FILE_2"
131file_three = "YOUR_FILE_3"
132
133download_wav(file_one, "testone.wav")
134download_wav(file_two, "testtwo.wav")
135download_wav(file_three, "testthree.wav")
136
137# Store utterances from each clip, keyed by clip index
138clip_utterances = {}
139
140# Dictionary to track known speaker identities across all clips
141# Maps current clip speaker labels to a unified speaker label
142speaker_identity_map = {}
143
144def process_clips(clip_transcript_ids, audio_files):
145 global clip_utterances, speaker_identity_map
146
147 # This will store the longest clip filenames for each speaker from the previous clips
148 previous_speaker_clips = {}
149
150 for clip_index, (transcript_id, audio_file) in enumerate(zip(clip_transcript_ids, audio_files)):
151 transcript = get_transcript(transcript_id)
152 utterances = transcript['utterances']
153 clip_utterances[clip_index] = utterances # Store utterances for the current clip
154
155 longest_monologues = find_longest_monologues(utterances)
156
157 # Process the longest monologues for clipping and storing
158 current_speaker_clips = {}
159 for speaker, monologue_data in longest_monologues.items():
160 clip_and_store_utterances(audio_file, {speaker: monologue_data})
161 longest_clip = f"{speaker}_monologue_{monologue_data[0][1]['start']}_{monologue_data[0][1]['end']}.wav"
162 current_speaker_clips[speaker] = longest_clip
163
164 if clip_index == 0:
165 speaker_identity_map = {speaker: speaker for speaker in longest_monologues.keys()}
166 previous_speaker_clips = current_speaker_clips.copy()
167 else:
168 # Compare all new speakers against all base speakers from previous clips
169 for new_speaker, new_clip in current_speaker_clips.items():
170 for base_speaker, base_clip in previous_speaker_clips.items():
171 if compare_embeddings(new_clip, base_clip):
172 speaker_identity_map[new_speaker] = base_speaker
173 break
174 else:
175 # If no match is found, assign a new label
176 new_label = chr(ord(max(speaker_identity_map.values(), key=lambda x: ord(x))) + 1)
177 speaker_identity_map[new_speaker] = new_label
178
179 # Update the previous_speaker_clips for the next iteration
180 previous_speaker_clips.update(current_speaker_clips)
181
182 # Update utterances with the new speaker labels for the current clip
183 for utterance in clip_utterances[clip_index]:
184 original_speaker = utterance['speaker']
185 # Update only if there's a change in speaker identity
186 if original_speaker in speaker_identity_map:
187 utterance['speaker'] = speaker_identity_map[original_speaker]
188
189
190# Add your clip transcript IDs
191clip_transcript_ids = [
192 "YOUR_TRANSCRIPT_ID_1",
193 "YOUR_TRANSCRIPT_ID_2",
194 "YOUR_TRANSCRIPT_ID_3"
195]
196
197# Add filepaths to your downloaded files
198audio_files = [
199 "/testone.wav",
200 "/testtwo.wav",
201 "/testthree.wav"
202]
203
204process_clips(clip_transcript_ids, audio_files)
205
206def display_transcript(transcript_data):
207 for clip_index, utterances in transcript_data.items():
208 print(f"Clip {clip_index + 1}:")
209 for utterance in utterances:
210 speaker = utterance['speaker']
211 text = utterance['text']
212 print(f" Speaker {speaker}: {text}")
213 print("\n") # Add an extra newline for spacing between
214
215
216display_transcript(clip_utterances)

Get Started

Before we begin, make sure you have an AssemblyAI account and an API key. You can sign up for an AssemblyAI account and get your API key from your dashboard.

Step-by-step instructions

Install Dependencies

$pip install pytorch
>pip install nemo_toolkit['all']
>pip install ffmpeg
>pip install assemblyai

AssemblyAI Setup, Transcript Setup, and Load the Model Using NeMO

In this section, we’ll import dependencies and add functions to transcribe and store transcript IDs if needed.

1import assemblyai as aai
2import requests
3import json
4import time
5import requests
6import copy
7from pydub import AudioSegment
8import os
9import nemo.collections.asr as nemo_asr
10
11speaker_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained("nvidia/speakerverification_en_titanet_large")
12
13assemblyai_key = "YOUR_API_KEY"
14
15headers = {
16 "authorization": assemblyai_key
17}

Helper Functions

The function below requests a transcript based on a transcript ID.

1def get_transcript(transcript_id):
2 polling_endpoint = f"https://api.assemblyai.com/v2/transcript/{transcript_id}"
3
4 while True:
5 transcription_result = requests.get(polling_endpoint, headers=headers).json()
6
7 if transcription_result['status'] == 'completed':
8 # print("Transcript ID:", transcript_id)
9 return(transcription_result)
10 break
11
12 elif transcription_result['status'] == 'error':
13 raise RuntimeError(f"Transcription failed: {transcription_result['error']}")
14
15 else:
16 time.sleep(3)

Our main inference function will make use of these functions:

1def download_wav(presigned_url, output_filename):
2 # Download the WAV file from the presigned URL
3 response = requests.get(presigned_url)
4 if response.status_code == 200:
5 print("downloading...")
6 with open(output_filename, 'wb') as f:
7 f.write(response.content)
8 print("successfully downloaded file:", output_filename)
9 else:
10 raise Exception("Failed to download file, status code: {}".format(response.status_code))
11
12# Function to identify the longest monologue of each speaker from each clip
13# you pass in the utterances and it returns the longest monologue from each speaker on that file
14def find_longest_monologues(utterances):
15 longest_monologues = {}
16 current_monologue = {}
17 last_speaker = None # Track the last speaker to identify interruptions
18
19 for utterance in utterances:
20 speaker = utterance['speaker']
21 start_time = utterance['start']
22 end_time = utterance['end']
23
24 if speaker not in current_monologue:
25 current_monologue[speaker] = {"start": start_time, "end": end_time}
26 longest_monologues[speaker] = []
27 else:
28 # Extend monologue only if it's the same speaker speaking continuously
29 if current_monologue[speaker]["end"] == start_time and last_speaker == speaker:
30 current_monologue[speaker]["end"] = end_time
31 else:
32 monologue_length = current_monologue[speaker]["end"] - current_monologue[speaker]["start"]
33 new_entry = (monologue_length, copy.deepcopy(current_monologue[speaker]))
34
35 if len(longest_monologues[speaker]) < 1 or monologue_length > min(longest_monologues[speaker], key=lambda x: x[0])[0]:
36 if len(longest_monologues[speaker]) == 1:
37 longest_monologues[speaker].remove(min(longest_monologues[speaker], key=lambda x: x[0]))
38
39 longest_monologues[speaker].append(new_entry)
40
41 current_monologue[speaker] = {"start": start_time, "end": end_time}
42
43 last_speaker = speaker # Update the last speaker
44
45 # Check the last monologue for each speaker
46 for speaker, monologue in current_monologue.items():
47 monologue_length = monologue["end"] - monologue["start"]
48 new_entry = (monologue_length, monologue)
49 if len(longest_monologues[speaker]) < 1 or monologue_length > min(longest_monologues[speaker], key=lambda x: x[0])[0]:
50 if len(longest_monologues[speaker]) == 1:
51 longest_monologues[speaker].remove(min(longest_monologues[speaker], key=lambda x: x[0]))
52 longest_monologues[speaker].append(new_entry)
53
54 return longest_monologues
55
56# Create clips of each long monologue and embed the clip
57# you pass in the file path and the longest monologue objects returned by the find_longest_monologues function.
58# This function will create new audio file clips which contain only the longest monologue from each speaker
59def clip_and_store_utterances(audio_file, longest_monologues):
60 # Load the full conversation audio
61 full_audio = AudioSegment.from_wav(audio_file)
62 full_audio = full_audio.set_channels(1)
63
64 utterance_clips = []
65
66 for speaker, monologues in longest_monologues.items():
67 for _, monologue in monologues:
68 start_ms = monologue['start']
69 end_ms = monologue['end']
70 clip = full_audio[start_ms:end_ms]
71 clip_filename = f"{speaker}_monologue_{start_ms}_{end_ms}.wav"
72 clip.export(clip_filename, format="wav")
73
74 utterance_clips.append({
75 'clip_filename': clip_filename,
76 'start': start_ms,
77 'end': end_ms,
78 'speaker': speaker
79 })
80
81 print("Total Number of Monologue Clips Found: ", len(utterance_clips))
82
83 return utterance_clips
84
85# This function uses NeMO to compare two files
86def compare_embeddings(utterance_clip, reference_file):
87 verification_result = speaker_model.verify_speakers(
88 utterance_clip,
89 reference_file
90 )
91 return verification_result

Inference

Add the links to the WAV file clips you have stored on your server.

1file_one = "YOUR_FILE_1"
2file_two = "YOUR_FILE_2"
3file_three = "YOUR_FILE_3"
4
5download_wav(file_one, "testone.wav")
6download_wav(file_two, "testtwo.wav")
7download_wav(file_three, "testthree.wav")
1from pydub import AudioSegment
2
3# Store utterances from each clip, keyed by clip index
4clip_utterances = {}
5
6# Dictionary to track known speaker identities across all clips
7# Maps current clip speaker labels to a unified speaker label
8speaker_identity_map = {}
9
10def process_clips(clip_transcript_ids, audio_files):
11 global clip_utterances, speaker_identity_map
12
13 # This will store the longest clip filenames for each speaker from the previous clips
14 previous_speaker_clips = {}
15
16 for clip_index, (transcript_id, audio_file) in enumerate(zip(clip_transcript_ids, audio_files)):
17 transcript = get_transcript(transcript_id)
18 utterances = transcript['utterances']
19 clip_utterances[clip_index] = utterances # Store utterances for the current clip
20
21 longest_monologues = find_longest_monologues(utterances)
22
23 # Process the longest monologues for clipping and storing
24 current_speaker_clips = {}
25 for speaker, monologue_data in longest_monologues.items():
26 clip_and_store_utterances(audio_file, {speaker: monologue_data})
27 longest_clip = f"{speaker}_monologue_{monologue_data[0][1]['start']}_{monologue_data[0][1]['end']}.wav"
28 current_speaker_clips[speaker] = longest_clip
29
30 if clip_index == 0:
31 speaker_identity_map = {speaker: speaker for speaker in longest_monologues.keys()}
32 previous_speaker_clips = current_speaker_clips.copy()
33 else:
34 # Compare all new speakers against all base speakers from previous clips
35 for new_speaker, new_clip in current_speaker_clips.items():
36 for base_speaker, base_clip in previous_speaker_clips.items():
37 if compare_embeddings(new_clip, base_clip):
38 speaker_identity_map[new_speaker] = base_speaker
39 break
40 else:
41 # If no match is found, assign a new label
42 new_label = chr(ord(max(speaker_identity_map.values(), key=lambda x: ord(x))) + 1)
43 speaker_identity_map[new_speaker] = new_label
44
45 # Update the previous_speaker_clips for the next iteration
46 previous_speaker_clips.update(current_speaker_clips)
47
48 # Update utterances with the new speaker labels for the current clip
49 for utterance in clip_utterances[clip_index]:
50 original_speaker = utterance['speaker']
51 # Update only if there's a change in speaker identity
52 if original_speaker in speaker_identity_map:
53 utterance['speaker'] = speaker_identity_map[original_speaker]
54
55
56# Add your clip transcript IDs
57clip_transcript_ids = [
58 "YOUR_TRANSCRIPT_ID_1",
59 "YOUR_TRANSCRIPT_ID_2",
60 "YOUR_TRANSCRIPT_ID_3"
61]
62
63# Add filepaths to your downloaded files
64audio_files = [
65 "/testone.wav",
66 "/testtwo.wav",
67 "/testthree.wav"
68]
69
70process_clips(clip_transcript_ids, audio_files)

New Clip Utterances

The clip utterances returned by the process_clips function will contain the corrected utterances, which can be seen by printing out the utterances or by using the display_transcript function.

1print(clip_utterances)
1def display_transcript(transcript_data):
2 for clip_index, utterances in transcript_data.items():
3 print(f"Clip {clip_index + 1}:")
4 for utterance in utterances:
5 speaker = utterance['speaker']
6 text = utterance['text']
7 print(f" Speaker {speaker}: {text}")
8 print("\n") # Add an extra newline for spacing between
9
10
11display_transcript(clip_utterances)

Additional Resources