sort -> topk, prev_token_and_index -> prev_token, token_index

This commit is contained in:
Brett Kuprel 2022-06-30 09:04:11 -04:00
parent fb97ba5e20
commit df9aa6f915
3 changed files with 21 additions and 14 deletions

View File

@ -2,16 +2,17 @@ import os
import numpy
from copy import deepcopy
from typing import Dict
from flax import traverse_util, serialization
from flax.traverse_util import flatten_dict
from flax.serialization import msgpack_restore
import torch
torch.set_grad_enabled(False)
def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
with open(os.path.join(path, 'flax_model.msgpack'), "rb") as f:
params: Dict[str, numpy.ndarray] = serialization.msgpack_restore(f.read())
params: Dict[str, numpy.ndarray] = msgpack_restore(f.read())
P: Dict[str, numpy.ndarray] = traverse_util.flatten_dict(params, sep='.')
P: Dict[str, numpy.ndarray] = flatten_dict(params, sep='.')
for i in list(P.keys()):
j = i
@ -42,7 +43,7 @@ def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
def load_dalle_bart_flax_params(path: str) -> Dict[str, numpy.ndarray]:
with open(os.path.join(path, "flax_model.msgpack"), "rb") as f:
params = serialization.msgpack_restore(f.read())
params = msgpack_restore(f.read())
for codec in ['encoder', 'decoder']:
k = 'FlaxBart{}Layers'.format(codec.title())
@ -82,7 +83,7 @@ def convert_dalle_bart_torch_from_flax_params(
is_encoder: bool
) -> dict:
P = deepcopy(params)
P: Dict[str, numpy.ndarray] = traverse_util.flatten_dict(P, sep='.')
P: Dict[str, numpy.ndarray] = flatten_dict(P, sep='.')
for i in P:
P[i] = torch.tensor(P[i])

View File

@ -160,14 +160,15 @@ class DalleBartDecoderTorch(nn.Module):
text_tokens: LongTensor,
encoder_state: FloatTensor,
attention_state: FloatTensor,
prev_token_and_index: LongTensor
prev_token: LongTensor,
token_index: LongTensor
) -> Tuple[LongTensor, FloatTensor]:
attention_mask = text_tokens.not_equal(1)
batch_count = encoder_state.shape[0]
prev_token = torch.cat([prev_token_and_index[:1]] * batch_count)
token_index = torch.cat([prev_token_and_index[1:]] * batch_count)
decoder_state = self.embed_tokens.forward(prev_token)
decoder_state += self.embed_positions.forward(token_index)
prev_token_batched = torch.cat([prev_token] * batch_count)
token_index_batched = torch.cat([token_index] * batch_count)
decoder_state = self.embed_tokens.forward(prev_token_batched)
decoder_state += self.embed_positions.forward(token_index_batched)
decoder_state = self.layernorm_embedding.forward(decoder_state)
decoder_state = decoder_state[:, None]
attention_states_new = []
@ -177,7 +178,7 @@ class DalleBartDecoderTorch(nn.Module):
encoder_state,
attention_state[i],
attention_mask,
token_index[:1]
token_index
)
attention_states_new.append(attention_state_layer)
decoder_state = self.final_ln(decoder_state)
@ -185,7 +186,7 @@ class DalleBartDecoderTorch(nn.Module):
a = self.condition_factor
logits: FloatTensor = a * logits[0, -1] + (1 - a) * logits[1, -1]
top_logits = logits.sort(descending=True)[0][:50]
top_logits, _ = logits.topk(50, dim=-1)
probs = torch.where(
logits < top_logits[-1],
self.zero_prob,
@ -206,12 +207,12 @@ class DalleBartDecoderTorch(nn.Module):
image_token = self.start_token
for i in range(self.sample_token_count):
token_index = self.token_indices[i:i+1]
probs, attention_state = self.decode_step(
text_tokens = text_tokens,
encoder_state = encoder_state,
attention_state = attention_state,
prev_token_and_index = torch.cat([image_token, token_index])
prev_token = image_token,
token_index = self.token_indices[[i]]
)
image_token = torch.multinomial(probs, 1)

View File

@ -61,6 +61,7 @@ class AttentionBlock(Module):
h = self.proj_out.forward(h)
return x + h
class MiddleLayer(Module):
def __init__(self):
super().__init__()
@ -74,6 +75,7 @@ class MiddleLayer(Module):
h = self.block_2.forward(h)
return h
class Upsample(Module):
def __init__(self, log2_count):
super().__init__()
@ -86,6 +88,7 @@ class Upsample(Module):
x = self.conv.forward(x)
return x
class UpsampleBlock(Module):
def __init__(
self,
@ -124,6 +127,7 @@ class UpsampleBlock(Module):
h = self.upsample.forward(h)
return h
class Decoder(Module):
def __init__(self):
super().__init__()
@ -154,6 +158,7 @@ class Decoder(Module):
z = self.conv_out.forward(z)
return z
class VQGanDetokenizer(Module):
def __init__(self):
super().__init__()