Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Allow model to take byte-level input and make byte-level prediction #1187

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions pytext/data/tensorizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from .utils import (
BOL,
BOS,
BYTE_BOS,
BYTE_EOS,
EOL,
EOS,
PAD,
Expand Down Expand Up @@ -360,28 +362,56 @@ class Config(Tensorizer.Config):
column: str = "text"
lower: bool = True
max_seq_len: Optional[int] = None
add_bos_token: Optional[bool] = False
add_eos_token: Optional[bool] = False
use_eos_token_for_bos: Optional[bool] = False

@classmethod
def from_config(cls, config: Config):
return cls(config.column, config.lower, config.max_seq_len)
return cls(
config.column,
config.lower,
config.max_seq_len,
config.add_bos_token,
config.add_eos_token,
config.use_eos_token_for_bos,
)

def __init__(self, text_column, lower=True, max_seq_len=None):
def __init__(
self,
text_column,
lower=True,
max_seq_len=None,
add_bos_token=Config.add_bos_token,
add_eos_token=Config.add_eos_token,
use_eos_token_for_bos=Config.use_eos_token_for_bos,
):
self.text_column = text_column
self.lower = lower
self.max_seq_len = max_seq_len
self.add_bos_token = add_bos_token
self.add_eos_token = add_eos_token
self.use_eos_token_for_bos = use_eos_token_for_bos

@property
def column_schema(self):
return [(self.text_column, str)]

def numberize(self, row):
"""Convert text to characters."""
text = row[self.text_column]
text = row[self.text_column].strip()
if self.lower:
text = text.lower()

bytes = list(text.encode())

if self.max_seq_len:
bytes = bytes[: self.max_seq_len]
if self.add_bos_token:
bos = BYTE_EOS if self.use_eos_token_for_bos else BYTE_BOS
bytes = list(bos.encode()) + bytes
if self.add_eos_token:
bytes = bytes + list(BYTE_EOS.encode())
return bytes, len(bytes)

def tensorize(self, batch):
Expand Down
4 changes: 4 additions & 0 deletions pytext/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ def __eq__(self, other):
BOL = SpecialToken("__BEGIN_OF_LIST__")
EOL = SpecialToken("__END_OF_LIST__")
MASK = SpecialToken("__MASK__")
# BOS and EOS is too long for Byte-level Language Model.
# Todo: find out conbination of bytes with low-frequency and shorter length
BYTE_BOS = SpecialToken("^")
BYTE_EOS = SpecialToken("#")

UNK_INDEX = 0
PAD_INDEX = 1
Expand Down
10 changes: 9 additions & 1 deletion pytext/metric_reporters/language_model_metric_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class LanguageModelMetricReporter(MetricReporter):
UTTERANCE_COLUMN = "utterance"
RAW_TEXT_COLUMN = "text"
TOKENS_COLUMN = "tokens"
LABELS_COLUMN = "labels"
lower_is_better = True

class Config(MetricReporter.Config):
Expand Down Expand Up @@ -84,7 +85,14 @@ def __init__(
if metadata:
self.pad_index = metadata.target.pad_token_idx
if tensorizers:
self.pad_index = tensorizers[self.TOKENS_COLUMN].vocab.get_pad_index()
if self.TOKENS_COLUMN in tensorizers:
column = self.TOKENS_COLUMN
elif self.LABELS_COLUMN in tensorizers:
column = self.LABELS_COLUMN
if hasattr(tensorizers[column], "vocab"):
self.pad_index = tensorizers[column].vocab.get_pad_index()
else:
self.pad_index = tensorizers[column].PAD_BYTE
self.perplexity_func = get_perplexity_func(perplexity_type)

def add_batch_stats(
Expand Down