min-dalle-test/min_dalle/text_tokenizer.py

39 lines
1.6 KiB
Python
Raw Permalink Normal View History

2022-06-27 15:57:56 +00:00
from math import inf
from typing import List, Tuple
2022-06-27 18:34:10 +00:00
2022-06-27 15:57:56 +00:00
class TextTokenizer:
2022-07-04 20:06:49 +00:00
def __init__(self, vocab: dict, merges: List[str]):
2022-06-27 15:57:56 +00:00
self.token_from_subword = vocab
pairs = [tuple(pair.split()) for pair in merges]
self.rank_from_pair = dict(zip(pairs, range(len(pairs))))
2022-07-04 20:06:49 +00:00
def tokenize(self, text: str, is_verbose: bool = False) -> List[int]:
2022-06-27 15:57:56 +00:00
sep_token = self.token_from_subword['</s>']
cls_token = self.token_from_subword['<s>']
unk_token = self.token_from_subword['<unk>']
text = text.lower().encode("ascii", errors="ignore").decode()
tokens = [
self.token_from_subword.get(subword, unk_token)
for word in text.split(" ") if len(word) > 0
2022-07-04 20:06:49 +00:00
for subword in self.get_byte_pair_encoding(word, is_verbose)
2022-06-27 15:57:56 +00:00
]
return [cls_token] + tokens + [sep_token]
2022-07-04 20:06:49 +00:00
def get_byte_pair_encoding(self, word: str, is_verbose: bool) -> List[str]:
def get_pair_rank(pair: Tuple[str, str]) -> int:
2022-06-27 15:57:56 +00:00
return self.rank_from_pair.get(pair, inf)
subwords = [chr(ord(" ") + 256)] + list(word)
while len(subwords) > 1:
pairs = list(zip(subwords[:-1], subwords[1:]))
pair_to_merge = min(pairs, key=get_pair_rank)
if pair_to_merge not in self.rank_from_pair: break
i = pairs.index(pair_to_merge)
subwords = (
(subwords[:i] if i > 0 else []) +
[subwords[i] + subwords[i + 1]] +
(subwords[i + 2:] if i + 2 < len(subwords) else [])
)
2022-07-04 20:06:49 +00:00
if is_verbose: print(subwords)
2022-06-27 15:57:56 +00:00
return subwords