2022-06-29 13:42:12 +00:00
|
|
|
import os
|
|
|
|
import json
|
|
|
|
import numpy
|
|
|
|
|
|
|
|
from .text_tokenizer import TextTokenizer
|
|
|
|
|
2022-06-30 10:43:10 +00:00
|
|
|
class MinDalleBase:
|
2022-06-29 13:42:12 +00:00
|
|
|
def __init__(self, is_mega: bool):
|
|
|
|
self.is_mega = is_mega
|
|
|
|
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
|
2022-06-30 18:54:08 +00:00
|
|
|
self.model_path = os.path.join('pretrained', model_name)
|
2022-06-29 13:42:12 +00:00
|
|
|
|
2022-06-30 18:54:08 +00:00
|
|
|
print("reading files from {}".format(self.model_path))
|
|
|
|
vocab_path = os.path.join(self.model_path, 'vocab.json')
|
|
|
|
merges_path = os.path.join(self.model_path, 'merges.txt')
|
2022-06-29 18:18:23 +00:00
|
|
|
|
|
|
|
with open(vocab_path, 'r', encoding='utf8') as f:
|
2022-06-29 13:42:12 +00:00
|
|
|
vocab = json.load(f)
|
2022-06-29 18:18:23 +00:00
|
|
|
with open(merges_path, 'r', encoding='utf8') as f:
|
2022-06-29 13:42:12 +00:00
|
|
|
merges = f.read().split("\n")[1:-1]
|
2022-06-29 18:18:23 +00:00
|
|
|
|
2022-06-29 13:42:12 +00:00
|
|
|
self.tokenizer = TextTokenizer(vocab, merges)
|
2022-06-30 10:43:10 +00:00
|
|
|
|
|
|
|
|
2022-06-29 13:42:12 +00:00
|
|
|
def tokenize_text(self, text: str) -> numpy.ndarray:
|
|
|
|
print("tokenizing text")
|
|
|
|
tokens = self.tokenizer.tokenize(text)
|
|
|
|
print("text tokens", tokens)
|
2022-07-01 16:03:37 +00:00
|
|
|
text_tokens = numpy.ones((2, 64), dtype=numpy.int32)
|
2022-07-01 14:17:29 +00:00
|
|
|
text_tokens[0, :2] = [tokens[0], tokens[-1]]
|
|
|
|
text_tokens[1, :len(tokens)] = tokens
|
2022-06-29 13:42:12 +00:00
|
|
|
return text_tokens
|