Skip to content

Fix llama tokenizer #22402

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Apr 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 4 additions & 14 deletions src/transformers/models/llama/convert_llama_weights_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import torch

from transformers import LlamaConfig, LlamaForCausalLM
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer


"""
Expand Down Expand Up @@ -233,19 +233,9 @@ def permute(w):

def write_tokenizer(tokenizer_path, input_tokenizer_path):
print(f"Fetching the tokenizer from {input_tokenizer_path}.")
os.makedirs(tokenizer_path, exist_ok=True)
write_json({}, os.path.join(tokenizer_path, "special_tokens_map.json"))
write_json(
{
"bos_token": "",
"eos_token": "",
"model_max_length": int(1e30),
"tokenizer_class": "LlamaTokenizer",
"unk_token": "",
},
os.path.join(tokenizer_path, "tokenizer_config.json"),
)
shutil.copyfile(input_tokenizer_path, os.path.join(tokenizer_path, "tokenizer.model"))
# Initialize the tokenizer based on the `spm` model
tokenizer = LlamaTokenizer(input_tokenizer_path)
tokenizer.save_pretrained(tokenizer_path)


def main():
Expand Down
110 changes: 63 additions & 47 deletions src/transformers/models/llama/tokenization_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,25 @@

import sentencepiece as spm

from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import logging


logger = logging.get_logger(__name__)

VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}

PRETRAINED_VOCAB_FILES_MAP = {}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model",
},
"tokenizer_file": {
"hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"hf-internal-testing/llama-tokenizer": 2048,
}


class LlamaTokenizer(PreTrainedTokenizer):
Expand All @@ -47,6 +57,7 @@ class LlamaTokenizer(PreTrainedTokenizer):

vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]

def __init__(
Expand All @@ -55,51 +66,50 @@ def __init__(
unk_token="<unk>",
bos_token="<s>",
eos_token="</s>",
pad_token=None,
sp_model_kwargs: Optional[Dict[str, Any]] = None,
add_bos_token=True,
add_eos_token=False,
decode_with_prefix_space=False,
clean_up_tokenization_spaces=False,
**kwargs,
):
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
add_bos_token=add_bos_token,
add_eos_token=add_eos_token,
sp_model_kwargs=self.sp_model_kwargs,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
self.vocab_file = vocab_file
self.add_bos_token = add_bos_token
self.add_eos_token = add_eos_token
self.decode_with_prefix_space = decode_with_prefix_space
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(vocab_file)
self._no_prefix_space_tokens = None

""" Initialisation"""
def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
return state

@property
def no_prefix_space_tokens(self):
if self._no_prefix_space_tokens is None:
vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
self._no_prefix_space_tokens = {i for i, tok in enumerate(vocab) if not tok.startswith("▁")}
return self._no_prefix_space_tokens
def __setstate__(self, d):
self.__dict__ = d
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab_file)

@property
def vocab_size(self):
"""Returns vocab size"""
return self.sp_model.get_piece_size()

@property
def bos_token_id(self) -> Optional[int]:
return self.sp_model.bos_id()

@property
def eos_token_id(self) -> Optional[int]:
return self.sp_model.eos_id()

def get_vocab(self):
"""Returns vocab as a dict"""
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
Expand All @@ -119,21 +129,15 @@ def _convert_id_to_token(self, index):
token = self.sp_model.IdToPiece(index)
return token

def _maybe_add_prefix_space(self, tokens, decoded):
if tokens and tokens[0] not in self.no_prefix_space_tokens:
return " " + decoded
else:
return decoded

def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
prev_is_special = False
for token in tokens:
for i, token in enumerate(tokens):
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special:
if not prev_is_special and i != 0:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
Expand All @@ -142,7 +146,6 @@ def convert_tokens_to_string(self, tokens):
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)
return out_string

def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
Expand Down Expand Up @@ -173,18 +176,13 @@ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None)
return (out_vocab_file,)

def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
if self.add_bos_token:
bos_token_ids = [self.bos_token_id]
else:
bos_token_ids = []
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
eos_token_id = [self.eos_token_id] if self.add_eos_token else []

output = bos_token_ids + token_ids_0
output = bos_token_id + token_ids_0 + eos_token_id

if token_ids_1 is not None:
output = output + token_ids_1

if self.add_eos_token:
output = output + [self.eos_token_id]
output = output + bos_token_id + token_ids_1 + eos_token_id

return output

Expand All @@ -211,28 +209,46 @@ def get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)

bos_token_id = [1] if self.add_bos_token else []
eos_token_id = [1] if self.add_eos_token else []

if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
return (
bos_token_id
+ ([0] * len(token_ids_0))
+ eos_token_id
+ bos_token_id
+ ([0] * len(token_ids_1))
+ eos_token_id
)

def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
use of token type ids, therefore a list of zeros is returned.
Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
sequence pair mask has the following format:

```
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
```

if token_ids_1 is None, only returns the first portion of the mask (0s).

Args:
token_ids_0 (`List[int]`):
List of IDs.
List of ids.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.

Returns:
`List[int]`: List of zeros.
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
"""
eos = [self.eos_token_id]
sep = [self.sep_token_id]
cls = [self.cls_token_id]

if token_ids_1 is None:
return len(token_ids_0 + eos) * [0]
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
Loading