diff --git a/README.md b/README.md index c588160..681f1c7 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ This is a fast, minimal implementation of Boris Dayma's [DALL·E Mega](https://github.com/borisdayma/dalle-mini). It has been stripped down for inference and converted to PyTorch. The only third party dependencies are numpy, requests, pillow and torch. It takes -- **35 seconds** to generate a 3x3 grid with a P100 in Colab +- **32 seconds** to generate a 3x3 grid with a P100 in Colab - **16 seconds** to generate a 4x4 grid with an A100 on Replicate - **TBD** to generate a 4x4 grid with an H100 (@NVIDIA?) diff --git a/image_from_text.py b/image_from_text.py index 75b6b71..061fd05 100644 --- a/image_from_text.py +++ b/image_from_text.py @@ -54,7 +54,8 @@ def generate_image( if token_count < 256: image_tokens = model.generate_image_tokens(text, seed, grid_size ** 2) - print('image tokens', image_tokens.to('cpu').detach().numpy()) + image_tokens = image_tokens[:, :token_count].to('cpu').detach().numpy() + print('image tokens', image_tokens) else: image = model.generate_image(text, seed, grid_size) save_image(image, image_path) diff --git a/min_dalle/models/dalle_bart_decoder.py b/min_dalle/models/dalle_bart_decoder.py index 50afe6d..b45e6cc 100644 --- a/min_dalle/models/dalle_bart_decoder.py +++ b/min_dalle/models/dalle_bart_decoder.py @@ -20,9 +20,9 @@ class DecoderCrossAttention(AttentionBase): class DecoderSelfAttention(AttentionBase): - def __init__(self, head_count: int, embed_count: int): + def __init__(self, head_count: int, embed_count: int, token_count: int): super().__init__(head_count, embed_count) - token_indices = torch.arange(256) + token_indices = torch.arange(token_count) if torch.cuda.is_available(): token_indices = token_indices.cuda() self.token_indices = token_indices @@ -56,7 +56,11 @@ class DecoderLayer(nn.Module): super().__init__() self.image_token_count = image_token_count self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count) - self.self_attn = DecoderSelfAttention(head_count, embed_count) + self.self_attn = DecoderSelfAttention( + head_count, + embed_count, + image_token_count + ) self.self_attn_layer_norm = nn.LayerNorm(embed_count) self.pre_encoder_attn_layer_norm = nn.LayerNorm(embed_count) self.encoder_attn = DecoderCrossAttention(head_count, embed_count) @@ -150,7 +154,7 @@ class DalleBartDecoder(nn.Module): attention_state: FloatTensor, prev_tokens: LongTensor, token_index: LongTensor - ) -> Tuple[LongTensor, FloatTensor]: + ) -> Tuple[FloatTensor, FloatTensor]: image_count = encoder_state.shape[0] // 2 token_index_batched = token_index[[0] * image_count * 2] prev_tokens = prev_tokens[list(range(image_count)) * 2] @@ -158,16 +162,14 @@ class DalleBartDecoder(nn.Module): decoder_state += self.embed_positions.forward(token_index_batched) decoder_state = self.layernorm_embedding.forward(decoder_state) decoder_state = decoder_state[:, None] - attention_states_new = [] for i in range(self.layer_count): - decoder_state, attention_state_layer = self.layers[i].forward( + decoder_state, attention_state[i] = self.layers[i].forward( decoder_state, encoder_state, attention_state[i], attention_mask, token_index ) - attention_states_new.append(attention_state_layer) decoder_state = self.final_ln(decoder_state) logits = self.lm_head(decoder_state) a = self.condition_factor @@ -182,7 +184,7 @@ class DalleBartDecoder(nn.Module): self.zero_prob, torch.exp(logits - top_logits[:, [0]]) ) - return probs, torch.stack(attention_states_new) + return probs, attention_state def forward( @@ -203,10 +205,17 @@ class DalleBartDecoder(nn.Module): self.embed_count ) attention_state = torch.zeros(attention_state_shape) - if torch.cuda.is_available(): attention_state = attention_state.cuda() + image_tokens_sequence = torch.full( + (image_count, self.image_token_count), + 6965, # black token + dtype=torch.long + ) + if torch.cuda.is_available(): + attention_state = attention_state.cuda() + image_tokens_sequence = image_tokens_sequence.cuda() image_tokens = self.start_token[[0] * image_count] - image_tokens_sequence: List[LongTensor] = [] + for i in range(self.sample_token_count): probs, attention_state = self.decode_step( attention_mask = attention_mask, @@ -217,6 +226,6 @@ class DalleBartDecoder(nn.Module): ) image_tokens = torch.multinomial(probs, 1)[:, 0] - image_tokens_sequence += [image_tokens] + image_tokens_sequence[:, i] = image_tokens - return torch.stack(image_tokens_sequence).T \ No newline at end of file + return image_tokens_sequence \ No newline at end of file diff --git a/setup.py b/setup.py index 00196c5..f5510ab 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name='min-dalle', description = 'min(DALL·E)', long_description=(Path(__file__).parent / "README.rst").read_text(), - version='0.2.15', + version='0.2.16', author='Brett Kuprel', author_email='brkuprel@gmail.com', url='https://github.com/kuprel/min-dalle',