Merge branch 'main' into patch-1
This commit is contained in:
commit
8b77428102
14
README.md
vendored
14
README.md
vendored
|
@ -8,7 +8,7 @@
|
||||||
|
|
||||||
This is a minimal implementation of Boris Dayma's [DALL·E Mini](https://github.com/borisdayma/dalle-mini) in PyTorch. It has been stripped to the bare essentials necessary for doing inference. The only third party dependencies are numpy, requests, pillow and torch.
|
This is a minimal implementation of Boris Dayma's [DALL·E Mini](https://github.com/borisdayma/dalle-mini) in PyTorch. It has been stripped to the bare essentials necessary for doing inference. The only third party dependencies are numpy, requests, pillow and torch.
|
||||||
|
|
||||||
It currently takes **7.4 seconds** to generate an image with DALL·E Mega on a standard GPU runtime in Colab.
|
It currently take **35 seconds** to generate a 3x3 grid with DALL·E Mega on a standard GPU runtime in Colab.
|
||||||
|
|
||||||
The flax model and code for converting it to torch can be found [here](https://github.com/kuprel/min-dalle-flax).
|
The flax model and code for converting it to torch can be found [here](https://github.com/kuprel/min-dalle-flax).
|
||||||
|
|
||||||
|
@ -33,18 +33,18 @@ model = MinDalle(is_mega=True, models_root='./pretrained')
|
||||||
The required models will be downloaded to `models_root` if they are not already there. Once everything has finished initializing, call `generate_image` with some text and a seed as many times as you want.
|
The required models will be downloaded to `models_root` if they are not already there. Once everything has finished initializing, call `generate_image` with some text and a seed as many times as you want.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
text = "a comfy chair that looks like an avocado"
|
text = 'a comfy chair that looks like an avocado'
|
||||||
image = model.generate_image(text)
|
image = model.generate_image(text)
|
||||||
display(image)
|
display(image)
|
||||||
```
|
```
|
||||||
![Avocado Armchair](https://github.com/kuprel/min-dalle/raw/main/examples/avocado_armchair.png)
|
![Avocado Armchair](https://github.com/kuprel/min-dalle/raw/main/examples/avocado_armchair.png)
|
||||||
|
|
||||||
```python
|
```python
|
||||||
text = "trail cam footage of gollum eating watermelon"
|
text = 'court sketch of godzilla on trial'
|
||||||
image = model.generate_image(text, seed=1)
|
image = model.generate_image(text, seed=6, grid_size=3)
|
||||||
display(image)
|
display(image)
|
||||||
```
|
```
|
||||||
![Gollum Trailcam](https://github.com/kuprel/min-dalle/raw/main/examples/gollum_trailcam.png)
|
![Godzilla Trial](https://github.com/kuprel/min-dalle/raw/main/examples/godzilla_trial.png)
|
||||||
|
|
||||||
|
|
||||||
### Command Line
|
### Command Line
|
||||||
|
@ -57,6 +57,6 @@ $ python image_from_text.py --text='artificial intelligence' --seed=7
|
||||||
![Artificial Intelligence](https://github.com/kuprel/min-dalle/raw/main/examples/artificial_intelligence.png)
|
![Artificial Intelligence](https://github.com/kuprel/min-dalle/raw/main/examples/artificial_intelligence.png)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ python image_from_text.py --text='court sketch of godzilla on trial' --mega
|
$ python image_from_text.py --text='trail cam footage of gollum eating watermelon' --mega --seed=1 --grid-size=3
|
||||||
```
|
```
|
||||||
![Godzilla Trial](https://github.com/kuprel/min-dalle/raw/main/examples/godzilla_on_trial.png)
|
![Gollum Trailcam](https://github.com/kuprel/min-dalle/raw/main/examples/gollum_trailcam.png)
|
||||||
|
|
BIN
examples/godzilla_on_trial.png
vendored
BIN
examples/godzilla_on_trial.png
vendored
Binary file not shown.
Before Width: | Height: | Size: 155 KiB |
BIN
examples/godzilla_trial.png
vendored
Normal file
BIN
examples/godzilla_trial.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.2 MiB |
BIN
examples/gollum_trailcam.png
vendored
BIN
examples/gollum_trailcam.png
vendored
Binary file not shown.
Before Width: | Height: | Size: 82 KiB After Width: | Height: | Size: 892 KiB |
|
@ -11,9 +11,10 @@ parser.add_argument('--no-mega', dest='mega', action='store_false')
|
||||||
parser.set_defaults(mega=False)
|
parser.set_defaults(mega=False)
|
||||||
parser.add_argument('--text', type=str, default='alien life')
|
parser.add_argument('--text', type=str, default='alien life')
|
||||||
parser.add_argument('--seed', type=int, default=-1)
|
parser.add_argument('--seed', type=int, default=-1)
|
||||||
parser.add_argument('--image_path', type=str, default='generated')
|
parser.add_argument('--grid-size', type=int, default=1)
|
||||||
parser.add_argument('--models_root', type=str, default='pretrained')
|
parser.add_argument('--image-path', type=str, default='generated')
|
||||||
parser.add_argument('--token_count', type=int, default=256) # for debugging
|
parser.add_argument('--models-root', type=str, default='pretrained')
|
||||||
|
parser.add_argument('--token-count', type=int, default=256) # for debugging
|
||||||
|
|
||||||
|
|
||||||
def ascii_from_image(image: Image.Image, size: int) -> str:
|
def ascii_from_image(image: Image.Image, size: int) -> str:
|
||||||
|
@ -38,6 +39,7 @@ def generate_image(
|
||||||
is_mega: bool,
|
is_mega: bool,
|
||||||
text: str,
|
text: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
|
grid_size: int,
|
||||||
image_path: str,
|
image_path: str,
|
||||||
models_root: str,
|
models_root: str,
|
||||||
token_count: int
|
token_count: int
|
||||||
|
@ -51,10 +53,10 @@ def generate_image(
|
||||||
)
|
)
|
||||||
|
|
||||||
if token_count < 256:
|
if token_count < 256:
|
||||||
image_tokens = model.generate_image_tokens(text, seed)
|
image_tokens = model.generate_image_tokens(text, seed, grid_size ** 2)
|
||||||
print('image tokens', list(image_tokens.to('cpu').detach().numpy()))
|
print('image tokens', image_tokens.to('cpu').detach().numpy())
|
||||||
else:
|
else:
|
||||||
image = model.generate_image(text, seed)
|
image = model.generate_image(text, seed, grid_size)
|
||||||
save_image(image, image_path)
|
save_image(image, image_path)
|
||||||
print(ascii_from_image(image, size=128))
|
print(ascii_from_image(image, size=128))
|
||||||
|
|
||||||
|
@ -66,6 +68,7 @@ if __name__ == '__main__':
|
||||||
is_mega=args.mega,
|
is_mega=args.mega,
|
||||||
text=args.text,
|
text=args.text,
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
|
grid_size=args.grid_size,
|
||||||
image_path=args.image_path,
|
image_path=args.image_path,
|
||||||
models_root=args.models_root,
|
models_root=args.models_root,
|
||||||
token_count=args.token_count
|
token_count=args.token_count
|
||||||
|
|
68
min_dalle.ipynb
vendored
68
min_dalle.ipynb
vendored
File diff suppressed because one or more lines are too long
|
@ -5,7 +5,6 @@ from torch import LongTensor
|
||||||
import torch
|
import torch
|
||||||
import json
|
import json
|
||||||
import requests
|
import requests
|
||||||
import random
|
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
torch.set_num_threads(os.cpu_count())
|
torch.set_num_threads(os.cpu_count())
|
||||||
|
|
||||||
|
@ -28,7 +27,6 @@ class MinDalle:
|
||||||
self.is_reusable = is_reusable
|
self.is_reusable = is_reusable
|
||||||
self.is_verbose = is_verbose
|
self.is_verbose = is_verbose
|
||||||
self.sample_token_count = sample_token_count
|
self.sample_token_count = sample_token_count
|
||||||
self.batch_count = 2
|
|
||||||
self.text_token_count = 64
|
self.text_token_count = 64
|
||||||
self.image_token_count = 256
|
self.image_token_count = 256
|
||||||
self.layer_count = 24 if is_mega else 12
|
self.layer_count = 24 if is_mega else 12
|
||||||
|
@ -128,8 +126,7 @@ class MinDalle:
|
||||||
embed_count = self.embed_count,
|
embed_count = self.embed_count,
|
||||||
glu_embed_count = self.glu_embed_count,
|
glu_embed_count = self.glu_embed_count,
|
||||||
layer_count = self.layer_count,
|
layer_count = self.layer_count,
|
||||||
start_token = self.image_vocab_count,
|
start_token = self.image_vocab_count
|
||||||
batch_count = self.batch_count
|
|
||||||
)
|
)
|
||||||
params = torch.load(self.decoder_params_path)
|
params = torch.load(self.decoder_params_path)
|
||||||
self.decoder.load_state_dict(params, strict=False)
|
self.decoder.load_state_dict(params, strict=False)
|
||||||
|
@ -148,7 +145,12 @@ class MinDalle:
|
||||||
if torch.cuda.is_available(): self.detokenizer = self.detokenizer.cuda()
|
if torch.cuda.is_available(): self.detokenizer = self.detokenizer.cuda()
|
||||||
|
|
||||||
|
|
||||||
def generate_image_tokens(self, text: str, seed: int) -> LongTensor:
|
def generate_image_tokens(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
seed: int,
|
||||||
|
image_count: int
|
||||||
|
) -> LongTensor:
|
||||||
if self.is_verbose: print("tokenizing text")
|
if self.is_verbose: print("tokenizing text")
|
||||||
tokens = self.tokenizer.tokenize(text)
|
tokens = self.tokenizer.tokenize(text)
|
||||||
if self.is_verbose: print("text tokens", tokens)
|
if self.is_verbose: print("text tokens", tokens)
|
||||||
|
@ -166,18 +168,29 @@ class MinDalle:
|
||||||
|
|
||||||
if not self.is_reusable: self.init_decoder()
|
if not self.is_reusable: self.init_decoder()
|
||||||
if self.is_verbose: print("sampling image tokens")
|
if self.is_verbose: print("sampling image tokens")
|
||||||
if seed < 0: seed = random.randint(0, 2 ** 31)
|
if seed > 0: torch.manual_seed(seed)
|
||||||
torch.manual_seed(seed)
|
image_tokens = self.decoder.forward(
|
||||||
image_tokens = self.decoder.forward(text_tokens, encoder_state)
|
image_count,
|
||||||
|
text_tokens,
|
||||||
|
encoder_state
|
||||||
|
)
|
||||||
if not self.is_reusable: del self.decoder
|
if not self.is_reusable: del self.decoder
|
||||||
return image_tokens
|
return image_tokens
|
||||||
|
|
||||||
|
|
||||||
def generate_image(self, text: str, seed: int) -> Image.Image:
|
def generate_image(
|
||||||
image_tokens = self.generate_image_tokens(text, seed)
|
self,
|
||||||
|
text: str,
|
||||||
|
seed: int = -1,
|
||||||
|
grid_size: int = 1
|
||||||
|
) -> Image.Image:
|
||||||
|
image_count = grid_size ** 2
|
||||||
|
image_tokens = self.generate_image_tokens(text, seed, image_count)
|
||||||
if not self.is_reusable: self.init_detokenizer()
|
if not self.is_reusable: self.init_detokenizer()
|
||||||
if self.is_verbose: print("detokenizing image")
|
if self.is_verbose: print("detokenizing image")
|
||||||
image = self.detokenizer.forward(image_tokens).to(torch.uint8)
|
images = self.detokenizer.forward(image_tokens).to(torch.uint8)
|
||||||
if not self.is_reusable: del self.detokenizer
|
if not self.is_reusable: del self.detokenizer
|
||||||
|
images = images.reshape([grid_size] * 2 + list(images.shape[1:]))
|
||||||
|
image = images.flatten(1, 2).transpose(0, 1).flatten(1, 2)
|
||||||
image = Image.fromarray(image.to('cpu').detach().numpy())
|
image = Image.fromarray(image.to('cpu').detach().numpy())
|
||||||
return image
|
return image
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import List, Tuple
|
from typing import Tuple, List
|
||||||
import torch
|
import torch
|
||||||
from torch import LongTensor, nn, FloatTensor, BoolTensor
|
from torch import LongTensor, nn, FloatTensor, BoolTensor
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
|
@ -76,8 +76,8 @@ class DecoderLayer(nn.Module):
|
||||||
residual = decoder_state
|
residual = decoder_state
|
||||||
decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state)
|
decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state)
|
||||||
self_attn_mask = self.token_indices < token_index + 1
|
self_attn_mask = self.token_indices < token_index + 1
|
||||||
|
self_attn_mask = self_attn_mask[None][[0] * decoder_state.shape[0]]
|
||||||
token_mask = self.token_indices == token_index
|
token_mask = self.token_indices == token_index
|
||||||
self_attn_mask = torch.stack([self_attn_mask] * decoder_state.shape[0])
|
|
||||||
decoder_state, attention_state = self.self_attn.forward(
|
decoder_state, attention_state = self.self_attn.forward(
|
||||||
decoder_state,
|
decoder_state,
|
||||||
attention_state,
|
attention_state,
|
||||||
|
@ -116,11 +116,11 @@ class DalleBartDecoder(nn.Module):
|
||||||
attention_head_count: int,
|
attention_head_count: int,
|
||||||
glu_embed_count: int,
|
glu_embed_count: int,
|
||||||
layer_count: int,
|
layer_count: int,
|
||||||
batch_count: int,
|
|
||||||
start_token: int
|
start_token: int
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer_count = layer_count
|
self.layer_count = layer_count
|
||||||
|
self.embed_count = embed_count
|
||||||
self.sample_token_count = sample_token_count
|
self.sample_token_count = sample_token_count
|
||||||
self.condition_factor = 10.0
|
self.condition_factor = 10.0
|
||||||
self.image_token_count = image_token_count
|
self.image_token_count = image_token_count
|
||||||
|
@ -138,12 +138,6 @@ class DalleBartDecoder(nn.Module):
|
||||||
self.layernorm_embedding = nn.LayerNorm(embed_count)
|
self.layernorm_embedding = nn.LayerNorm(embed_count)
|
||||||
self.final_ln = nn.LayerNorm(embed_count)
|
self.final_ln = nn.LayerNorm(embed_count)
|
||||||
self.lm_head = nn.Linear(embed_count, image_vocab_count + 1, bias=False)
|
self.lm_head = nn.Linear(embed_count, image_vocab_count + 1, bias=False)
|
||||||
self.attention_state_shape = (
|
|
||||||
layer_count,
|
|
||||||
2 * batch_count,
|
|
||||||
image_token_count,
|
|
||||||
embed_count
|
|
||||||
)
|
|
||||||
self.zero_prob = torch.zeros([1])
|
self.zero_prob = torch.zeros([1])
|
||||||
self.token_indices = torch.arange(self.sample_token_count)
|
self.token_indices = torch.arange(self.sample_token_count)
|
||||||
self.start_token = torch.tensor([start_token]).to(torch.long)
|
self.start_token = torch.tensor([start_token]).to(torch.long)
|
||||||
|
@ -155,17 +149,16 @@ class DalleBartDecoder(nn.Module):
|
||||||
|
|
||||||
def decode_step(
|
def decode_step(
|
||||||
self,
|
self,
|
||||||
text_tokens: LongTensor,
|
attention_mask: BoolTensor,
|
||||||
encoder_state: FloatTensor,
|
encoder_state: FloatTensor,
|
||||||
attention_state: FloatTensor,
|
attention_state: FloatTensor,
|
||||||
prev_token: LongTensor,
|
prev_tokens: LongTensor,
|
||||||
token_index: LongTensor
|
token_index: LongTensor
|
||||||
) -> Tuple[LongTensor, FloatTensor]:
|
) -> Tuple[LongTensor, FloatTensor]:
|
||||||
attention_mask = text_tokens.not_equal(1)
|
image_count = encoder_state.shape[0] // 2
|
||||||
batch_count = encoder_state.shape[0]
|
token_index_batched = token_index[[0] * image_count * 2]
|
||||||
prev_token_batched = torch.cat([prev_token] * batch_count)
|
prev_tokens = prev_tokens[list(range(image_count)) * 2]
|
||||||
token_index_batched = torch.cat([token_index] * batch_count)
|
decoder_state = self.embed_tokens.forward(prev_tokens)
|
||||||
decoder_state = self.embed_tokens.forward(prev_token_batched)
|
|
||||||
decoder_state += self.embed_positions.forward(token_index_batched)
|
decoder_state += self.embed_positions.forward(token_index_batched)
|
||||||
decoder_state = self.layernorm_embedding.forward(decoder_state)
|
decoder_state = self.layernorm_embedding.forward(decoder_state)
|
||||||
decoder_state = decoder_state[:, None]
|
decoder_state = decoder_state[:, None]
|
||||||
|
@ -182,38 +175,52 @@ class DalleBartDecoder(nn.Module):
|
||||||
decoder_state = self.final_ln(decoder_state)
|
decoder_state = self.final_ln(decoder_state)
|
||||||
logits = self.lm_head(decoder_state)
|
logits = self.lm_head(decoder_state)
|
||||||
a = self.condition_factor
|
a = self.condition_factor
|
||||||
logits: FloatTensor = (1 - a) * logits[0, -1] + a * logits[1, -1]
|
logits: FloatTensor = (
|
||||||
|
logits[:image_count, -1] * (1 - a) +
|
||||||
|
logits[image_count:, -1] * a
|
||||||
|
)
|
||||||
|
|
||||||
top_logits, _ = logits.topk(50, dim=-1)
|
top_logits, _ = logits.topk(50, dim=-1)
|
||||||
probs = torch.where(
|
probs = torch.where(
|
||||||
logits < top_logits[-1],
|
logits < top_logits[:, [-1]],
|
||||||
self.zero_prob,
|
self.zero_prob,
|
||||||
torch.exp(logits - top_logits[0])
|
torch.exp(logits - top_logits[:, [0]])
|
||||||
)
|
)
|
||||||
return probs, torch.stack(attention_states_new)
|
return probs, torch.stack(attention_states_new)
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
image_count: int,
|
||||||
text_tokens: LongTensor,
|
text_tokens: LongTensor,
|
||||||
encoder_state: FloatTensor
|
encoder_state: FloatTensor
|
||||||
) -> LongTensor:
|
) -> LongTensor:
|
||||||
image_tokens: List[LongTensor] = []
|
expanded_indices = [0] * image_count + [1] * image_count
|
||||||
attention_state = torch.zeros(self.attention_state_shape)
|
text_tokens = text_tokens[expanded_indices]
|
||||||
if torch.cuda.is_available():
|
encoder_state = encoder_state[expanded_indices]
|
||||||
attention_state = attention_state.cuda()
|
attention_mask = text_tokens.not_equal(1)
|
||||||
image_token = self.start_token
|
|
||||||
|
|
||||||
|
attention_state_shape = (
|
||||||
|
self.layer_count,
|
||||||
|
image_count * 4,
|
||||||
|
self.image_token_count,
|
||||||
|
self.embed_count
|
||||||
|
)
|
||||||
|
attention_state = torch.zeros(attention_state_shape)
|
||||||
|
if torch.cuda.is_available(): attention_state = attention_state.cuda()
|
||||||
|
|
||||||
|
image_tokens = self.start_token[[0] * image_count]
|
||||||
|
image_tokens_sequence: List[LongTensor] = []
|
||||||
for i in range(self.sample_token_count):
|
for i in range(self.sample_token_count):
|
||||||
probs, attention_state = self.decode_step(
|
probs, attention_state = self.decode_step(
|
||||||
text_tokens = text_tokens,
|
attention_mask = attention_mask,
|
||||||
encoder_state = encoder_state,
|
encoder_state = encoder_state,
|
||||||
attention_state = attention_state,
|
attention_state = attention_state,
|
||||||
prev_token = image_token,
|
prev_tokens = image_tokens,
|
||||||
token_index = self.token_indices[[i]]
|
token_index = self.token_indices[[i]]
|
||||||
)
|
)
|
||||||
|
|
||||||
image_token = torch.multinomial(probs, 1)
|
image_tokens = torch.multinomial(probs, 1)[:, 0]
|
||||||
image_tokens += [image_token]
|
image_tokens_sequence += [image_tokens]
|
||||||
|
|
||||||
return torch.cat(image_tokens)
|
return torch.stack(image_tokens_sequence).T
|
|
@ -137,8 +137,7 @@ class DalleBartEncoder(nn.Module):
|
||||||
|
|
||||||
def forward(self, text_tokens: LongTensor) -> FloatTensor:
|
def forward(self, text_tokens: LongTensor) -> FloatTensor:
|
||||||
attention_mask = text_tokens.not_equal(1)
|
attention_mask = text_tokens.not_equal(1)
|
||||||
batch_count = text_tokens.shape[0]
|
pose_tokens = self.token_indices[None][[0] * text_tokens.shape[0]]
|
||||||
pose_tokens = torch.stack([self.token_indices] * batch_count)
|
|
||||||
encoder_state = (
|
encoder_state = (
|
||||||
self.embed_tokens.forward(text_tokens) +
|
self.embed_tokens.forward(text_tokens) +
|
||||||
self.embed_positions.forward(pose_tokens)
|
self.embed_positions.forward(pose_tokens)
|
||||||
|
|
|
@ -3,8 +3,6 @@ from torch import Tensor
|
||||||
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
|
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
BATCH_COUNT: int = 1
|
|
||||||
|
|
||||||
|
|
||||||
class ResnetBlock(Module):
|
class ResnetBlock(Module):
|
||||||
def __init__(self, log2_count_in: int, log2_count_out: int):
|
def __init__(self, log2_count_in: int, log2_count_out: int):
|
||||||
|
@ -42,22 +40,22 @@ class AttentionBlock(Module):
|
||||||
self.proj_out = Conv2d(n, n, 1)
|
self.proj_out = Conv2d(n, n, 1)
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
n = 2 ** 9
|
n, m = 2 ** 9, x.shape[0]
|
||||||
h = x
|
h = x
|
||||||
h = self.norm(h)
|
h = self.norm(h)
|
||||||
q = self.q.forward(h)
|
q = self.q.forward(h)
|
||||||
k = self.k.forward(h)
|
k = self.k.forward(h)
|
||||||
v = self.v.forward(h)
|
v = self.v.forward(h)
|
||||||
q = q.reshape(BATCH_COUNT, n, 2 ** 8)
|
q = q.reshape(m, n, 2 ** 8)
|
||||||
q = q.permute(0, 2, 1)
|
q = q.permute(0, 2, 1)
|
||||||
k = k.reshape(BATCH_COUNT, n, 2 ** 8)
|
k = k.reshape(m, n, 2 ** 8)
|
||||||
w = torch.bmm(q, k)
|
w = torch.bmm(q, k)
|
||||||
w /= n ** 0.5
|
w /= n ** 0.5
|
||||||
w = torch.softmax(w, dim=2)
|
w = torch.softmax(w, dim=2)
|
||||||
v = v.reshape(BATCH_COUNT, n, 2 ** 8)
|
v = v.reshape(m, n, 2 ** 8)
|
||||||
w = w.permute(0, 2, 1)
|
w = w.permute(0, 2, 1)
|
||||||
h = torch.bmm(v, w)
|
h = torch.bmm(v, w)
|
||||||
h = h.reshape(BATCH_COUNT, n, 2 ** 4, 2 ** 4)
|
h = h.reshape(m, n, 2 ** 4, 2 ** 4)
|
||||||
h = self.proj_out.forward(h)
|
h = self.proj_out.forward(h)
|
||||||
return x + h
|
return x + h
|
||||||
|
|
||||||
|
@ -169,10 +167,10 @@ class VQGanDetokenizer(Module):
|
||||||
|
|
||||||
def forward(self, z: Tensor) -> Tensor:
|
def forward(self, z: Tensor) -> Tensor:
|
||||||
z = self.embedding.forward(z)
|
z = self.embedding.forward(z)
|
||||||
z = z.view((BATCH_COUNT, 2 ** 4, 2 ** 4, 2 ** 8))
|
z = z.view((z.shape[0], 2 ** 4, 2 ** 4, 2 ** 8))
|
||||||
z = z.permute(0, 3, 1, 2).contiguous()
|
z = z.permute(0, 3, 1, 2).contiguous()
|
||||||
z = self.post_quant_conv.forward(z)
|
z = self.post_quant_conv.forward(z)
|
||||||
z = self.decoder.forward(z)
|
z = self.decoder.forward(z)
|
||||||
z = z.permute(0, 2, 3, 1)
|
z = z.permute(0, 2, 3, 1)
|
||||||
z = z.clip(0.0, 1.0) * 255
|
z = z.clip(0.0, 1.0) * 255
|
||||||
return z[0]
|
return z
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
from math import inf
|
from math import inf
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
|
||||||
class TextTokenizer:
|
class TextTokenizer:
|
||||||
def __init__(self, vocab: dict, merges: List[str], is_verbose: bool = True):
|
def __init__(self, vocab: dict, merges: List[str], is_verbose: bool = True):
|
||||||
self.is_verbose = is_verbose
|
self.is_verbose = is_verbose
|
||||||
|
|
|
@ -13,10 +13,10 @@ class Predictor(BasePredictor):
|
||||||
description="Text for generating images.",
|
description="Text for generating images.",
|
||||||
),
|
),
|
||||||
seed: int = Input(
|
seed: int = Input(
|
||||||
description="Specify the seed.",
|
description="Specify a random seed.",
|
||||||
),
|
),
|
||||||
) -> Path:
|
) -> Path:
|
||||||
image = self.model.generate_image(text, seed)
|
image = self.model.generate_image(text, seed, grid_size=3)
|
||||||
out_path = Path(tempfile.mkdtemp()) / "output.png"
|
out_path = Path(tempfile.mkdtemp()) / "output.png"
|
||||||
image.save(str(out_path))
|
image.save(str(out_path))
|
||||||
|
|
||||||
|
|
5
setup.py
5
setup.py
|
@ -5,7 +5,7 @@ setuptools.setup(
|
||||||
name='min-dalle',
|
name='min-dalle',
|
||||||
description = 'min(DALL·E)',
|
description = 'min(DALL·E)',
|
||||||
long_description=(Path(__file__).parent / "README").read_text(),
|
long_description=(Path(__file__).parent / "README").read_text(),
|
||||||
version='0.2.6',
|
version='0.2.9',
|
||||||
author='Brett Kuprel',
|
author='Brett Kuprel',
|
||||||
author_email='brkuprel@gmail.com',
|
author_email='brkuprel@gmail.com',
|
||||||
url='https://github.com/kuprel/min-dalle',
|
url='https://github.com/kuprel/min-dalle',
|
||||||
|
@ -15,7 +15,8 @@ setuptools.setup(
|
||||||
],
|
],
|
||||||
license='MIT',
|
license='MIT',
|
||||||
install_requires=[
|
install_requires=[
|
||||||
'torch>=1.10.0'
|
'torch>=1.10.0',
|
||||||
|
'typing_extensions>=4.1.0'
|
||||||
],
|
],
|
||||||
keywords = [
|
keywords = [
|
||||||
'artificial intelligence',
|
'artificial intelligence',
|
||||||
|
|
Loading…
Reference in New Issue
Block a user