|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved |
| 3 | + |
| 4 | +from typing import Dict, List, Optional, Type |
| 5 | + |
| 6 | +from pytext.common.constants import Stage |
| 7 | +from pytext.data import Batcher, Data |
| 8 | +from pytext.data.bert_tensorizer import BERTTensorizer |
| 9 | +from pytext.data.data import RowData |
| 10 | +from pytext.data.sources import DataSource |
| 11 | +from pytext.data.tensorizers import Tensorizer, TokenTensorizer |
| 12 | +from pytext.data.xlm_tensorizer import XLMTensorizer |
| 13 | + |
| 14 | + |
| 15 | +class PackedLMData(Data): |
| 16 | + """ |
| 17 | + Special purpose Data object which assumes a single text tensorizer. Packs |
| 18 | + tokens into a square batch with no padding. Used for LM training. The object |
| 19 | + also takes in an optional language argument which is used for cross-lingual |
| 20 | + LM training. |
| 21 | + """ |
| 22 | + |
| 23 | + __EXPANSIBLE__ = True |
| 24 | + |
| 25 | + class Config(Data.Config): |
| 26 | + max_seq_len: int = 128 |
| 27 | + |
| 28 | + @classmethod |
| 29 | + def from_config( |
| 30 | + cls, |
| 31 | + config: Config, |
| 32 | + schema: Dict[str, Type], |
| 33 | + tensorizers: Dict[str, Tensorizer], |
| 34 | + language: Optional[str] = None, |
| 35 | + rank: int = 0, |
| 36 | + world_size: int = 1, |
| 37 | + ): |
| 38 | + return super(PackedLMData, cls).from_config( |
| 39 | + config, |
| 40 | + schema, |
| 41 | + tensorizers, |
| 42 | + rank, |
| 43 | + world_size, |
| 44 | + language=language, |
| 45 | + max_seq_len=config.max_seq_len, |
| 46 | + ) |
| 47 | + |
| 48 | + def __init__( |
| 49 | + self, |
| 50 | + data_source: DataSource, |
| 51 | + tensorizers: Dict[str, Tensorizer], |
| 52 | + batcher: Batcher = None, |
| 53 | + max_seq_len: int = Config.max_seq_len, |
| 54 | + sort_key: Optional[str] = None, |
| 55 | + # language is used in cross-lingual LM training |
| 56 | + language: Optional[str] = None, |
| 57 | + in_memory: Optional[bool] = False, |
| 58 | + ): |
| 59 | + super().__init__(data_source, tensorizers, batcher, sort_key, in_memory) |
| 60 | + assert len(list(self.tensorizers.items())) == 1 |
| 61 | + self.tensorizer_name, self.tensorizer = list(self.tensorizers.items())[0] |
| 62 | + self.remainder: Dict[str, List[int]] = {"tokens": [], "segment_labels": []} |
| 63 | + self.max_seq_len = max_seq_len |
| 64 | + self.language = language |
| 65 | + self.batch = {Stage.TRAIN: None, Stage.EVAL: None, Stage.TEST: None} |
| 66 | + |
| 67 | + def _parse_row(self, row): |
| 68 | + """ |
| 69 | + The output of numberization has different number of elements depending on |
| 70 | + the tensorizer used. For example: positions tensor is only output by the |
| 71 | + XLMTensorizer. This function unpacks the elements according to the |
| 72 | + specific tensorizer used. |
| 73 | + Additionally, since we are packing tokens into fixed size |
| 74 | + blocks, we don't need to use the positions vector output by the call to |
| 75 | + numberize. We will simply create this in `_format_output_row`. |
| 76 | + """ |
| 77 | + numberized_row = self.tensorizer.numberize(row) |
| 78 | + if isinstance(self.tensorizer, XLMTensorizer): |
| 79 | + tokens, seq_len, segment_labels, _ = numberized_row |
| 80 | + elif isinstance(self.tensorizer, BERTTensorizer): |
| 81 | + tokens, segment_labels, seq_len = numberized_row |
| 82 | + elif isinstance(self.tensorizer, TokenTensorizer): |
| 83 | + tokens, seq_len, _ = numberized_row |
| 84 | + segment_labels = [] |
| 85 | + else: |
| 86 | + raise NotImplementedError( |
| 87 | + "PackedLMData only supports XLMTensorizer, BERTTensorizer and " |
| 88 | + "TokenTensorizer." |
| 89 | + ) |
| 90 | + return tokens, segment_labels, seq_len |
| 91 | + |
| 92 | + def _format_output_row(self, tokens, segment_labels, seq_len): |
| 93 | + """ |
| 94 | + The tensorize function for different tensorizers takes in different |
| 95 | + number of inputs which may be arranged differently. This function formats |
| 96 | + the output dict to conform to the expectations of the tensorizer. |
| 97 | + In case of the XLMTensorizer, we also need to create a new positions list |
| 98 | + which goes from 0 to seq_len. |
| 99 | + """ |
| 100 | + if isinstance(self.tensorizer, XLMTensorizer): |
| 101 | + positions = [index for index in range(seq_len)] |
| 102 | + return {self.tensorizer_name: (tokens, seq_len, segment_labels, positions)} |
| 103 | + elif isinstance(self.tensorizer, BERTTensorizer): |
| 104 | + return {self.tensorizer_name: (tokens, segment_labels, seq_len)} |
| 105 | + elif isinstance(self.tensorizer, TokenTensorizer): |
| 106 | + # dummy token_ranges |
| 107 | + return {self.tensorizer_name: (tokens, seq_len, [(-1, -1)] * seq_len)} |
| 108 | + else: |
| 109 | + raise NotImplementedError( |
| 110 | + "PackedLMData only supports BERTTensorizer and TokenTensorizer." |
| 111 | + ) |
| 112 | + |
| 113 | + def _yield_and_reset(self): |
| 114 | + packed_tokens = list(self.remainder["tokens"]) |
| 115 | + packed_segments = list(self.remainder["segment_labels"]) |
| 116 | + self.remainder: Dict[str, List[int]] = {"tokens": [], "segment_labels": []} |
| 117 | + return RowData( |
| 118 | + {}, # packed LM data doesn't respect data cardinality |
| 119 | + self._format_output_row(packed_tokens, packed_segments, len(packed_tokens)), |
| 120 | + ) |
| 121 | + |
| 122 | + def numberize_rows(self, rows): |
| 123 | + """ |
| 124 | + This function does the actual packing. It processes rows until we obtain |
| 125 | + a block of data with length = max_seq_len. |
| 126 | + """ |
| 127 | + for row in rows: |
| 128 | + |
| 129 | + # if the packedLM object has a language member then a cross-lingual |
| 130 | + # LM is being trained using monolingual data. |
| 131 | + # Add this language to the row since the underlying |
| 132 | + # tensorizer needs this to generate language embeddings (used as |
| 133 | + # segment_labels below) |
| 134 | + if self.language: |
| 135 | + row["language"] = self.language |
| 136 | + |
| 137 | + tokens, segment_labels, seq_len = self._parse_row(row) |
| 138 | + remaining = self.max_seq_len - len(self.remainder["tokens"]) - 1 |
| 139 | + while remaining < len(tokens): |
| 140 | + self.remainder["tokens"].extend(tokens[:remaining]) |
| 141 | + self.remainder["segment_labels"].extend(segment_labels[:remaining]) |
| 142 | + tokens = tokens[remaining:] |
| 143 | + segment_labels = segment_labels[remaining:] |
| 144 | + yield self._yield_and_reset() |
| 145 | + remaining = self.max_seq_len - 1 |
| 146 | + self.remainder["tokens"].extend(tokens) |
| 147 | + self.remainder["segment_labels"].extend(segment_labels) |
| 148 | + if len(self.remainder["tokens"]): |
| 149 | + yield self._yield_and_reset() |
0 commit comments