added to pypi
This commit is contained in:
@@ -5,23 +5,29 @@ import numpy
|
||||
from torch import LongTensor
|
||||
import torch
|
||||
import json
|
||||
import requests
|
||||
torch.set_grad_enabled(False)
|
||||
torch.set_num_threads(os.cpu_count())
|
||||
|
||||
from .text_tokenizer import TextTokenizer
|
||||
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
|
||||
from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch
|
||||
from .models.vqgan_detokenizer import VQGanDetokenizer
|
||||
MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
|
||||
|
||||
from .text_tokenizer import TextTokenizer
|
||||
from .models import (
|
||||
DalleBartEncoderTorch,
|
||||
DalleBartDecoderTorch,
|
||||
VQGanDetokenizer
|
||||
)
|
||||
|
||||
class MinDalleTorch:
|
||||
def __init__(
|
||||
self,
|
||||
self,
|
||||
is_mega: bool,
|
||||
is_reusable: bool = True,
|
||||
models_root: str = 'pretrained',
|
||||
sample_token_count: int = 256
|
||||
):
|
||||
print("initializing MinDalleTorch")
|
||||
self.is_mega = is_mega
|
||||
self.is_reusable = is_reusable
|
||||
self.sample_token_count = sample_token_count
|
||||
self.batch_count = 2
|
||||
@@ -35,10 +41,15 @@ class MinDalleTorch:
|
||||
self.image_vocab_count = 16415 if is_mega else 16384
|
||||
|
||||
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
|
||||
self.model_path = os.path.join('pretrained', model_name)
|
||||
self.encoder_params_path = os.path.join(self.model_path, 'encoder.pt')
|
||||
self.decoder_params_path = os.path.join(self.model_path, 'decoder.pt')
|
||||
self.detoker_params_path = os.path.join('pretrained', 'vqgan', 'detoker.pt')
|
||||
dalle_path = os.path.join(models_root, model_name)
|
||||
vqgan_path = os.path.join(models_root, 'vqgan')
|
||||
if not os.path.exists(dalle_path): os.makedirs(dalle_path)
|
||||
if not os.path.exists(vqgan_path): os.makedirs(vqgan_path)
|
||||
self.vocab_path = os.path.join(dalle_path, 'vocab.json')
|
||||
self.merges_path = os.path.join(dalle_path, 'merges.txt')
|
||||
self.encoder_params_path = os.path.join(dalle_path, 'encoder.pt')
|
||||
self.decoder_params_path = os.path.join(dalle_path, 'decoder.pt')
|
||||
self.detoker_params_path = os.path.join(vqgan_path, 'detoker.pt')
|
||||
|
||||
self.init_tokenizer()
|
||||
if is_reusable:
|
||||
@@ -46,18 +57,51 @@ class MinDalleTorch:
|
||||
self.init_decoder()
|
||||
self.init_detokenizer()
|
||||
|
||||
|
||||
def download_tokenizer(self):
|
||||
print("downloading tokenizer params")
|
||||
suffix = '' if self.is_mega else '_mini'
|
||||
vocab = requests.get(MIN_DALLE_REPO + 'vocab{}.json'.format(suffix))
|
||||
merges = requests.get(MIN_DALLE_REPO + 'merges{}.txt'.format(suffix))
|
||||
with open(self.vocab_path, 'wb') as f: f.write(vocab.content)
|
||||
with open(self.merges_path, 'wb') as f: f.write(merges.content)
|
||||
|
||||
|
||||
def download_encoder(self):
|
||||
print("downloading encoder params")
|
||||
suffix = '' if self.is_mega else '_mini'
|
||||
params = requests.get(MIN_DALLE_REPO + 'encoder{}.pt'.format(suffix))
|
||||
with open(self.encoder_params_path, 'wb') as f: f.write(params.content)
|
||||
|
||||
|
||||
def download_decoder(self):
|
||||
print("downloading decoder params")
|
||||
suffix = '' if self.is_mega else '_mini'
|
||||
params = requests.get(MIN_DALLE_REPO + 'decoder{}.pt'.format(suffix))
|
||||
with open(self.decoder_params_path, 'wb') as f: f.write(params.content)
|
||||
|
||||
|
||||
def download_detokenizer(self):
|
||||
print("downloading detokenizer params")
|
||||
params = requests.get(MIN_DALLE_REPO + 'detoker.pt')
|
||||
with open(self.detoker_params_path, 'wb') as f: f.write(params.content)
|
||||
|
||||
|
||||
def init_tokenizer(self):
|
||||
print("reading files from {}".format(self.model_path))
|
||||
vocab_path = os.path.join(self.model_path, 'vocab.json')
|
||||
merges_path = os.path.join(self.model_path, 'merges.txt')
|
||||
with open(vocab_path, 'r', encoding='utf8') as f:
|
||||
is_downloaded = os.path.exists(self.vocab_path)
|
||||
is_downloaded &= os.path.exists(self.merges_path)
|
||||
if not is_downloaded: self.download_tokenizer()
|
||||
print("intializing TextTokenizer")
|
||||
with open(self.vocab_path, 'r', encoding='utf8') as f:
|
||||
vocab = json.load(f)
|
||||
with open(merges_path, 'r', encoding='utf8') as f:
|
||||
with open(self.merges_path, 'r', encoding='utf8') as f:
|
||||
merges = f.read().split("\n")[1:-1]
|
||||
self.tokenizer = TextTokenizer(vocab, merges)
|
||||
|
||||
|
||||
def init_encoder(self):
|
||||
is_downloaded = os.path.exists(self.encoder_params_path)
|
||||
if not is_downloaded: self.download_encoder()
|
||||
print("initializing DalleBartEncoderTorch")
|
||||
self.encoder = DalleBartEncoderTorch(
|
||||
attention_head_count = self.attention_head_count,
|
||||
@@ -74,6 +118,8 @@ class MinDalleTorch:
|
||||
|
||||
|
||||
def init_decoder(self):
|
||||
is_downloaded = os.path.exists(self.decoder_params_path)
|
||||
if not is_downloaded: self.download_decoder()
|
||||
print("initializing DalleBartDecoderTorch")
|
||||
self.decoder = DalleBartDecoderTorch(
|
||||
sample_token_count = self.sample_token_count,
|
||||
@@ -93,6 +139,8 @@ class MinDalleTorch:
|
||||
|
||||
|
||||
def init_detokenizer(self):
|
||||
is_downloaded = os.path.exists(self.detoker_params_path)
|
||||
if not is_downloaded: self.download_detokenizer()
|
||||
print("initializing VQGanDetokenizer")
|
||||
self.detokenizer = VQGanDetokenizer()
|
||||
params = torch.load(self.detoker_params_path)
|
||||
|
3
min_dalle/models/__init__.py
Normal file
3
min_dalle/models/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .dalle_bart_encoder_torch import DalleBartEncoderTorch
|
||||
from .dalle_bart_decoder_torch import DalleBartDecoderTorch
|
||||
from .vqgan_detokenizer import VQGanDetokenizer
|
Reference in New Issue
Block a user