Skip to content

Commit 3c4a8e4

Browse files
Nithin-Hollafacebook-github-bot
authored andcommitted
Enabling word-level timestamps for Wav2Vec 2.0 (#3627)
Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes #3371. Currently, the output from Wav2Vec 2.0 decoding does not contain word-level start/end times, which can be useful for certain applications of ASR. Based on the discussion [here](flashlight/flashlight#618), they could be computed based on the output from the Flashlight decoder. For the KenLM decoder, we could first obtain the frame number corresponding to each non-blank token. Next, the timestamp of each character could be computed as `segment_start + frame_no/total_frames * segment_duration`. Finally, the start and end time of each word could be calculated based on the timestamp of the word boundary characters. In order to enable this, the frame number of each non-blank character is returned as a result of KenLM decoding. This is similar to the `timesteps` output from the [ctcdecode](https://github.com/parlance/ctcdecode#outputs-from-the-decode-method) library. ## PR review alexeib Pull Request resolved: #3627 Reviewed By: michaelauli Differential Revision: D29282488 Pulled By: alexeib fbshipit-source-id: b5fe64bf50abd7ef8e9539f4e338937c866eb0ca
1 parent 900a607 commit 3c4a8e4

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

Diff for: examples/speech_recognition/new/decoders/flashlight_decoder.py

+22
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,27 @@ def __init__(self, cfg: FlashlightDecoderConfig, tgt_dict: Dictionary) -> None:
118118
self.decoder_opts, self.lm, self.silence, self.blank, []
119119
)
120120

121+
def get_timesteps(self, token_idxs: List[int]) -> List[int]:
122+
"""Returns frame numbers corresponding to every non-blank token.
123+
124+
Parameters
125+
----------
126+
token_idxs : List[int]
127+
IDs of decoded tokens.
128+
129+
Returns
130+
-------
131+
List[int]
132+
Frame numbers corresponding to every non-blank token.
133+
"""
134+
timesteps = []
135+
for i, token_idx in enumerate(token_idxs):
136+
if token_idx == self.blank:
137+
continue
138+
if i == 0 or token_idx != token_idxs[i-1]:
139+
timesteps.append(i)
140+
return timesteps
141+
121142
def decode(
122143
self,
123144
emissions: torch.FloatTensor,
@@ -134,6 +155,7 @@ def decode(
134155
{
135156
"tokens": self.get_tokens(result.tokens),
136157
"score": result.score,
158+
"timesteps": self.get_timesteps(result.tokens),
137159
"words": [
138160
self.word_dict.get_entry(x) for x in result.words if x >= 0
139161
],

Diff for: examples/speech_recognition/w2l_decoder.py

+22
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import gc
1313
import itertools as it
1414
import os.path as osp
15+
from typing import List
1516
import warnings
1617
from collections import deque, namedtuple
1718

@@ -194,6 +195,26 @@ def __init__(self, args, tgt_dict):
194195
self.decoder_opts, self.lm, self.silence, self.blank, []
195196
)
196197

198+
def get_timesteps(self, token_idxs: List[int]) -> List[int]:
199+
"""Returns frame numbers corresponding to every non-blank token.
200+
201+
Parameters
202+
----------
203+
token_idxs : List[int]
204+
IDs of decoded tokens.
205+
206+
Returns
207+
-------
208+
List[int]
209+
Frame numbers corresponding to every non-blank token.
210+
"""
211+
timesteps = []
212+
for i, token_idx in enumerate(token_idxs):
213+
if token_idx == self.blank:
214+
continue
215+
if i == 0 or token_idx != token_idxs[i-1]:
216+
timesteps.append(i)
217+
return timesteps
197218

198219
def decode(self, emissions):
199220
B, T, N = emissions.size()
@@ -208,6 +229,7 @@ def decode(self, emissions):
208229
{
209230
"tokens": self.get_tokens(result.tokens),
210231
"score": result.score,
232+
"timesteps": self.get_timesteps(result.tokens),
211233
"words": [
212234
self.word_dict.get_entry(x) for x in result.words if x >= 0
213235
],

0 commit comments

Comments
 (0)