inplace attention state, faster and less memory
This commit is contained in:
parent
aca617dc64
commit
6f617fe98f
2
README.md
vendored
2
README.md
vendored
|
@ -9,7 +9,7 @@
|
||||||
This is a fast, minimal implementation of Boris Dayma's [DALL·E Mega](https://github.com/borisdayma/dalle-mini). It has been stripped down for inference and converted to PyTorch. The only third party dependencies are numpy, requests, pillow and torch.
|
This is a fast, minimal implementation of Boris Dayma's [DALL·E Mega](https://github.com/borisdayma/dalle-mini). It has been stripped down for inference and converted to PyTorch. The only third party dependencies are numpy, requests, pillow and torch.
|
||||||
|
|
||||||
It takes
|
It takes
|
||||||
- **35 seconds** to generate a 3x3 grid with a P100 in Colab
|
- **32 seconds** to generate a 3x3 grid with a P100 in Colab
|
||||||
- **16 seconds** to generate a 4x4 grid with an A100 on Replicate
|
- **16 seconds** to generate a 4x4 grid with an A100 on Replicate
|
||||||
- **TBD** to generate a 4x4 grid with an H100 (@NVIDIA?)
|
- **TBD** to generate a 4x4 grid with an H100 (@NVIDIA?)
|
||||||
|
|
||||||
|
|
|
@ -54,7 +54,8 @@ def generate_image(
|
||||||
|
|
||||||
if token_count < 256:
|
if token_count < 256:
|
||||||
image_tokens = model.generate_image_tokens(text, seed, grid_size ** 2)
|
image_tokens = model.generate_image_tokens(text, seed, grid_size ** 2)
|
||||||
print('image tokens', image_tokens.to('cpu').detach().numpy())
|
image_tokens = image_tokens[:, :token_count].to('cpu').detach().numpy()
|
||||||
|
print('image tokens', image_tokens)
|
||||||
else:
|
else:
|
||||||
image = model.generate_image(text, seed, grid_size)
|
image = model.generate_image(text, seed, grid_size)
|
||||||
save_image(image, image_path)
|
save_image(image, image_path)
|
||||||
|
|
|
@ -20,9 +20,9 @@ class DecoderCrossAttention(AttentionBase):
|
||||||
|
|
||||||
|
|
||||||
class DecoderSelfAttention(AttentionBase):
|
class DecoderSelfAttention(AttentionBase):
|
||||||
def __init__(self, head_count: int, embed_count: int):
|
def __init__(self, head_count: int, embed_count: int, token_count: int):
|
||||||
super().__init__(head_count, embed_count)
|
super().__init__(head_count, embed_count)
|
||||||
token_indices = torch.arange(256)
|
token_indices = torch.arange(token_count)
|
||||||
if torch.cuda.is_available(): token_indices = token_indices.cuda()
|
if torch.cuda.is_available(): token_indices = token_indices.cuda()
|
||||||
self.token_indices = token_indices
|
self.token_indices = token_indices
|
||||||
|
|
||||||
|
@ -56,7 +56,11 @@ class DecoderLayer(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.image_token_count = image_token_count
|
self.image_token_count = image_token_count
|
||||||
self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
|
self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
|
||||||
self.self_attn = DecoderSelfAttention(head_count, embed_count)
|
self.self_attn = DecoderSelfAttention(
|
||||||
|
head_count,
|
||||||
|
embed_count,
|
||||||
|
image_token_count
|
||||||
|
)
|
||||||
self.self_attn_layer_norm = nn.LayerNorm(embed_count)
|
self.self_attn_layer_norm = nn.LayerNorm(embed_count)
|
||||||
self.pre_encoder_attn_layer_norm = nn.LayerNorm(embed_count)
|
self.pre_encoder_attn_layer_norm = nn.LayerNorm(embed_count)
|
||||||
self.encoder_attn = DecoderCrossAttention(head_count, embed_count)
|
self.encoder_attn = DecoderCrossAttention(head_count, embed_count)
|
||||||
|
@ -150,7 +154,7 @@ class DalleBartDecoder(nn.Module):
|
||||||
attention_state: FloatTensor,
|
attention_state: FloatTensor,
|
||||||
prev_tokens: LongTensor,
|
prev_tokens: LongTensor,
|
||||||
token_index: LongTensor
|
token_index: LongTensor
|
||||||
) -> Tuple[LongTensor, FloatTensor]:
|
) -> Tuple[FloatTensor, FloatTensor]:
|
||||||
image_count = encoder_state.shape[0] // 2
|
image_count = encoder_state.shape[0] // 2
|
||||||
token_index_batched = token_index[[0] * image_count * 2]
|
token_index_batched = token_index[[0] * image_count * 2]
|
||||||
prev_tokens = prev_tokens[list(range(image_count)) * 2]
|
prev_tokens = prev_tokens[list(range(image_count)) * 2]
|
||||||
|
@ -158,16 +162,14 @@ class DalleBartDecoder(nn.Module):
|
||||||
decoder_state += self.embed_positions.forward(token_index_batched)
|
decoder_state += self.embed_positions.forward(token_index_batched)
|
||||||
decoder_state = self.layernorm_embedding.forward(decoder_state)
|
decoder_state = self.layernorm_embedding.forward(decoder_state)
|
||||||
decoder_state = decoder_state[:, None]
|
decoder_state = decoder_state[:, None]
|
||||||
attention_states_new = []
|
|
||||||
for i in range(self.layer_count):
|
for i in range(self.layer_count):
|
||||||
decoder_state, attention_state_layer = self.layers[i].forward(
|
decoder_state, attention_state[i] = self.layers[i].forward(
|
||||||
decoder_state,
|
decoder_state,
|
||||||
encoder_state,
|
encoder_state,
|
||||||
attention_state[i],
|
attention_state[i],
|
||||||
attention_mask,
|
attention_mask,
|
||||||
token_index
|
token_index
|
||||||
)
|
)
|
||||||
attention_states_new.append(attention_state_layer)
|
|
||||||
decoder_state = self.final_ln(decoder_state)
|
decoder_state = self.final_ln(decoder_state)
|
||||||
logits = self.lm_head(decoder_state)
|
logits = self.lm_head(decoder_state)
|
||||||
a = self.condition_factor
|
a = self.condition_factor
|
||||||
|
@ -182,7 +184,7 @@ class DalleBartDecoder(nn.Module):
|
||||||
self.zero_prob,
|
self.zero_prob,
|
||||||
torch.exp(logits - top_logits[:, [0]])
|
torch.exp(logits - top_logits[:, [0]])
|
||||||
)
|
)
|
||||||
return probs, torch.stack(attention_states_new)
|
return probs, attention_state
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -203,10 +205,17 @@ class DalleBartDecoder(nn.Module):
|
||||||
self.embed_count
|
self.embed_count
|
||||||
)
|
)
|
||||||
attention_state = torch.zeros(attention_state_shape)
|
attention_state = torch.zeros(attention_state_shape)
|
||||||
if torch.cuda.is_available(): attention_state = attention_state.cuda()
|
image_tokens_sequence = torch.full(
|
||||||
|
(image_count, self.image_token_count),
|
||||||
|
6965, # black token
|
||||||
|
dtype=torch.long
|
||||||
|
)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
attention_state = attention_state.cuda()
|
||||||
|
image_tokens_sequence = image_tokens_sequence.cuda()
|
||||||
|
|
||||||
image_tokens = self.start_token[[0] * image_count]
|
image_tokens = self.start_token[[0] * image_count]
|
||||||
image_tokens_sequence: List[LongTensor] = []
|
|
||||||
for i in range(self.sample_token_count):
|
for i in range(self.sample_token_count):
|
||||||
probs, attention_state = self.decode_step(
|
probs, attention_state = self.decode_step(
|
||||||
attention_mask = attention_mask,
|
attention_mask = attention_mask,
|
||||||
|
@ -217,6 +226,6 @@ class DalleBartDecoder(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
image_tokens = torch.multinomial(probs, 1)[:, 0]
|
image_tokens = torch.multinomial(probs, 1)[:, 0]
|
||||||
image_tokens_sequence += [image_tokens]
|
image_tokens_sequence[:, i] = image_tokens
|
||||||
|
|
||||||
return torch.stack(image_tokens_sequence).T
|
return image_tokens_sequence
|
2
setup.py
2
setup.py
|
@ -5,7 +5,7 @@ setuptools.setup(
|
||||||
name='min-dalle',
|
name='min-dalle',
|
||||||
description = 'min(DALL·E)',
|
description = 'min(DALL·E)',
|
||||||
long_description=(Path(__file__).parent / "README.rst").read_text(),
|
long_description=(Path(__file__).parent / "README.rst").read_text(),
|
||||||
version='0.2.15',
|
version='0.2.16',
|
||||||
author='Brett Kuprel',
|
author='Brett Kuprel',
|
||||||
author_email='brkuprel@gmail.com',
|
author_email='brkuprel@gmail.com',
|
||||||
url='https://github.com/kuprel/min-dalle',
|
url='https://github.com/kuprel/min-dalle',
|
||||||
|
|
Loading…
Reference in New Issue
Block a user