support bfloat16
This commit is contained in:
@@ -18,13 +18,15 @@ MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
|
||||
class MinDalle:
|
||||
def __init__(
|
||||
self,
|
||||
is_mega: bool,
|
||||
is_reusable: bool = True,
|
||||
models_root: str = 'pretrained',
|
||||
dtype: torch.dtype = torch.float32,
|
||||
is_mega: bool = True,
|
||||
is_reusable: bool = True,
|
||||
is_verbose = True
|
||||
):
|
||||
self.is_mega = is_mega
|
||||
self.is_reusable = is_reusable
|
||||
self.dtype = dtype
|
||||
self.is_verbose = is_verbose
|
||||
self.text_token_count = 64
|
||||
self.layer_count = 24 if is_mega else 12
|
||||
@@ -34,7 +36,6 @@ class MinDalle:
|
||||
self.text_vocab_count = 50272 if is_mega else 50264
|
||||
self.image_vocab_count = 16415 if is_mega else 16384
|
||||
|
||||
if self.is_verbose: print("initializing MinDalle")
|
||||
model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
|
||||
dalle_path = os.path.join(models_root, model_name)
|
||||
vqgan_path = os.path.join(models_root, 'vqgan')
|
||||
@@ -105,7 +106,7 @@ class MinDalle:
|
||||
text_token_count = self.text_token_count,
|
||||
text_vocab_count = self.text_vocab_count,
|
||||
layer_count = self.layer_count
|
||||
)
|
||||
).to(self.dtype).eval()
|
||||
params = torch.load(self.encoder_params_path)
|
||||
self.encoder.load_state_dict(params, strict=False)
|
||||
del params
|
||||
@@ -123,7 +124,7 @@ class MinDalle:
|
||||
glu_embed_count = self.glu_embed_count,
|
||||
layer_count = self.layer_count,
|
||||
start_token = self.image_vocab_count
|
||||
)
|
||||
).to(self.dtype).eval()
|
||||
params = torch.load(self.decoder_params_path)
|
||||
self.decoder.load_state_dict(params, strict=False)
|
||||
del params
|
||||
@@ -134,7 +135,7 @@ class MinDalle:
|
||||
is_downloaded = os.path.exists(self.detoker_params_path)
|
||||
if not is_downloaded: self.download_detokenizer()
|
||||
if self.is_verbose: print("initializing VQGanDetokenizer")
|
||||
self.detokenizer = VQGanDetokenizer()
|
||||
self.detokenizer = VQGanDetokenizer().to(self.dtype).eval()
|
||||
params = torch.load(self.detoker_params_path)
|
||||
self.detokenizer.load_state_dict(params)
|
||||
del params
|
||||
@@ -184,38 +185,41 @@ class MinDalle:
|
||||
|
||||
if not self.is_reusable: self.init_encoder()
|
||||
if is_verbose: print("encoding text tokens")
|
||||
encoder_state = self.encoder.forward(text_tokens)
|
||||
with torch.cuda.amp.autocast(dtype=self.dtype):
|
||||
encoder_state = self.encoder.forward(text_tokens)
|
||||
if not self.is_reusable: del self.encoder
|
||||
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
||||
|
||||
if not self.is_reusable: self.init_decoder()
|
||||
|
||||
encoder_state, attention_mask, attention_state, image_tokens = (
|
||||
self.decoder.decode_initial(
|
||||
seed,
|
||||
grid_size ** 2,
|
||||
text_tokens,
|
||||
encoder_state
|
||||
with torch.cuda.amp.autocast(dtype=self.dtype):
|
||||
encoder_state, attention_mask, attention_state, image_tokens = (
|
||||
self.decoder.decode_initial(
|
||||
seed,
|
||||
grid_size ** 2,
|
||||
text_tokens,
|
||||
encoder_state
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
row_count = 16
|
||||
for row_index in range(row_count):
|
||||
if is_verbose:
|
||||
print('sampling row {} of {}'.format(row_index + 1, row_count))
|
||||
attention_state, image_tokens = self.decoder.decode_row(
|
||||
row_index,
|
||||
log2_k,
|
||||
log2_supercondition_factor,
|
||||
encoder_state,
|
||||
attention_mask,
|
||||
attention_state,
|
||||
image_tokens
|
||||
)
|
||||
if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0:
|
||||
tokens = image_tokens[:, 1:]
|
||||
image = self.image_from_tokens(grid_size, tokens, is_verbose)
|
||||
yield image
|
||||
with torch.cuda.amp.autocast(dtype=self.dtype):
|
||||
attention_state, image_tokens = self.decoder.decode_row(
|
||||
row_index,
|
||||
log2_k,
|
||||
log2_supercondition_factor,
|
||||
encoder_state,
|
||||
attention_mask,
|
||||
attention_state,
|
||||
image_tokens
|
||||
)
|
||||
if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0:
|
||||
tokens = image_tokens[:, 1:]
|
||||
image = self.image_from_tokens(grid_size, tokens, is_verbose)
|
||||
yield image
|
||||
|
||||
|
||||
def generate_image(
|
||||
|
@@ -40,7 +40,8 @@ class DecoderSelfAttention(AttentionBase):
|
||||
queries = self.q_proj.forward(decoder_state)
|
||||
attn_mask = self.token_indices < token_index + 1
|
||||
attn_mask = attn_mask[None][[0] * decoder_state.shape[0]]
|
||||
attention_state[:, token_index] = torch.cat([keys, values])
|
||||
attn_state_new = torch.cat([keys, values]).to(attention_state.dtype)
|
||||
attention_state[:, token_index] = attn_state_new
|
||||
batch_count = decoder_state.shape[0]
|
||||
keys = attention_state[:batch_count]
|
||||
values = attention_state[batch_count:]
|
||||
|
@@ -82,7 +82,7 @@ class Upsample(Module):
|
||||
self.conv = Conv2d(n, n, 3, padding=1)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.upsample.forward(x)
|
||||
x = self.upsample.forward(x.to(torch.float32))
|
||||
x = self.conv.forward(x)
|
||||
return x
|
||||
|
||||
|
Reference in New Issue
Block a user