support bfloat16

This commit is contained in:
Brett Kuprel
2022-07-07 08:21:20 -04:00
parent 5f526e2109
commit da62298f06
9 changed files with 108 additions and 96 deletions

View File

@@ -18,13 +18,15 @@ MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
class MinDalle:
def __init__(
self,
is_mega: bool,
is_reusable: bool = True,
models_root: str = 'pretrained',
dtype: torch.dtype = torch.float32,
is_mega: bool = True,
is_reusable: bool = True,
is_verbose = True
):
self.is_mega = is_mega
self.is_reusable = is_reusable
self.dtype = dtype
self.is_verbose = is_verbose
self.text_token_count = 64
self.layer_count = 24 if is_mega else 12
@@ -34,7 +36,6 @@ class MinDalle:
self.text_vocab_count = 50272 if is_mega else 50264
self.image_vocab_count = 16415 if is_mega else 16384
if self.is_verbose: print("initializing MinDalle")
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
dalle_path = os.path.join(models_root, model_name)
vqgan_path = os.path.join(models_root, 'vqgan')
@@ -105,7 +106,7 @@ class MinDalle:
text_token_count = self.text_token_count,
text_vocab_count = self.text_vocab_count,
layer_count = self.layer_count
)
).to(self.dtype).eval()
params = torch.load(self.encoder_params_path)
self.encoder.load_state_dict(params, strict=False)
del params
@@ -123,7 +124,7 @@ class MinDalle:
glu_embed_count = self.glu_embed_count,
layer_count = self.layer_count,
start_token = self.image_vocab_count
)
).to(self.dtype).eval()
params = torch.load(self.decoder_params_path)
self.decoder.load_state_dict(params, strict=False)
del params
@@ -134,7 +135,7 @@ class MinDalle:
is_downloaded = os.path.exists(self.detoker_params_path)
if not is_downloaded: self.download_detokenizer()
if self.is_verbose: print("initializing VQGanDetokenizer")
self.detokenizer = VQGanDetokenizer()
self.detokenizer = VQGanDetokenizer().to(self.dtype).eval()
params = torch.load(self.detoker_params_path)
self.detokenizer.load_state_dict(params)
del params
@@ -184,38 +185,41 @@ class MinDalle:
if not self.is_reusable: self.init_encoder()
if is_verbose: print("encoding text tokens")
encoder_state = self.encoder.forward(text_tokens)
with torch.cuda.amp.autocast(dtype=self.dtype):
encoder_state = self.encoder.forward(text_tokens)
if not self.is_reusable: del self.encoder
if torch.cuda.is_available(): torch.cuda.empty_cache()
if not self.is_reusable: self.init_decoder()
encoder_state, attention_mask, attention_state, image_tokens = (
self.decoder.decode_initial(
seed,
grid_size ** 2,
text_tokens,
encoder_state
with torch.cuda.amp.autocast(dtype=self.dtype):
encoder_state, attention_mask, attention_state, image_tokens = (
self.decoder.decode_initial(
seed,
grid_size ** 2,
text_tokens,
encoder_state
)
)
)
row_count = 16
for row_index in range(row_count):
if is_verbose:
print('sampling row {} of {}'.format(row_index + 1, row_count))
attention_state, image_tokens = self.decoder.decode_row(
row_index,
log2_k,
log2_supercondition_factor,
encoder_state,
attention_mask,
attention_state,
image_tokens
)
if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0:
tokens = image_tokens[:, 1:]
image = self.image_from_tokens(grid_size, tokens, is_verbose)
yield image
with torch.cuda.amp.autocast(dtype=self.dtype):
attention_state, image_tokens = self.decoder.decode_row(
row_index,
log2_k,
log2_supercondition_factor,
encoder_state,
attention_mask,
attention_state,
image_tokens
)
if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0:
tokens = image_tokens[:, 1:]
image = self.image_from_tokens(grid_size, tokens, is_verbose)
yield image
def generate_image(

View File

@@ -40,7 +40,8 @@ class DecoderSelfAttention(AttentionBase):
queries = self.q_proj.forward(decoder_state)
attn_mask = self.token_indices < token_index + 1
attn_mask = attn_mask[None][[0] * decoder_state.shape[0]]
attention_state[:, token_index] = torch.cat([keys, values])
attn_state_new = torch.cat([keys, values]).to(attention_state.dtype)
attention_state[:, token_index] = attn_state_new
batch_count = decoder_state.shape[0]
keys = attention_state[:batch_count]
values = attention_state[batch_count:]

View File

@@ -82,7 +82,7 @@ class Upsample(Module):
self.conv = Conv2d(n, n, 3, padding=1)
def forward(self, x: Tensor) -> Tensor:
x = self.upsample.forward(x)
x = self.upsample.forward(x.to(torch.float32))
x = self.conv.forward(x)
return x