refactored to load models once and run multiple times

This commit is contained in:
Brett Kuprel 2022-06-29 09:42:12 -04:00
parent 1ef9b0b929
commit ed91ab4a30
11 changed files with 225 additions and 282 deletions

View File

@ -2,18 +2,18 @@
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb)
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are `numpy` and `torch` for the torch model and `flax` for the flax model. This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are `numpy`, `torch`, and `flax`.
### Setup ### Setup
Run `sh setup.sh` to install dependencies and download pretrained models. The models can also be downloaded manually: Run `sh setup.sh` to install dependencies and download pretrained models. The `wandb` python package is installed to download DALL·E mini and DALL·E mega. Alternatively, the models can be downloaded manually here:
[VQGan](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384), [VQGan](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384),
[DALL·E Mini](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mini-1/v0/files), [DALL·E Mini](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mini-1/v0/files),
[DALL·E Mega](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mega-1-fp16/v14/files) [DALL·E Mega](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mega-1-fp16/v14/files)
### Usage ### Usage
Use the command line python script `image_from_text.py` to generate images. Here are some examples: The simplest way to get started is the command line python script `image_from_text.py` provided. Here are some examples runs:
``` ```
python image_from_text.py --text='alien life' --seed=7 python image_from_text.py --text='alien life' --seed=7
@ -32,3 +32,7 @@ python image_from_text.py --text='court sketch of godzilla on trial' --mega --se
``` ```
![Godzilla Trial](examples/godzilla_trial.png) ![Godzilla Trial](examples/godzilla_trial.png)
### Load once run multiple times
The command line script loads the models and parameters each time. The colab notebook demonstrates how to load the models once and run multiple times.

View File

@ -2,8 +2,8 @@ import argparse
import os import os
from PIL import Image from PIL import Image
from min_dalle.generate_image import generate_image_from_text from min_dalle.min_dalle_torch import MinDalleTorch
from min_dalle.min_dalle_flax import MinDalleFlax
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--mega', action='store_true') parser.add_argument('--mega', action='store_true')
@ -15,7 +15,7 @@ parser.set_defaults(torch=False)
parser.add_argument('--text', type=str) parser.add_argument('--text', type=str)
parser.add_argument('--seed', type=int, default=0) parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--image_path', type=str, default='generated') parser.add_argument('--image_path', type=str, default='generated')
parser.add_argument('--image_token_count', type=int, default=256) # for debugging parser.add_argument('--sample_token_count', type=int, default=256) # for debugging
def ascii_from_image(image: Image.Image, size: int) -> str: def ascii_from_image(image: Image.Image, size: int) -> str:
@ -36,19 +36,40 @@ def save_image(image: Image.Image, path: str):
return image return image
def generate_image(
is_torch: bool,
is_mega: bool,
text: str,
seed: int,
image_path: str,
sample_token_count: int
):
if is_torch:
image_generator = MinDalleTorch(is_mega, sample_token_count)
image_tokens = image_generator.generate_image_tokens(text, seed)
if sample_token_count < image_generator.config['image_length']:
print('image tokens', list(image_tokens.to('cpu').detach().numpy()))
return
else:
image = image_generator.generate_image(text, seed)
else:
image_generator = MinDalleFlax(is_mega)
image = image_generator.generate_image(text, seed)
save_image(image, image_path)
print(ascii_from_image(image, size=128))
if __name__ == '__main__': if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
generate_image(
image = generate_image_from_text( is_torch=args.torch,
text = args.text, is_mega=args.mega,
is_mega = args.mega, text=args.text,
is_torch = args.torch, seed=args.seed,
seed = args.seed, image_path=args.image_path,
image_token_count = args.image_token_count sample_token_count=args.sample_token_count
) )
if image != None:
save_image(image, args.image_path)
print(ascii_from_image(image, size=128))

View File

@ -1,78 +0,0 @@
import os
import json
import numpy
from PIL import Image
from typing import Tuple, List
import torch
from min_dalle.load_params import load_dalle_bart_flax_params
from min_dalle.text_tokenizer import TextTokenizer
from min_dalle.min_dalle_flax import generate_image_tokens_flax
from min_dalle.min_dalle_torch import (
generate_image_tokens_torch,
detokenize_torch
)
def load_dalle_bart_metadata(path: str) -> Tuple[dict, dict, List[str]]:
print("parsing metadata from {}".format(path))
for f in ['config.json', 'flax_model.msgpack', 'vocab.json', 'merges.txt']:
assert(os.path.exists(os.path.join(path, f)))
with open(path + '/config.json', 'r') as f:
config = json.load(f)
with open(path + '/vocab.json') as f:
vocab = json.load(f)
with open(path + '/merges.txt') as f:
merges = f.read().split("\n")[1:-1]
return config, vocab, merges
def tokenize_text(
text: str,
config: dict,
vocab: dict,
merges: List[str]
) -> numpy.ndarray:
print("tokenizing text")
tokens = TextTokenizer(vocab, merges)(text)
print("text tokens", tokens)
text_tokens = numpy.ones((2, config['max_text_length']), dtype=numpy.int32)
text_tokens[0, :len(tokens)] = tokens
text_tokens[1, :2] = [tokens[0], tokens[-1]]
return text_tokens
def generate_image_from_text(
text: str,
is_mega: bool = False,
is_torch: bool = False,
seed: int = 0,
image_token_count: int = 256
) -> Image.Image:
model_name = 'mega' if is_mega else 'mini'
model_path = './pretrained/dalle_bart_{}'.format(model_name)
config, vocab, merges = load_dalle_bart_metadata(model_path)
text_tokens = tokenize_text(text, config, vocab, merges)
params_dalle_bart = load_dalle_bart_flax_params(model_path)
if is_torch:
image_tokens = generate_image_tokens_torch(
text_tokens = text_tokens,
seed = seed,
config = config,
params = params_dalle_bart,
image_token_count = image_token_count
)
if image_token_count == config['image_length']:
image = detokenize_torch(image_tokens, is_torch=True)
return Image.fromarray(image)
else:
print(list(image_tokens.to('cpu').detach().numpy()))
else:
image_tokens = generate_image_tokens_flax(
text_tokens = text_tokens,
seed = seed,
config = config,
params = params_dalle_bart,
)
image = detokenize_torch(torch.tensor(image_tokens), is_torch=False)
return Image.fromarray(image)

38
min_dalle/min_dalle.py Normal file
View File

@ -0,0 +1,38 @@
import os
import json
import numpy
from .text_tokenizer import TextTokenizer
from .load_params import load_vqgan_torch_params, load_dalle_bart_flax_params
from .models.vqgan_detokenizer import VQGanDetokenizer
class MinDalle:
def __init__(self, is_mega: bool):
self.is_mega = is_mega
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
model_path = os.path.join('pretrained', model_name)
print("reading files from {}".format(model_path))
with open(os.path.join(model_path, 'config.json'), 'r') as f:
self.config = json.load(f)
with open(os.path.join(model_path, 'vocab.json'), 'r') as f:
vocab = json.load(f)
with open(os.path.join(model_path, 'merges.txt'), 'r') as f:
merges = f.read().split("\n")[1:-1]
self.model_params = load_dalle_bart_flax_params(model_path)
self.tokenizer = TextTokenizer(vocab, merges)
self.detokenizer = VQGanDetokenizer()
vqgan_params = load_vqgan_torch_params('./pretrained/vqgan')
self.detokenizer.load_state_dict(vqgan_params)
def tokenize_text(self, text: str) -> numpy.ndarray:
print("tokenizing text")
tokens = self.tokenizer.tokenize(text)
print("text tokens", tokens)
text_token_count = self.config['max_text_length']
text_tokens = numpy.ones((2, text_token_count), dtype=numpy.int32)
text_tokens[0, :len(tokens)] = tokens
text_tokens[1, :2] = [tokens[0], tokens[-1]]
return text_tokens

View File

@ -1,79 +1,58 @@
import jax import jax
from jax import numpy as jnp
import numpy import numpy
from PIL import Image
import torch
from .min_dalle import MinDalle
from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax
from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax
def encode_flax( class MinDalleFlax(MinDalle):
text_tokens: numpy.ndarray, def __init__(self, is_mega: bool):
config: dict, super().__init__(is_mega)
params: dict print("initializing MinDalleFlax")
) -> jnp.ndarray:
print("loading flax encoder") print("loading encoder")
encoder: DalleBartEncoderFlax = DalleBartEncoderFlax( self.encoder = DalleBartEncoderFlax(
attention_head_count = config['encoder_attention_heads'], attention_head_count = self.config['encoder_attention_heads'],
embed_count = config['d_model'], embed_count = self.config['d_model'],
glu_embed_count = config['encoder_ffn_dim'], glu_embed_count = self.config['encoder_ffn_dim'],
text_token_count = config['max_text_length'], text_token_count = self.config['max_text_length'],
text_vocab_count = config['encoder_vocab_size'], text_vocab_count = self.config['encoder_vocab_size'],
layer_count = config['encoder_layers'] layer_count = self.config['encoder_layers']
).bind({'params': params.pop('encoder')}) ).bind({'params': self.model_params.pop('encoder')})
print("loading decoder")
self.decoder = DalleBartDecoderFlax(
image_token_count = self.config['image_length'],
text_token_count = self.config['max_text_length'],
image_vocab_count = self.config['image_vocab_size'],
attention_head_count = self.config['decoder_attention_heads'],
embed_count = self.config['d_model'],
glu_embed_count = self.config['decoder_ffn_dim'],
layer_count = self.config['decoder_layers'],
start_token = self.config['decoder_start_token_id']
)
def generate_image(self, text: str, seed: int) -> Image.Image:
text_tokens = self.tokenize_text(text)
print("encoding text tokens") print("encoding text tokens")
encoder_state = encoder(text_tokens) encoder_state = self.encoder(text_tokens)
del encoder
return encoder_state
def decode_flax(
text_tokens: jnp.ndarray,
encoder_state: jnp.ndarray,
config: dict,
seed: int,
params: dict
) -> jnp.ndarray:
print("loading flax decoder")
decoder = DalleBartDecoderFlax(
image_token_count = config['image_length'],
text_token_count = config['max_text_length'],
image_vocab_count = config['image_vocab_size'],
attention_head_count = config['decoder_attention_heads'],
embed_count = config['d_model'],
glu_embed_count = config['decoder_ffn_dim'],
layer_count = config['decoder_layers'],
start_token = config['decoder_start_token_id']
)
print("sampling image tokens") print("sampling image tokens")
image_tokens = decoder.sample_image_tokens( image_tokens = self.decoder.sample_image_tokens(
text_tokens, text_tokens,
encoder_state, encoder_state,
jax.random.PRNGKey(seed), jax.random.PRNGKey(seed),
params.pop('decoder') self.model_params['decoder']
) )
del decoder
return image_tokens
image_tokens = torch.tensor(numpy.array(image_tokens))
def generate_image_tokens_flax( print("detokenizing image")
text_tokens: numpy.ndarray, image = self.detokenizer.forward(image_tokens).to(torch.uint8)
seed: int, image = Image.fromarray(image.to('cpu').detach().numpy())
config: dict, return image
params: dict
) -> numpy.ndarray:
encoder_state = encode_flax(
text_tokens,
config,
params
)
image_tokens = decode_flax(
text_tokens,
encoder_state,
config,
seed,
params
)
image_tokens = numpy.array(image_tokens)
print("image tokens", list(image_tokens))
return image_tokens

View File

@ -1,118 +1,83 @@
from random import sample
import numpy import numpy
import os import os
from PIL import Image
from typing import Dict from typing import Dict
from torch import LongTensor, FloatTensor from torch import LongTensor
import torch import torch
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count()) torch.set_num_threads(os.cpu_count())
from .models.vqgan_detokenizer import VQGanDetokenizer from .load_params import convert_dalle_bart_torch_from_flax_params
from .min_dalle import MinDalle
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch
from .load_params import (
load_vqgan_torch_params,
convert_dalle_bart_torch_from_flax_params
)
class MinDalleTorch(MinDalle):
def __init__(self, is_mega: bool, sample_token_count: int = 256):
super().__init__(is_mega)
print("initializing MinDalleTorch")
def encode_torch( print("loading encoder")
text_tokens: LongTensor, self.encoder = DalleBartEncoderTorch(
config: dict, layer_count = self.config['encoder_layers'],
params: dict embed_count = self.config['d_model'],
) -> FloatTensor: attention_head_count = self.config['encoder_attention_heads'],
print("loading torch encoder") text_vocab_count = self.config['encoder_vocab_size'],
encoder = DalleBartEncoderTorch( text_token_count = self.config['max_text_length'],
layer_count = config['encoder_layers'], glu_embed_count = self.config['encoder_ffn_dim']
embed_count = config['d_model'],
attention_head_count = config['encoder_attention_heads'],
text_vocab_count = config['encoder_vocab_size'],
text_token_count = config['max_text_length'],
glu_embed_count = config['encoder_ffn_dim']
) )
encoder_params = convert_dalle_bart_torch_from_flax_params( encoder_params = convert_dalle_bart_torch_from_flax_params(
params.pop('encoder'), self.model_params.pop('encoder'),
layer_count=config['encoder_layers'], layer_count=self.config['encoder_layers'],
is_encoder=True is_encoder=True
) )
encoder.load_state_dict(encoder_params, strict=False) self.encoder.load_state_dict(encoder_params, strict=False)
del encoder_params
if torch.cuda.is_available(): encoder = encoder.cuda()
print("encoding text tokens") print("loading decoder")
encoder_state = encoder(text_tokens) self.decoder = DalleBartDecoderTorch(
del encoder image_vocab_size = self.config['image_vocab_size'],
return encoder_state image_token_count = self.config['image_length'],
sample_token_count = sample_token_count,
embed_count = self.config['d_model'],
def decode_torch( attention_head_count = self.config['decoder_attention_heads'],
text_tokens: LongTensor, glu_embed_count = self.config['decoder_ffn_dim'],
encoder_state: FloatTensor, layer_count = self.config['decoder_layers'],
config: dict,
seed: int,
params: dict,
image_token_count: int
) -> LongTensor:
print("loading torch decoder")
decoder = DalleBartDecoderTorch(
image_vocab_size = config['image_vocab_size'],
image_token_count = config['image_length'],
sample_token_count = image_token_count,
embed_count = config['d_model'],
attention_head_count = config['decoder_attention_heads'],
glu_embed_count = config['decoder_ffn_dim'],
layer_count = config['decoder_layers'],
batch_count = 2, batch_count = 2,
start_token = config['decoder_start_token_id'], start_token = self.config['decoder_start_token_id'],
is_verbose = True is_verbose = True
) )
decoder_params = convert_dalle_bart_torch_from_flax_params( decoder_params = convert_dalle_bart_torch_from_flax_params(
params.pop('decoder'), self.model_params.pop('decoder'),
layer_count=config['decoder_layers'], layer_count=self.config['decoder_layers'],
is_encoder=False is_encoder=False
) )
decoder.load_state_dict(decoder_params, strict=False) self.decoder.load_state_dict(decoder_params, strict=False)
del decoder_params
if torch.cuda.is_available(): decoder = decoder.cuda() if torch.cuda.is_available():
self.encoder = self.encoder.cuda()
self.decoder = self.decoder.cuda()
self.detokenizer = self.detokenizer.cuda()
def generate_image_tokens(self, text: str, seed: int) -> LongTensor:
text_tokens = self.tokenize_text(text)
text_tokens = torch.tensor(text_tokens).to(torch.long)
if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
print("encoding text tokens")
encoder_state = self.encoder.forward(text_tokens)
print("sampling image tokens") print("sampling image tokens")
torch.manual_seed(seed) torch.manual_seed(seed)
image_tokens = decoder.forward(text_tokens, encoder_state) image_tokens = self.decoder.forward(text_tokens, encoder_state)
return image_tokens return image_tokens
def generate_image_tokens_torch( def generate_image(self, text: str, seed: int) -> Image.Image:
text_tokens: numpy.ndarray, image_tokens = self.generate_image_tokens(text, seed)
seed: int,
config: dict,
params: dict,
image_token_count: int
) -> LongTensor:
text_tokens = torch.tensor(text_tokens).to(torch.long)
if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
encoder_state = encode_torch(
text_tokens,
config,
params
)
image_tokens = decode_torch(
text_tokens,
encoder_state,
config,
seed,
params,
image_token_count
)
return image_tokens
def detokenize_torch(image_tokens: LongTensor, is_torch: bool) -> numpy.ndarray:
print("detokenizing image") print("detokenizing image")
model_path = './pretrained/vqgan' image = self.detokenizer.forward(image_tokens).to(torch.uint8)
params = load_vqgan_torch_params(model_path) image = Image.fromarray(image.to('cpu').detach().numpy())
detokenizer = VQGanDetokenizer() return image
detokenizer.load_state_dict(params)
if torch.cuda.is_available() and is_torch: detokenizer = detokenizer.cuda()
image = detokenizer.forward(image_tokens).to(torch.uint8)
del detokenizer, params
return image.to('cpu').detach().numpy()

View File

@ -26,7 +26,8 @@ class DecoderCrossAttentionFlax(AttentionFlax):
class DecoderSelfAttentionFlax(AttentionFlax): class DecoderSelfAttentionFlax(AttentionFlax):
def __call__(self, def __call__(
self,
decoder_state: jnp.ndarray, decoder_state: jnp.ndarray,
keys_state: jnp.ndarray, keys_state: jnp.ndarray,
values_state: jnp.ndarray, values_state: jnp.ndarray,
@ -77,7 +78,8 @@ class DalleBartDecoderLayerFlax(nn.Module):
self.glu = GLUFlax(self.embed_count, self.glu_embed_count) self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
@nn.compact @nn.compact
def __call__(self, def __call__(
self,
decoder_state: jnp.ndarray, decoder_state: jnp.ndarray,
encoder_state: jnp.ndarray, encoder_state: jnp.ndarray,
keys_state: jnp.ndarray, keys_state: jnp.ndarray,
@ -173,7 +175,8 @@ class DalleBartDecoderFlax(nn.Module):
self.final_ln = nn.LayerNorm(use_scale=False) self.final_ln = nn.LayerNorm(use_scale=False)
self.lm_head = nn.Dense(self.image_vocab_count + 1, use_bias=False) self.lm_head = nn.Dense(self.image_vocab_count + 1, use_bias=False)
def __call__(self, def __call__(
self,
encoder_state: jnp.ndarray, encoder_state: jnp.ndarray,
keys_state: jnp.ndarray, keys_state: jnp.ndarray,
values_state: jnp.ndarray, values_state: jnp.ndarray,
@ -198,7 +201,8 @@ class DalleBartDecoderFlax(nn.Module):
decoder_state = self.lm_head(decoder_state) decoder_state = self.lm_head(decoder_state)
return decoder_state, keys_state, values_state return decoder_state, keys_state, values_state
def sample_image_tokens(self, def sample_image_tokens(
self,
text_tokens: jnp.ndarray, text_tokens: jnp.ndarray,
encoder_state: jnp.ndarray, encoder_state: jnp.ndarray,
prng_key: jax.random.PRNGKey, prng_key: jax.random.PRNGKey,

View File

@ -26,7 +26,8 @@ class DecoderCrossAttentionTorch(AttentionTorch):
class DecoderSelfAttentionTorch(AttentionTorch): class DecoderSelfAttentionTorch(AttentionTorch):
def forward(self, def forward(
self,
decoder_state: FloatTensor, decoder_state: FloatTensor,
keys_values: FloatTensor, keys_values: FloatTensor,
attention_mask: BoolTensor, attention_mask: BoolTensor,
@ -49,7 +50,8 @@ class DecoderSelfAttentionTorch(AttentionTorch):
class DecoderLayerTorch(nn.Module): class DecoderLayerTorch(nn.Module):
def __init__(self, def __init__(
self,
image_token_count: int, image_token_count: int,
head_count: int, head_count: int,
embed_count: int, embed_count: int,
@ -69,7 +71,8 @@ class DecoderLayerTorch(nn.Module):
if torch.cuda.is_available(): if torch.cuda.is_available():
self.token_indices = self.token_indices.cuda() self.token_indices = self.token_indices.cuda()
def forward(self, def forward(
self,
decoder_state: FloatTensor, decoder_state: FloatTensor,
encoder_state: FloatTensor, encoder_state: FloatTensor,
keys_values_state: FloatTensor, keys_values_state: FloatTensor,
@ -111,7 +114,8 @@ class DecoderLayerTorch(nn.Module):
class DalleBartDecoderTorch(nn.Module): class DalleBartDecoderTorch(nn.Module):
def __init__(self, def __init__(
self,
image_vocab_size: int, image_vocab_size: int,
image_token_count: int, image_token_count: int,
sample_token_count: int, sample_token_count: int,
@ -158,7 +162,8 @@ class DalleBartDecoderTorch(nn.Module):
self.start_token = self.start_token.cuda() self.start_token = self.start_token.cuda()
def decode_step(self, def decode_step(
self,
text_tokens: LongTensor, text_tokens: LongTensor,
encoder_state: FloatTensor, encoder_state: FloatTensor,
keys_values_state: FloatTensor, keys_values_state: FloatTensor,
@ -198,7 +203,8 @@ class DalleBartDecoderTorch(nn.Module):
return probs, keys_values return probs, keys_values
def forward(self, def forward(
self,
text_tokens: LongTensor, text_tokens: LongTensor,
encoder_state: FloatTensor encoder_state: FloatTensor
) -> LongTensor: ) -> LongTensor:

View File

@ -34,7 +34,8 @@ class AttentionFlax(nn.Module):
self.v_proj = nn.Dense(self.embed_count, use_bias=False) self.v_proj = nn.Dense(self.embed_count, use_bias=False)
self.out_proj = nn.Dense(self.embed_count, use_bias=False) self.out_proj = nn.Dense(self.embed_count, use_bias=False)
def forward(self, def forward(
self,
keys: jnp.ndarray, keys: jnp.ndarray,
values: jnp.ndarray, values: jnp.ndarray,
queries: jnp.ndarray, queries: jnp.ndarray,
@ -92,7 +93,8 @@ class DalleBartEncoderLayerFlax(nn.Module):
self.glu = GLUFlax(self.embed_count, self.glu_embed_count) self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
@nn.compact @nn.compact
def __call__(self, def __call__(
self,
encoder_state: jnp.ndarray, encoder_state: jnp.ndarray,
attention_mask: jnp.ndarray attention_mask: jnp.ndarray
) -> jnp.ndarray: ) -> jnp.ndarray:

View File

@ -37,7 +37,8 @@ class AttentionTorch(nn.Module):
self.one = torch.ones((1, 1)) self.one = torch.ones((1, 1))
if torch.cuda.is_available(): self.one = self.one.cuda() if torch.cuda.is_available(): self.one = self.one.cuda()
def forward(self, def forward(
self,
keys: FloatTensor, keys: FloatTensor,
values: FloatTensor, values: FloatTensor,
queries: FloatTensor, queries: FloatTensor,
@ -105,7 +106,8 @@ class EncoderLayerTorch(nn.Module):
class DalleBartEncoderTorch(nn.Module): class DalleBartEncoderTorch(nn.Module):
def __init__(self, def __init__(
self,
layer_count: int, layer_count: int,
embed_count: int, embed_count: int,
attention_head_count: int, attention_head_count: int,

View File

@ -8,7 +8,7 @@ class TextTokenizer:
pairs = [tuple(pair.split()) for pair in merges] pairs = [tuple(pair.split()) for pair in merges]
self.rank_from_pair = dict(zip(pairs, range(len(pairs)))) self.rank_from_pair = dict(zip(pairs, range(len(pairs))))
def __call__(self, text: str) -> List[int]: def tokenize(self, text: str) -> List[int]:
sep_token = self.token_from_subword['</s>'] sep_token = self.token_from_subword['</s>']
cls_token = self.token_from_subword['<s>'] cls_token = self.token_from_subword['<s>']
unk_token = self.token_from_subword['<unk>'] unk_token = self.token_from_subword['<unk>']