Skip to content

Enabling word-level timestamps for Wav2Vec 2.0 #3627

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

Nithin-Holla
Copy link
Contributor

@Nithin-Holla Nithin-Holla commented Jun 17, 2021

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
  • Did you read the contributor guideline?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Fixes #3371.

Currently, the output from Wav2Vec 2.0 decoding does not contain word-level start/end times, which can be useful for certain applications of ASR. Based on the discussion here, they could be computed based on the output from the Flashlight decoder. For the KenLM decoder, we could first obtain the frame number corresponding to each non-blank token. Next, the timestamp of each character could be computed as segment_start + frame_no/total_frames * segment_duration. Finally, the start and end time of each word could be calculated based on the timestamp of the word boundary characters. In order to enable this, the frame number of each non-blank character is returned as a result of KenLM decoding. This is similar to the timesteps output from the ctcdecode library.

PR review

@alexeib

@alexeib
Copy link
Contributor

alexeib commented Jun 20, 2021

thanks - could you also add it to decoders here, which are soon to replace the old ones you improved:

https://github.com/pytorch/fairseq/tree/master/examples/speech_recognition/new/decoders

@Nithin-Holla
Copy link
Contributor Author

Done! By the way, I only added this to the KenLM decoder. Do you think the same approach would work for FairseqLMDecoder?

@alexeib
Copy link
Contributor

alexeib commented Jun 22, 2021

Done! By the way, I only added this to the KenLM decoder. Do you think the same approach would work for FairseqLMDecoder?

yes, it should work for all decoders

@facebook-github-bot
Copy link
Contributor

@alexeib has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@alexeib merged this pull request in 3c4a8e4.

@harveenchadha
Copy link
Contributor

Hey @Nithin-Holla,

How to interpret these timesteps?

Are the timesteps in seconds?

@Nithin-Holla
Copy link
Contributor Author

@harveenchadha These timesteps are similar to the ones returned by the ctcdecode library. They indicate the frame number corresponding to each of the predicted characters. Suppose you have the start time and duration of an audio segment, one simple way to convert these timesteps into seconds that I can think of is segment_start + timestep/total_frames * segment_duration.

@Hemant6174
Copy link

Hemant6174 commented Jul 23, 2021

Hey @Nithin-Holla
How can we generate timestamps for the Viterbi decoder?
and can this method be used to interpret timestamps in seconds for that also?

@Nithin-Holla
Copy link
Contributor Author

@hpaliwal1225 I haven't checked it for the Viterbi decoder. But if the expression viterbi_path[b].tolist() is similar to result.tokens in the KenLM decoder, i.e., token predictions for every frame, then this approach should work for it too.

@micahjon
Copy link

@Nithin-Holla , thanks for this awesome contribution.

Could you share a working code example? I've been trying to work through some of the code in the speech recognition page but realize I'm in over my head. In particular, I'm not sure why I need to download and pre-process the entire LibriSpeech corpus to run inference on a tiny wav file.
https://github.com/pytorch/fairseq/tree/master/examples/speech_recognition

If you could share any code/example for running inference (with timings) on a small wav file, that'd be really helpful. Thanks much!

@shahzebali42
Copy link

@harveenchadha These timesteps are similar to the ones returned by the ctcdecode library. They indicate the frame number corresponding to each of the predicted characters. Suppose you have the start time and duration of an audio segment, one simple way to convert these timesteps into seconds that I can think of is segment_start + timestep/total_frames * segment_duration.

this segment_start + timestep/total_frames * segment_duration` only works when file duration is less than 60s. what to do if file is greater than 60s?

@abarcovschi
Copy link

Hi guys, could anyone please explain to me where I can get the values for segment_start, segment_duration and total_frames in the decoder? I can't find these values anywhere. Maybe I need to search elsewhere?

@abarcovschi
Copy link

abarcovschi commented Dec 17, 2023

I figured out how to get the word-level timestamps from the fairseq W2lDecoder subclasses' outputs. In examples/speech_recognition/w2l_decoder.py, I added the following code:

  1. in W2lDecoder.__init__(): self.symbols = tgt_dict.symbols # symbols (usually chars) understood by the ASR model, that are predicted in the emission matrix.
  2. W2lKenLMDecoder already had get_timesteps() method implemented, but I extended the functionality to all other decoders by adding the get_timesteps() function to the parents class W2lDecoder and also created a get_symbols() method:
def get_timesteps(self, token_idxs: List[int]) -> List[int]:
        """Returns frame numbers corresponding to every non-blank token.

        Parameters
        ----------
        token_idxs : List[int]
            IDs of decoded tokens (including blank tokens), i.e. list of tokens spanning all frames of the emission matrix.

        Returns
        -------
        List[int]
            Frame numbers corresponding to every non-blank token.
        """
        timesteps = []
        for i, token_idx in enumerate(token_idxs):
            if token_idx == self.blank:
                continue
            if i == 0 or token_idx != token_idxs[i-1]:
                timesteps.append(i)
                
        return timesteps

    def get_symbols(self, token_idxs: List[int]) -> List[int]:
        """Returns characters corresponding to every non-blank token.

        Parameters
        ----------
        token_idxs : List[int]
            IDs of non-blank tokens.

        Returns
        -------
        List[int]
            Character corresponding to every non-blank token.
        """
        chars = []
        for token_idx in token_idxs:
            chars.append(self.symbols[token_idx])

        return chars
  1. I slightly modified the code of the decode() method of each decoder to ensure the character symbols, timesteps corresponding to each symbol and list of words is returned in the hypos object.
  • The new decode() method for W2lViterbiDecoder:
    def decode(self, emissions):
        B, T, N = emissions.size()
        hypos = []
        if self.asg_transitions is None:
            transitions = torch.FloatTensor(N, N).zero_()
        else:
            transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
        viterbi_path = torch.IntTensor(B, T)
        workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
        CpuViterbiPath.compute(
            B,
            T,
            N,
            get_data_ptr_as_bytes(emissions),
            get_data_ptr_as_bytes(transitions),
            get_data_ptr_as_bytes(viterbi_path),
            get_data_ptr_as_bytes(workspace),
        )

        for b in range(B):
            tokens = self.get_tokens(viterbi_path[b].tolist()).tolist()
            hypos.append(
                [
                    {
                        "tokens": tokens,  # non-blank token idxs.
                        "symbols": self.get_symbols(
                            tokens
                        ),  # characters (symbols) corresponding to non-blank token idxs.
                        "score": 0,
                        "timesteps": self.get_timesteps(
                            viterbi_path[b].tolist()
                        ),  # frame numbers of non-blank tokens.
                        "words": post_process(
                            self.tgt_dict.string(tokens), "letter"
                        ).split(
                            " "
                        ),  # the transcript as a list of words.
                    }
                ]
            )

        return hypos
  • The new decode() method for W2lKenLMDecoder:
def decode(self, emissions):
        B, T, N = emissions.size()
        hypos = []
        for b in range(B):
            emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
            results = self.decoder.decode(emissions_ptr, T, N)

            nbest_results = results[: self.nbest]
            hypos.append(
                [
                    {
                        "tokens": tokens,  # non-blank token idxs.
                        "symbols": self.get_symbols(
                            tokens
                        ),  # characters (symbols) corresponding to non-blank token idxs.
                        "score": result.score,
                        "timesteps": self.get_timesteps(
                            result.tokens
                        ),  # frame numbers of non-blank tokens.
                        "words": [
                            self.word_dict.get_entry(x) for x in result.words if x >= 0
                        ],  # the transcript as a list of words. Empty if lexicon-free decoding.
                    }
                    for result in nbest_results
                    if (
                        tokens := self.get_tokens(result.tokens).tolist()
                    )  # tokens is a local variable for the list comprehension.
                ]
            )
        return hypos
  • The new decode() method for W2lFairseqLMDecoder:
def decode(self, emissions):
        B, T, N = emissions.size()
        hypos = []

        def idx_to_word(idx):
            if self.unit_lm:
                return self.idx_to_wrd[idx]
            else:
                return self.word_dict[idx]

        def make_hypo(result):
            hypo = {
                        "tokens": self.get_tokens(result.tokens).tolist(),  # non-blank token idxs.
                        "score": result.score
                    }
            if self.lexicon:
                hypo["words"] = [idx_to_word(x) for x in result.words if x >= 0] # the transcript as a list of words.
            hypo["symbols"] = self.get_symbols(hypo["tokens"]) # characters (symbols) corresponding to non-blank token idxs.
            hypo["timesteps"] = self.get_timesteps(result.tokens) # frame numbers of non-blank tokens.

            return hypo

        for b in range(B):
            emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
            results = self.decoder.decode(emissions_ptr, T, N)

            nbest_results = results[: self.nbest]
            hypos.append([make_hypo(result) for result in nbest_results])
            self.lm.empty_cache()

        return hypos

I then postprocess the results in my own custom script to get the word-level time alignments (in seconds) for each hypothesis:

  1. For results of beam search-based decoders (W2lKenLMDecoder, W2lFairseqLMDecoder) I use the following function to process the result of decode():
def beam_search_decode_fairseq(hypos, emission_mx, audio_lens, num_hyps, time_aligns):
"""Process the results of a W2lDecoder object from fairseq.

    Args:
        hypos (Union[List[Dict], List[List[Dict]]]):
            List of results for each audio file returned by a W2lDecoder object. If the number of hypotheses to return (W2lDecoder.nbest) is 1, hypos will be a list of just the best hypotheses dicts.
             If W2lDecoder.nbest > 1, hypos will be a list of lists, where for each audio file there will be N best hypotheses dicts.
        emission_mx (torch.tensor(B,T,N)):
            The batched emission matrix outputted by the w2v2 acoustic model trained in fairseq.
        audio_lens (List[int]):
            The lengths of the original audio files in the batch, measured in number of samples.
        num_hyps (int):
            The number of best hypotheses to return per audio file.
        time_aligns (bool):
            Flag used to specify whether to calculate word-level time alignment in seconds for each hypothesis.

    Returns:
        transcripts (Union[List[Dict], List[List[Dict]]]):
            List of processed results for each audio file. If W2lDecoder.nbest = 1, transcripts will be a list of just the best hypotheses dicts.
             If W2lDecoder.nbest > 1, transcripts will be a list of lists, where for each audio file there will be N best hypotheses dicts.
            A hypothesis dict has the following fields:
                'pred_txt': (str) the transcript hypothesis itself.
                'timestamps_word': (List[Dict]) List of word Dict objects, one for each word in the transcript, with the following fields:
                    'word': the word itself.
                    'start_time': the start time of the word in seconds in the corresponding audio file.
                    'end_time': the end time of the word in seconds in the corresponding audio file.
    """
    transcripts = []
    for i in range(emission_mx.size(dim=0)):
        # if the batch_size is > 1, use the maximum original audio length in the batch, as all other audio files are padded to the max length during preprocessing.
        audio_len = audio_lens[i] if emission_mx.size(dim=0) == 1 else max(audio_lens)
        if num_hyps > 1:
            all_results = []
            for hyp in hypos[i]:
                hyp_dict = dict()
                if hyp['words']:
                    # 'words' field is not empty if using a lexicon.
                    transcript = ' '.join(hyp['words']).lower()
                else:
                    # 'words' field is [] if lexicon-free decoding, convert from non-blank symbols to words instead.
                    tokens_str = ''.join(hyp['symbols'])
                    transcript = ' '.join(tokens_str.split('|')).strip().lower()
                hyp_dict['pred_txt'] = transcript
                if time_aligns:
                    word_times = get_word_time_alignments_fairseq(audio_len, emission_mx.size(dim=1), 16000, hyp['symbols'], hyp['timesteps'])
                    timestamps_word = normalize_timestamp_output_w2v2(hyp_dict['pred_txt'].split(' '), word_times)
                    hyp_dict['timestamps_word'] = timestamps_word
                # add a hypothesis dict
                all_results.append(hyp_dict)
                
            transcripts.append(all_results)
        else:
            hyp_dict = dict()
            # append the decoded phrase (as a list of words) from the prediction of the first beam [0] (most likely transcript).
            if hypos[i][0]['words']:
                # 'words' field is not empty if using a lexicon.
                transcript = ' '.join(hypos[i][0]['words']).lower()
            else:
                # 'words' field is [] if lexicon-free decoding, convert from non-blank symbols to words instead.
                tokens_str = ''.join(hypos[i][0]['symbols'])
                transcript = ' '.join(tokens_str.split('|')).strip().lower()
            hyp_dict['pred_txt'] = transcript
            if time_aligns:
                word_times = get_word_time_alignments_fairseq(audio_len, emission_mx.size(dim=1), 16000, hypos[i][0]['symbols'], hypos[i][0]['timesteps'])
                timestamps_word = normalize_timestamp_output_w2v2(hyp_dict['pred_txt'].split(' '), word_times)
                hyp_dict['timestamps_word'] = timestamps_word
            # add a hypothesis dict
            transcripts.append(hyp_dict)

    return transcripts
  1. For Viterbi decoder (W2lViterbiDecoder) I use the following code to process the result of decode():
transcripts = []
        for i in range(emission_mx.size(dim=0)):
            # if the batch_size is > 1, use the maximum original audio length in the batch, as all other audio files are padded to the max length during preprocessing.
            audio_len = audio_lens[i] if emission_mx.size(dim=0) == 1 else max(audio_lens)
            hyp_dict = dict()
            # append the decoded phrase (as a list of words) from the prediction of the first beam [0] (most likely transcript).
            transcript = ' '.join(hypos[i][0]['words']).lower()
            hyp_dict['pred_txt'] = transcript
            if self.time_aligns:
                word_times = get_word_time_alignments_fairseq(audio_len, emission_mx.size(dim=1), 16000, hypos[i][0]['symbols'], hypos[i][0]['timesteps'])
                timestamps_word = normalize_timestamp_output_w2v2(hyp_dict['pred_txt'].split(' '), word_times)
                hyp_dict['timestamps_word'] = timestamps_word
            # add a hypothesis dict
            transcripts.append(hyp_dict)

        return transcripts

Most importantly, get_word_time_alignments_fairseq() is where I calculate the word-level time alignments:

def get_word_time_alignments_fairseq(audio_len, num_frames, sample_rate, symbols, timesteps):
    """Get word time alignments information for a hypothesis transcript input by converting from timesteps to seconds.
    Args:
        audio_len (int):
            The length of audio file in number of samples.
        num_frames (int):
            The number of frames in the ASR acoustic model emission matrix.
        sample_rate (int):
            The sample rate of the loaded audio file.
        symbols (List[str]):
            Decoded list of characters corresponding to the non-blank tokens returned by the decoder.
        timesteps (List[int]):
            Frame numbers corresponding to the non-blank tokens/symbols.

    Returns:
        word_times (List[Tuple[float, float]]):
            List of tuples of start_time and stop_time in seconds for word in the transcript.
    """
    # list of times in seconds in the corresponding audio file for the the non-blank tokens/symbols.
    timestamps = []
    # get the timestep in seconds corresponding to each non-blank token.
    for frame_num in timesteps:
        timestamp = frame_num * (audio_len / (num_frames * sample_rate))
        timestamps.append(timestamp)

    # NOTE: algorithm only works if the first and last symbols are '|', so add them in if that's not the case.
    frame_offset = 0
    if symbols[0] != '|':
        symbols.insert(0, '|')
        # if adding a symbol at index 0, all symbols will have their frame idx increased by 1, so an offset of -1 is created.
        frame_offset = -1
    if symbols[-1] != '|':
        symbols.append('|')

    word_boundary_idxs = [] # tuples of word start and stop indices.
    # get the indices of all word-boundary tokens (|).
    wb_tokens_idxs = [i for i in range(len(symbols)) if symbols[i] == '|']

    # create tuples for each word that contains the indices of its start symbol and end symbol.
    tup = [] # initialise the first tuple of word start character and word end character indices.
    # loop through the indices of the '|' tokens and find the indices of the word-boundary symbols/characters that are the start and end characters of each word.
    for wb_tokens_idx in wb_tokens_idxs:
        try:
            if symbols[wb_tokens_idx-1] != '|' and tup:
                # there is a start index in tuple, but no end index yet.
                # end index has been found.
                if wb_tokens_idx-1 == tup[0]:
                    # word is composed of only one character, add the index of this '|' token as the end character index for the word.
                    tup.append(wb_tokens_idx)
                else:
                    # word is composed of more than one character.
                    tup.append(wb_tokens_idx-1) # add an end character index for the word.
                # add the tuple as complete word to the list of word start and end index tuples.
                word_boundary_idxs.append(tup)
                tup = [] # reset the tuple.
                # continue onto the next if statement as this '|' token may be the boundary between two words.
            if symbols[wb_tokens_idx+1] != '|':
                # start character of new word reached.
                tup.append(wb_tokens_idx+1) # add a start character index for the word.
        except IndexError:
            continue
    
    # create tuples of start and stop times for each word
    word_times = [(timestamps[start_idx + frame_offset], timestamps[end_idx + frame_offset]) for start_idx, end_idx in word_boundary_idxs]

    return word_times

And normalize_timestamp_output_w2v2() is just a utility function to create a Dict containing time alignment information for each word:

def normalize_timestamp_output_w2v2(words, word_time_tuples):
    """Get word Dict objects with time information for each word in the hypothesis transcript.

    Args:
        words (List[str]):
            List of words in the transcript.
        word_time_tuples (List[Tuple[float,float]]):
            List of tuples of start_time and stop_time in seconds for word in the transcript.

    Returns:
        values (List[Dict]):
            List of dict objects where each dict has the following fields:
                'word': (str) the word itself.
                'start_time': (float) the start time in seconds of the word in the corresponding audio file.
                'end_time': (float) the end time in seconds of the word in the corresponding audio file.
    """
    values = []
    for word, (word_start, word_end) in zip(words, word_time_tuples):
        vals_dict = dict()
        vals_dict['word'] = word
        vals_dict['start_time'] = word_start
        vals_dict['end_time'] = word_end
        values.append(vals_dict)
    
    return values

The formula I use to calculate the time in seconds in the corresponding audio for each non-blank symbol in the transcript is the following:

timestamp = frame_num * (audio_len / (num_frames * sample_rate))

where:

  • frame_num = the timestep of the symbol, as returned in the 'timesteps' field of Wl2Decoder.decode() outputs.
  • audio_len = the number of samples in the loaded audio file corresponding to the transcript (if using batched w2v2 acoustic model inference, will be zero padded to the length of the longest loaded audio file in the batch).
  • num_frames = the number of frames in the emission matrix returned by the w2v2 acoustic model inference for that audio file (if using batched inference, the number of frames for each audio file will be the same as in this case all loaded audio files are padded to the length of the longest audio file in the batch).
  • sample_rate = sample rate of loaded audio files (usually 16000 Hz).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Question] wav2vec 2.0 timestamp words
8 participants