diff --git a/min_dalle/generate_image.py b/min_dalle/generate_image.py index 6829248..ba7fd55 100644 --- a/min_dalle/generate_image.py +++ b/min_dalle/generate_image.py @@ -65,6 +65,8 @@ def generate_image_from_text( if image_token_count == config['image_length']: image = detokenize_torch(image_tokens) return Image.fromarray(image) + else: + print(list(image_tokens.to('cpu').detach().numpy())) else: image_tokens = generate_image_tokens_flax( text_tokens = text_tokens, diff --git a/min_dalle/load_params.py b/min_dalle/load_params.py index 3fd77a8..ac4fad6 100644 --- a/min_dalle/load_params.py +++ b/min_dalle/load_params.py @@ -4,7 +4,7 @@ from copy import deepcopy from typing import Dict from flax import traverse_util, serialization import torch -torch.no_grad() +torch.set_grad_enabled(False) def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]: @@ -30,7 +30,6 @@ def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]: for i in P: P[i] = torch.tensor(P[i]) - # if torch.cuda.is_available(): P[i] = P[i].cuda() P['embedding.weight'] = P.pop('quantize.embedding.embedding') @@ -87,7 +86,6 @@ def convert_dalle_bart_torch_from_flax_params( for i in P: P[i] = torch.tensor(P[i]) - # if torch.cuda.is_available(): P[i] = P[i].cuda() for i in list(P): if 'kernel' in i: diff --git a/min_dalle/min_dalle_torch.py b/min_dalle/min_dalle_torch.py index 690e02c..5a6a39d 100644 --- a/min_dalle/min_dalle_torch.py +++ b/min_dalle/min_dalle_torch.py @@ -2,7 +2,7 @@ import numpy from typing import Dict from torch import LongTensor, FloatTensor import torch -torch.no_grad() +torch.set_grad_enabled(False) from .models.vqgan_detokenizer import VQGanDetokenizer from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch @@ -35,6 +35,7 @@ def encode_torch( ) encoder.load_state_dict(encoder_params, strict=False) del encoder_params + if torch.cuda.is_available(): encoder = encoder.cuda() print("encoding text tokens") encoder_state = encoder(text_tokens) @@ -70,6 +71,7 @@ def decode_torch( ) decoder.load_state_dict(decoder_params, strict=False) del decoder_params + if torch.cuda.is_available(): decoder = decoder.cuda() print("sampling image tokens") torch.manual_seed(seed) @@ -85,7 +87,7 @@ def generate_image_tokens_torch( image_token_count: int ) -> LongTensor: text_tokens = torch.tensor(text_tokens).to(torch.long) - # if torch.cuda.is_available(): text_tokens = text_tokens.cuda() + if torch.cuda.is_available(): text_tokens = text_tokens.cuda() encoder_state = encode_torch( text_tokens, config, @@ -108,6 +110,8 @@ def detokenize_torch(image_tokens: LongTensor) -> numpy.ndarray: params = load_vqgan_torch_params(model_path) detokenizer = VQGanDetokenizer() detokenizer.load_state_dict(params) + if torch.cuda.is_available(): detokenizer = detokenizer.cuda() image = detokenizer.forward(image_tokens).to(torch.uint8) - return image.detach().numpy() + del detokenizer, params + return image.to('cpu').detach().numpy() \ No newline at end of file diff --git a/min_dalle/models/dalle_bart_decoder_torch.py b/min_dalle/models/dalle_bart_decoder_torch.py index ec14493..f4555ab 100644 --- a/min_dalle/models/dalle_bart_decoder_torch.py +++ b/min_dalle/models/dalle_bart_decoder_torch.py @@ -1,7 +1,7 @@ from typing import List, Tuple import torch from torch import LongTensor, nn, FloatTensor, BoolTensor -torch.no_grad() +torch.set_grad_enabled(False) from .dalle_bart_encoder_torch import GLUTorch, AttentionTorch @@ -30,14 +30,12 @@ class DecoderSelfAttentionTorch(AttentionTorch): decoder_state: FloatTensor, keys_values: FloatTensor, attention_mask: BoolTensor, - token_index: LongTensor + token_mask: BoolTensor ) -> Tuple[FloatTensor, FloatTensor]: batch_count = decoder_state.shape[0] - token_count = keys_values.shape[1] shape = (batch_count, 1) + keys_values.shape[2:] keys = self.k_proj.forward(decoder_state).view(shape) values = self.v_proj.forward(decoder_state).view(shape) - token_mask = torch.arange(token_count) == token_index keys_values = torch.where( token_mask[None, :, None, None], torch.cat([keys, values]), @@ -67,6 +65,10 @@ class DecoderLayerTorch(nn.Module): self.encoder_attn_layer_norm = nn.LayerNorm(embed_count) self.glu = GLUTorch(embed_count, glu_embed_count) + self.token_indices = torch.arange(self.image_token_count) + if torch.cuda.is_available(): + self.token_indices = self.token_indices.cuda() + def forward(self, decoder_state: FloatTensor, encoder_state: FloatTensor, @@ -77,13 +79,14 @@ class DecoderLayerTorch(nn.Module): # Self Attention residual = decoder_state decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state) - self_attn_mask = torch.arange(self.image_token_count) < token_index + 1 + self_attn_mask = self.token_indices < token_index + 1 + token_mask = self.token_indices == token_index self_attn_mask = torch.stack([self_attn_mask] * decoder_state.shape[0]) decoder_state, keys_values_state = self.self_attn.forward( decoder_state, keys_values_state, self_attn_mask, - token_index + token_mask ) decoder_state = self.self_attn_layer_norm.forward(decoder_state) decoder_state = residual + decoder_state @@ -124,13 +127,7 @@ class DalleBartDecoderTorch(nn.Module): self.is_verbose = is_verbose self.layer_count = layer_count self.sample_token_count = sample_token_count - self.start_token = torch.tensor([start_token]).to(torch.long) - self.pad_token = torch.tensor([1]).to(torch.long) - self.condition_factor = torch.tensor([10]).to(torch.float) - # if torch.cuda.is_available(): - # self.start_token = self.start_token.cuda() - # self.pad_token = self.pad_token.cuda() - # self.condition_factor = self.condition_factor.cuda() + self.condition_factor = 10.0 self.image_token_count = image_token_count self.embed_tokens = nn.Embedding(image_vocab_size + 1, embed_count) self.embed_positions = nn.Embedding(image_token_count, embed_count) @@ -152,6 +149,13 @@ class DalleBartDecoderTorch(nn.Module): attention_head_count, embed_count // attention_head_count ) + self.zero_prob = torch.zeros([1]) + self.token_indices = torch.arange(self.sample_token_count) + self.start_token = torch.tensor([start_token]).to(torch.long) + if torch.cuda.is_available(): + self.zero_prob = self.zero_prob.cuda() + self.token_indices = self.token_indices.cuda() + self.start_token = self.start_token.cuda() def decode_step(self, @@ -160,7 +164,7 @@ class DalleBartDecoderTorch(nn.Module): keys_values_state: FloatTensor, prev_token_and_index: LongTensor ) -> Tuple[LongTensor, FloatTensor]: - attention_mask = text_tokens.not_equal(self.pad_token) + attention_mask = text_tokens.not_equal(1) batch_count = encoder_state.shape[0] prev_token = torch.cat([prev_token_and_index[:1]] * batch_count) token_index = torch.cat([prev_token_and_index[1:]] * batch_count) @@ -188,7 +192,7 @@ class DalleBartDecoderTorch(nn.Module): top_logits = logits.sort(descending=True)[0][:50] probs = torch.where( logits < top_logits[-1], - torch.zeros([1]), + self.zero_prob, torch.exp(logits - top_logits[0]) ) return probs, keys_values @@ -200,11 +204,12 @@ class DalleBartDecoderTorch(nn.Module): ) -> LongTensor: image_tokens: List[LongTensor] = [] keys_values_state = torch.zeros(self.keys_values_state_shape) + if torch.cuda.is_available(): + keys_values_state = keys_values_state.cuda() image_token = self.start_token for i in range(self.sample_token_count): - token_index = torch.tensor([i]).to(torch.long) - # if torch.cuda.is_available(): token_index = token_index.cuda() + token_index = self.token_indices[i:i+1] probs, keys_values_state = self.decode_step( text_tokens = text_tokens, encoder_state = encoder_state, @@ -214,9 +219,5 @@ class DalleBartDecoderTorch(nn.Module): image_token = torch.multinomial(probs, 1) image_tokens += [image_token] - - if self.is_verbose: - token = int(image_token.detach().numpy()) - print("image token {} is {}".format(i, token)) return torch.cat(image_tokens) \ No newline at end of file diff --git a/min_dalle/models/dalle_bart_encoder_torch.py b/min_dalle/models/dalle_bart_encoder_torch.py index d21c542..92bf775 100644 --- a/min_dalle/models/dalle_bart_encoder_torch.py +++ b/min_dalle/models/dalle_bart_encoder_torch.py @@ -1,7 +1,7 @@ from typing import List import torch from torch import nn, BoolTensor, FloatTensor, LongTensor -torch.no_grad() +torch.set_grad_enabled(False) class GLUTorch(nn.Module): @@ -34,6 +34,8 @@ class AttentionTorch(nn.Module): self.v_proj = nn.Linear(embed_count, embed_count, bias=False) self.q_proj = nn.Linear(embed_count, embed_count, bias=False) self.out_proj = nn.Linear(embed_count, embed_count, bias=False) + self.one = torch.ones((1, 1)) + if torch.cuda.is_available(): self.one = self.one.cuda() def forward(self, keys: FloatTensor, @@ -43,8 +45,8 @@ class AttentionTorch(nn.Module): ) -> FloatTensor: attention_bias = torch.where( attention_mask, - torch.full(attention_mask.shape, 0.0), - torch.full(attention_mask.shape, -torch.inf), + self.one * 0, + self.one * (-torch.inf), ) attention_weights: FloatTensor = torch.einsum( 'bqhc,bkhc->bhqk', @@ -124,11 +126,14 @@ class DalleBartEncoderTorch(nn.Module): ]) self.layernorm_embedding = nn.LayerNorm(embed_count) self.final_ln = nn.LayerNorm(embed_count) + self.token_indices = torch.arange(text_token_count).to(torch.long) + if torch.cuda.is_available(): + self.token_indices = self.token_indices.cuda() def forward(self, text_tokens: LongTensor) -> FloatTensor: attention_mask = text_tokens.not_equal(1) - batch_count, token_count = text_tokens.shape - pose_tokens = torch.stack([torch.arange(token_count)] * batch_count) + batch_count = text_tokens.shape[0] + pose_tokens = torch.stack([self.token_indices] * batch_count) encoder_state = ( self.embed_tokens.forward(text_tokens) + self.embed_positions.forward(pose_tokens) diff --git a/min_dalle/models/vqgan_detokenizer.py b/min_dalle/models/vqgan_detokenizer.py index e74416e..b0b8758 100644 --- a/min_dalle/models/vqgan_detokenizer.py +++ b/min_dalle/models/vqgan_detokenizer.py @@ -1,7 +1,7 @@ import torch from torch import Tensor from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding -torch.no_grad() +torch.set_grad_enabled(False) BATCH_COUNT: int = 1