refactored to load models once and run multiple times
This commit is contained in:
parent
1ef9b0b929
commit
ed91ab4a30
10
README.md
10
README.md
|
@ -2,18 +2,18 @@
|
|||
|
||||
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb)
|
||||
|
||||
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are `numpy` and `torch` for the torch model and `flax` for the flax model.
|
||||
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are `numpy`, `torch`, and `flax`.
|
||||
|
||||
### Setup
|
||||
|
||||
Run `sh setup.sh` to install dependencies and download pretrained models. The models can also be downloaded manually:
|
||||
Run `sh setup.sh` to install dependencies and download pretrained models. The `wandb` python package is installed to download DALL·E mini and DALL·E mega. Alternatively, the models can be downloaded manually here:
|
||||
[VQGan](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384),
|
||||
[DALL·E Mini](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mini-1/v0/files),
|
||||
[DALL·E Mega](https://wandb.ai/dalle-mini/dalle-mini/artifacts/DalleBart_model/mega-1-fp16/v14/files)
|
||||
|
||||
### Usage
|
||||
|
||||
Use the command line python script `image_from_text.py` to generate images. Here are some examples:
|
||||
The simplest way to get started is the command line python script `image_from_text.py` provided. Here are some examples runs:
|
||||
|
||||
```
|
||||
python image_from_text.py --text='alien life' --seed=7
|
||||
|
@ -32,3 +32,7 @@ python image_from_text.py --text='court sketch of godzilla on trial' --mega --se
|
|||
```
|
||||
|
||||
![Godzilla Trial](examples/godzilla_trial.png)
|
||||
|
||||
### Load once run multiple times
|
||||
|
||||
The command line script loads the models and parameters each time. The colab notebook demonstrates how to load the models once and run multiple times.
|
|
@ -2,8 +2,8 @@ import argparse
|
|||
import os
|
||||
from PIL import Image
|
||||
|
||||
from min_dalle.generate_image import generate_image_from_text
|
||||
|
||||
from min_dalle.min_dalle_torch import MinDalleTorch
|
||||
from min_dalle.min_dalle_flax import MinDalleFlax
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--mega', action='store_true')
|
||||
|
@ -15,7 +15,7 @@ parser.set_defaults(torch=False)
|
|||
parser.add_argument('--text', type=str)
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--image_path', type=str, default='generated')
|
||||
parser.add_argument('--image_token_count', type=int, default=256) # for debugging
|
||||
parser.add_argument('--sample_token_count', type=int, default=256) # for debugging
|
||||
|
||||
|
||||
def ascii_from_image(image: Image.Image, size: int) -> str:
|
||||
|
@ -36,19 +36,40 @@ def save_image(image: Image.Image, path: str):
|
|||
return image
|
||||
|
||||
|
||||
def generate_image(
|
||||
is_torch: bool,
|
||||
is_mega: bool,
|
||||
text: str,
|
||||
seed: int,
|
||||
image_path: str,
|
||||
sample_token_count: int
|
||||
):
|
||||
if is_torch:
|
||||
image_generator = MinDalleTorch(is_mega, sample_token_count)
|
||||
image_tokens = image_generator.generate_image_tokens(text, seed)
|
||||
|
||||
if sample_token_count < image_generator.config['image_length']:
|
||||
print('image tokens', list(image_tokens.to('cpu').detach().numpy()))
|
||||
return
|
||||
else:
|
||||
image = image_generator.generate_image(text, seed)
|
||||
|
||||
else:
|
||||
image_generator = MinDalleFlax(is_mega)
|
||||
image = image_generator.generate_image(text, seed)
|
||||
|
||||
save_image(image, image_path)
|
||||
print(ascii_from_image(image, size=128))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parser.parse_args()
|
||||
|
||||
print(args)
|
||||
|
||||
image = generate_image_from_text(
|
||||
text = args.text,
|
||||
is_mega = args.mega,
|
||||
is_torch = args.torch,
|
||||
seed = args.seed,
|
||||
image_token_count = args.image_token_count
|
||||
)
|
||||
|
||||
if image != None:
|
||||
save_image(image, args.image_path)
|
||||
print(ascii_from_image(image, size=128))
|
||||
generate_image(
|
||||
is_torch=args.torch,
|
||||
is_mega=args.mega,
|
||||
text=args.text,
|
||||
seed=args.seed,
|
||||
image_path=args.image_path,
|
||||
sample_token_count=args.sample_token_count
|
||||
)
|
|
@ -1,78 +0,0 @@
|
|||
import os
|
||||
import json
|
||||
import numpy
|
||||
from PIL import Image
|
||||
from typing import Tuple, List
|
||||
import torch
|
||||
|
||||
from min_dalle.load_params import load_dalle_bart_flax_params
|
||||
from min_dalle.text_tokenizer import TextTokenizer
|
||||
from min_dalle.min_dalle_flax import generate_image_tokens_flax
|
||||
from min_dalle.min_dalle_torch import (
|
||||
generate_image_tokens_torch,
|
||||
detokenize_torch
|
||||
)
|
||||
|
||||
def load_dalle_bart_metadata(path: str) -> Tuple[dict, dict, List[str]]:
|
||||
print("parsing metadata from {}".format(path))
|
||||
for f in ['config.json', 'flax_model.msgpack', 'vocab.json', 'merges.txt']:
|
||||
assert(os.path.exists(os.path.join(path, f)))
|
||||
with open(path + '/config.json', 'r') as f:
|
||||
config = json.load(f)
|
||||
with open(path + '/vocab.json') as f:
|
||||
vocab = json.load(f)
|
||||
with open(path + '/merges.txt') as f:
|
||||
merges = f.read().split("\n")[1:-1]
|
||||
return config, vocab, merges
|
||||
|
||||
|
||||
def tokenize_text(
|
||||
text: str,
|
||||
config: dict,
|
||||
vocab: dict,
|
||||
merges: List[str]
|
||||
) -> numpy.ndarray:
|
||||
print("tokenizing text")
|
||||
tokens = TextTokenizer(vocab, merges)(text)
|
||||
print("text tokens", tokens)
|
||||
text_tokens = numpy.ones((2, config['max_text_length']), dtype=numpy.int32)
|
||||
text_tokens[0, :len(tokens)] = tokens
|
||||
text_tokens[1, :2] = [tokens[0], tokens[-1]]
|
||||
return text_tokens
|
||||
|
||||
|
||||
def generate_image_from_text(
|
||||
text: str,
|
||||
is_mega: bool = False,
|
||||
is_torch: bool = False,
|
||||
seed: int = 0,
|
||||
image_token_count: int = 256
|
||||
) -> Image.Image:
|
||||
model_name = 'mega' if is_mega else 'mini'
|
||||
model_path = './pretrained/dalle_bart_{}'.format(model_name)
|
||||
config, vocab, merges = load_dalle_bart_metadata(model_path)
|
||||
text_tokens = tokenize_text(text, config, vocab, merges)
|
||||
params_dalle_bart = load_dalle_bart_flax_params(model_path)
|
||||
|
||||
if is_torch:
|
||||
image_tokens = generate_image_tokens_torch(
|
||||
text_tokens = text_tokens,
|
||||
seed = seed,
|
||||
config = config,
|
||||
params = params_dalle_bart,
|
||||
image_token_count = image_token_count
|
||||
)
|
||||
if image_token_count == config['image_length']:
|
||||
image = detokenize_torch(image_tokens, is_torch=True)
|
||||
return Image.fromarray(image)
|
||||
else:
|
||||
print(list(image_tokens.to('cpu').detach().numpy()))
|
||||
else:
|
||||
image_tokens = generate_image_tokens_flax(
|
||||
text_tokens = text_tokens,
|
||||
seed = seed,
|
||||
config = config,
|
||||
params = params_dalle_bart,
|
||||
)
|
||||
image = detokenize_torch(torch.tensor(image_tokens), is_torch=False)
|
||||
return Image.fromarray(image)
|
38
min_dalle/min_dalle.py
Normal file
38
min_dalle/min_dalle.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
import os
|
||||
import json
|
||||
import numpy
|
||||
|
||||
from .text_tokenizer import TextTokenizer
|
||||
from .load_params import load_vqgan_torch_params, load_dalle_bart_flax_params
|
||||
from .models.vqgan_detokenizer import VQGanDetokenizer
|
||||
|
||||
class MinDalle:
|
||||
def __init__(self, is_mega: bool):
|
||||
self.is_mega = is_mega
|
||||
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
|
||||
model_path = os.path.join('pretrained', model_name)
|
||||
|
||||
print("reading files from {}".format(model_path))
|
||||
with open(os.path.join(model_path, 'config.json'), 'r') as f:
|
||||
self.config = json.load(f)
|
||||
with open(os.path.join(model_path, 'vocab.json'), 'r') as f:
|
||||
vocab = json.load(f)
|
||||
with open(os.path.join(model_path, 'merges.txt'), 'r') as f:
|
||||
merges = f.read().split("\n")[1:-1]
|
||||
self.model_params = load_dalle_bart_flax_params(model_path)
|
||||
|
||||
self.tokenizer = TextTokenizer(vocab, merges)
|
||||
self.detokenizer = VQGanDetokenizer()
|
||||
vqgan_params = load_vqgan_torch_params('./pretrained/vqgan')
|
||||
self.detokenizer.load_state_dict(vqgan_params)
|
||||
|
||||
|
||||
def tokenize_text(self, text: str) -> numpy.ndarray:
|
||||
print("tokenizing text")
|
||||
tokens = self.tokenizer.tokenize(text)
|
||||
print("text tokens", tokens)
|
||||
text_token_count = self.config['max_text_length']
|
||||
text_tokens = numpy.ones((2, text_token_count), dtype=numpy.int32)
|
||||
text_tokens[0, :len(tokens)] = tokens
|
||||
text_tokens[1, :2] = [tokens[0], tokens[-1]]
|
||||
return text_tokens
|
|
@ -1,79 +1,58 @@
|
|||
import jax
|
||||
from jax import numpy as jnp
|
||||
import numpy
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
from .min_dalle import MinDalle
|
||||
from .models.dalle_bart_encoder_flax import DalleBartEncoderFlax
|
||||
from .models.dalle_bart_decoder_flax import DalleBartDecoderFlax
|
||||
|
||||
|
||||
def encode_flax(
|
||||
text_tokens: numpy.ndarray,
|
||||
config: dict,
|
||||
params: dict
|
||||
) -> jnp.ndarray:
|
||||
print("loading flax encoder")
|
||||
encoder: DalleBartEncoderFlax = DalleBartEncoderFlax(
|
||||
attention_head_count = config['encoder_attention_heads'],
|
||||
embed_count = config['d_model'],
|
||||
glu_embed_count = config['encoder_ffn_dim'],
|
||||
text_token_count = config['max_text_length'],
|
||||
text_vocab_count = config['encoder_vocab_size'],
|
||||
layer_count = config['encoder_layers']
|
||||
).bind({'params': params.pop('encoder')})
|
||||
class MinDalleFlax(MinDalle):
|
||||
def __init__(self, is_mega: bool):
|
||||
super().__init__(is_mega)
|
||||
print("initializing MinDalleFlax")
|
||||
|
||||
print("encoding text tokens")
|
||||
encoder_state = encoder(text_tokens)
|
||||
del encoder
|
||||
return encoder_state
|
||||
print("loading encoder")
|
||||
self.encoder = DalleBartEncoderFlax(
|
||||
attention_head_count = self.config['encoder_attention_heads'],
|
||||
embed_count = self.config['d_model'],
|
||||
glu_embed_count = self.config['encoder_ffn_dim'],
|
||||
text_token_count = self.config['max_text_length'],
|
||||
text_vocab_count = self.config['encoder_vocab_size'],
|
||||
layer_count = self.config['encoder_layers']
|
||||
).bind({'params': self.model_params.pop('encoder')})
|
||||
|
||||
print("loading decoder")
|
||||
self.decoder = DalleBartDecoderFlax(
|
||||
image_token_count = self.config['image_length'],
|
||||
text_token_count = self.config['max_text_length'],
|
||||
image_vocab_count = self.config['image_vocab_size'],
|
||||
attention_head_count = self.config['decoder_attention_heads'],
|
||||
embed_count = self.config['d_model'],
|
||||
glu_embed_count = self.config['decoder_ffn_dim'],
|
||||
layer_count = self.config['decoder_layers'],
|
||||
start_token = self.config['decoder_start_token_id']
|
||||
)
|
||||
|
||||
|
||||
def decode_flax(
|
||||
text_tokens: jnp.ndarray,
|
||||
encoder_state: jnp.ndarray,
|
||||
config: dict,
|
||||
seed: int,
|
||||
params: dict
|
||||
) -> jnp.ndarray:
|
||||
print("loading flax decoder")
|
||||
decoder = DalleBartDecoderFlax(
|
||||
image_token_count = config['image_length'],
|
||||
text_token_count = config['max_text_length'],
|
||||
image_vocab_count = config['image_vocab_size'],
|
||||
attention_head_count = config['decoder_attention_heads'],
|
||||
embed_count = config['d_model'],
|
||||
glu_embed_count = config['decoder_ffn_dim'],
|
||||
layer_count = config['decoder_layers'],
|
||||
start_token = config['decoder_start_token_id']
|
||||
)
|
||||
print("sampling image tokens")
|
||||
image_tokens = decoder.sample_image_tokens(
|
||||
text_tokens,
|
||||
encoder_state,
|
||||
jax.random.PRNGKey(seed),
|
||||
params.pop('decoder')
|
||||
)
|
||||
del decoder
|
||||
return image_tokens
|
||||
def generate_image(self, text: str, seed: int) -> Image.Image:
|
||||
text_tokens = self.tokenize_text(text)
|
||||
|
||||
print("encoding text tokens")
|
||||
encoder_state = self.encoder(text_tokens)
|
||||
|
||||
def generate_image_tokens_flax(
|
||||
text_tokens: numpy.ndarray,
|
||||
seed: int,
|
||||
config: dict,
|
||||
params: dict
|
||||
) -> numpy.ndarray:
|
||||
encoder_state = encode_flax(
|
||||
text_tokens,
|
||||
config,
|
||||
params
|
||||
)
|
||||
image_tokens = decode_flax(
|
||||
text_tokens,
|
||||
encoder_state,
|
||||
config,
|
||||
seed,
|
||||
params
|
||||
)
|
||||
image_tokens = numpy.array(image_tokens)
|
||||
print("image tokens", list(image_tokens))
|
||||
return image_tokens
|
||||
print("sampling image tokens")
|
||||
image_tokens = self.decoder.sample_image_tokens(
|
||||
text_tokens,
|
||||
encoder_state,
|
||||
jax.random.PRNGKey(seed),
|
||||
self.model_params['decoder']
|
||||
)
|
||||
|
||||
image_tokens = torch.tensor(numpy.array(image_tokens))
|
||||
|
||||
print("detokenizing image")
|
||||
image = self.detokenizer.forward(image_tokens).to(torch.uint8)
|
||||
image = Image.fromarray(image.to('cpu').detach().numpy())
|
||||
return image
|
|
@ -1,118 +1,83 @@
|
|||
from random import sample
|
||||
import numpy
|
||||
import os
|
||||
from PIL import Image
|
||||
from typing import Dict
|
||||
from torch import LongTensor, FloatTensor
|
||||
from torch import LongTensor
|
||||
import torch
|
||||
torch.set_grad_enabled(False)
|
||||
torch.set_num_threads(os.cpu_count())
|
||||
|
||||
from .models.vqgan_detokenizer import VQGanDetokenizer
|
||||
from .load_params import convert_dalle_bart_torch_from_flax_params
|
||||
from .min_dalle import MinDalle
|
||||
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
|
||||
from .models.dalle_bart_decoder_torch import DalleBartDecoderTorch
|
||||
|
||||
from .load_params import (
|
||||
load_vqgan_torch_params,
|
||||
convert_dalle_bart_torch_from_flax_params
|
||||
)
|
||||
|
||||
class MinDalleTorch(MinDalle):
|
||||
def __init__(self, is_mega: bool, sample_token_count: int = 256):
|
||||
super().__init__(is_mega)
|
||||
print("initializing MinDalleTorch")
|
||||
|
||||
print("loading encoder")
|
||||
self.encoder = DalleBartEncoderTorch(
|
||||
layer_count = self.config['encoder_layers'],
|
||||
embed_count = self.config['d_model'],
|
||||
attention_head_count = self.config['encoder_attention_heads'],
|
||||
text_vocab_count = self.config['encoder_vocab_size'],
|
||||
text_token_count = self.config['max_text_length'],
|
||||
glu_embed_count = self.config['encoder_ffn_dim']
|
||||
)
|
||||
encoder_params = convert_dalle_bart_torch_from_flax_params(
|
||||
self.model_params.pop('encoder'),
|
||||
layer_count=self.config['encoder_layers'],
|
||||
is_encoder=True
|
||||
)
|
||||
self.encoder.load_state_dict(encoder_params, strict=False)
|
||||
|
||||
print("loading decoder")
|
||||
self.decoder = DalleBartDecoderTorch(
|
||||
image_vocab_size = self.config['image_vocab_size'],
|
||||
image_token_count = self.config['image_length'],
|
||||
sample_token_count = sample_token_count,
|
||||
embed_count = self.config['d_model'],
|
||||
attention_head_count = self.config['decoder_attention_heads'],
|
||||
glu_embed_count = self.config['decoder_ffn_dim'],
|
||||
layer_count = self.config['decoder_layers'],
|
||||
batch_count = 2,
|
||||
start_token = self.config['decoder_start_token_id'],
|
||||
is_verbose = True
|
||||
)
|
||||
decoder_params = convert_dalle_bart_torch_from_flax_params(
|
||||
self.model_params.pop('decoder'),
|
||||
layer_count=self.config['decoder_layers'],
|
||||
is_encoder=False
|
||||
)
|
||||
self.decoder.load_state_dict(decoder_params, strict=False)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self.encoder = self.encoder.cuda()
|
||||
self.decoder = self.decoder.cuda()
|
||||
self.detokenizer = self.detokenizer.cuda()
|
||||
|
||||
|
||||
def encode_torch(
|
||||
text_tokens: LongTensor,
|
||||
config: dict,
|
||||
params: dict
|
||||
) -> FloatTensor:
|
||||
print("loading torch encoder")
|
||||
encoder = DalleBartEncoderTorch(
|
||||
layer_count = config['encoder_layers'],
|
||||
embed_count = config['d_model'],
|
||||
attention_head_count = config['encoder_attention_heads'],
|
||||
text_vocab_count = config['encoder_vocab_size'],
|
||||
text_token_count = config['max_text_length'],
|
||||
glu_embed_count = config['encoder_ffn_dim']
|
||||
)
|
||||
encoder_params = convert_dalle_bart_torch_from_flax_params(
|
||||
params.pop('encoder'),
|
||||
layer_count=config['encoder_layers'],
|
||||
is_encoder=True
|
||||
)
|
||||
encoder.load_state_dict(encoder_params, strict=False)
|
||||
del encoder_params
|
||||
if torch.cuda.is_available(): encoder = encoder.cuda()
|
||||
def generate_image_tokens(self, text: str, seed: int) -> LongTensor:
|
||||
text_tokens = self.tokenize_text(text)
|
||||
text_tokens = torch.tensor(text_tokens).to(torch.long)
|
||||
if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
|
||||
|
||||
print("encoding text tokens")
|
||||
encoder_state = encoder(text_tokens)
|
||||
del encoder
|
||||
return encoder_state
|
||||
print("encoding text tokens")
|
||||
encoder_state = self.encoder.forward(text_tokens)
|
||||
|
||||
print("sampling image tokens")
|
||||
torch.manual_seed(seed)
|
||||
image_tokens = self.decoder.forward(text_tokens, encoder_state)
|
||||
return image_tokens
|
||||
|
||||
|
||||
def decode_torch(
|
||||
text_tokens: LongTensor,
|
||||
encoder_state: FloatTensor,
|
||||
config: dict,
|
||||
seed: int,
|
||||
params: dict,
|
||||
image_token_count: int
|
||||
) -> LongTensor:
|
||||
print("loading torch decoder")
|
||||
decoder = DalleBartDecoderTorch(
|
||||
image_vocab_size = config['image_vocab_size'],
|
||||
image_token_count = config['image_length'],
|
||||
sample_token_count = image_token_count,
|
||||
embed_count = config['d_model'],
|
||||
attention_head_count = config['decoder_attention_heads'],
|
||||
glu_embed_count = config['decoder_ffn_dim'],
|
||||
layer_count = config['decoder_layers'],
|
||||
batch_count = 2,
|
||||
start_token = config['decoder_start_token_id'],
|
||||
is_verbose = True
|
||||
)
|
||||
decoder_params = convert_dalle_bart_torch_from_flax_params(
|
||||
params.pop('decoder'),
|
||||
layer_count=config['decoder_layers'],
|
||||
is_encoder=False
|
||||
)
|
||||
decoder.load_state_dict(decoder_params, strict=False)
|
||||
del decoder_params
|
||||
if torch.cuda.is_available(): decoder = decoder.cuda()
|
||||
|
||||
print("sampling image tokens")
|
||||
torch.manual_seed(seed)
|
||||
image_tokens = decoder.forward(text_tokens, encoder_state)
|
||||
return image_tokens
|
||||
|
||||
|
||||
def generate_image_tokens_torch(
|
||||
text_tokens: numpy.ndarray,
|
||||
seed: int,
|
||||
config: dict,
|
||||
params: dict,
|
||||
image_token_count: int
|
||||
) -> LongTensor:
|
||||
text_tokens = torch.tensor(text_tokens).to(torch.long)
|
||||
if torch.cuda.is_available(): text_tokens = text_tokens.cuda()
|
||||
encoder_state = encode_torch(
|
||||
text_tokens,
|
||||
config,
|
||||
params
|
||||
)
|
||||
image_tokens = decode_torch(
|
||||
text_tokens,
|
||||
encoder_state,
|
||||
config,
|
||||
seed,
|
||||
params,
|
||||
image_token_count
|
||||
)
|
||||
return image_tokens
|
||||
|
||||
|
||||
def detokenize_torch(image_tokens: LongTensor, is_torch: bool) -> numpy.ndarray:
|
||||
print("detokenizing image")
|
||||
model_path = './pretrained/vqgan'
|
||||
params = load_vqgan_torch_params(model_path)
|
||||
detokenizer = VQGanDetokenizer()
|
||||
detokenizer.load_state_dict(params)
|
||||
if torch.cuda.is_available() and is_torch: detokenizer = detokenizer.cuda()
|
||||
image = detokenizer.forward(image_tokens).to(torch.uint8)
|
||||
del detokenizer, params
|
||||
return image.to('cpu').detach().numpy()
|
||||
def generate_image(self, text: str, seed: int) -> Image.Image:
|
||||
image_tokens = self.generate_image_tokens(text, seed)
|
||||
print("detokenizing image")
|
||||
image = self.detokenizer.forward(image_tokens).to(torch.uint8)
|
||||
image = Image.fromarray(image.to('cpu').detach().numpy())
|
||||
return image
|
|
@ -26,7 +26,8 @@ class DecoderCrossAttentionFlax(AttentionFlax):
|
|||
|
||||
|
||||
class DecoderSelfAttentionFlax(AttentionFlax):
|
||||
def __call__(self,
|
||||
def __call__(
|
||||
self,
|
||||
decoder_state: jnp.ndarray,
|
||||
keys_state: jnp.ndarray,
|
||||
values_state: jnp.ndarray,
|
||||
|
@ -77,7 +78,8 @@ class DalleBartDecoderLayerFlax(nn.Module):
|
|||
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
|
||||
|
||||
@nn.compact
|
||||
def __call__(self,
|
||||
def __call__(
|
||||
self,
|
||||
decoder_state: jnp.ndarray,
|
||||
encoder_state: jnp.ndarray,
|
||||
keys_state: jnp.ndarray,
|
||||
|
@ -173,7 +175,8 @@ class DalleBartDecoderFlax(nn.Module):
|
|||
self.final_ln = nn.LayerNorm(use_scale=False)
|
||||
self.lm_head = nn.Dense(self.image_vocab_count + 1, use_bias=False)
|
||||
|
||||
def __call__(self,
|
||||
def __call__(
|
||||
self,
|
||||
encoder_state: jnp.ndarray,
|
||||
keys_state: jnp.ndarray,
|
||||
values_state: jnp.ndarray,
|
||||
|
@ -198,7 +201,8 @@ class DalleBartDecoderFlax(nn.Module):
|
|||
decoder_state = self.lm_head(decoder_state)
|
||||
return decoder_state, keys_state, values_state
|
||||
|
||||
def sample_image_tokens(self,
|
||||
def sample_image_tokens(
|
||||
self,
|
||||
text_tokens: jnp.ndarray,
|
||||
encoder_state: jnp.ndarray,
|
||||
prng_key: jax.random.PRNGKey,
|
||||
|
|
|
@ -26,7 +26,8 @@ class DecoderCrossAttentionTorch(AttentionTorch):
|
|||
|
||||
|
||||
class DecoderSelfAttentionTorch(AttentionTorch):
|
||||
def forward(self,
|
||||
def forward(
|
||||
self,
|
||||
decoder_state: FloatTensor,
|
||||
keys_values: FloatTensor,
|
||||
attention_mask: BoolTensor,
|
||||
|
@ -49,7 +50,8 @@ class DecoderSelfAttentionTorch(AttentionTorch):
|
|||
|
||||
|
||||
class DecoderLayerTorch(nn.Module):
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
image_token_count: int,
|
||||
head_count: int,
|
||||
embed_count: int,
|
||||
|
@ -69,7 +71,8 @@ class DecoderLayerTorch(nn.Module):
|
|||
if torch.cuda.is_available():
|
||||
self.token_indices = self.token_indices.cuda()
|
||||
|
||||
def forward(self,
|
||||
def forward(
|
||||
self,
|
||||
decoder_state: FloatTensor,
|
||||
encoder_state: FloatTensor,
|
||||
keys_values_state: FloatTensor,
|
||||
|
@ -111,7 +114,8 @@ class DecoderLayerTorch(nn.Module):
|
|||
|
||||
|
||||
class DalleBartDecoderTorch(nn.Module):
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
image_vocab_size: int,
|
||||
image_token_count: int,
|
||||
sample_token_count: int,
|
||||
|
@ -158,7 +162,8 @@ class DalleBartDecoderTorch(nn.Module):
|
|||
self.start_token = self.start_token.cuda()
|
||||
|
||||
|
||||
def decode_step(self,
|
||||
def decode_step(
|
||||
self,
|
||||
text_tokens: LongTensor,
|
||||
encoder_state: FloatTensor,
|
||||
keys_values_state: FloatTensor,
|
||||
|
@ -198,7 +203,8 @@ class DalleBartDecoderTorch(nn.Module):
|
|||
return probs, keys_values
|
||||
|
||||
|
||||
def forward(self,
|
||||
def forward(
|
||||
self,
|
||||
text_tokens: LongTensor,
|
||||
encoder_state: FloatTensor
|
||||
) -> LongTensor:
|
||||
|
|
|
@ -34,7 +34,8 @@ class AttentionFlax(nn.Module):
|
|||
self.v_proj = nn.Dense(self.embed_count, use_bias=False)
|
||||
self.out_proj = nn.Dense(self.embed_count, use_bias=False)
|
||||
|
||||
def forward(self,
|
||||
def forward(
|
||||
self,
|
||||
keys: jnp.ndarray,
|
||||
values: jnp.ndarray,
|
||||
queries: jnp.ndarray,
|
||||
|
@ -92,7 +93,8 @@ class DalleBartEncoderLayerFlax(nn.Module):
|
|||
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
|
||||
|
||||
@nn.compact
|
||||
def __call__(self,
|
||||
def __call__(
|
||||
self,
|
||||
encoder_state: jnp.ndarray,
|
||||
attention_mask: jnp.ndarray
|
||||
) -> jnp.ndarray:
|
||||
|
|
|
@ -37,7 +37,8 @@ class AttentionTorch(nn.Module):
|
|||
self.one = torch.ones((1, 1))
|
||||
if torch.cuda.is_available(): self.one = self.one.cuda()
|
||||
|
||||
def forward(self,
|
||||
def forward(
|
||||
self,
|
||||
keys: FloatTensor,
|
||||
values: FloatTensor,
|
||||
queries: FloatTensor,
|
||||
|
@ -105,7 +106,8 @@ class EncoderLayerTorch(nn.Module):
|
|||
|
||||
|
||||
class DalleBartEncoderTorch(nn.Module):
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
layer_count: int,
|
||||
embed_count: int,
|
||||
attention_head_count: int,
|
||||
|
|
|
@ -8,7 +8,7 @@ class TextTokenizer:
|
|||
pairs = [tuple(pair.split()) for pair in merges]
|
||||
self.rank_from_pair = dict(zip(pairs, range(len(pairs))))
|
||||
|
||||
def __call__(self, text: str) -> List[int]:
|
||||
def tokenize(self, text: str) -> List[int]:
|
||||
sep_token = self.token_from_subword['</s>']
|
||||
cls_token = self.token_from_subword['<s>']
|
||||
unk_token = self.token_from_subword['<unk>']
|
||||
|
|
Loading…
Reference in New Issue
Block a user