sort -> topk, prev_token_and_index -> prev_token, token_index

This commit is contained in:
Brett Kuprel 2022-06-30 09:04:11 -04:00
parent fb97ba5e20
commit df9aa6f915
3 changed files with 21 additions and 14 deletions

View File

@ -2,16 +2,17 @@ import os
import numpy import numpy
from copy import deepcopy from copy import deepcopy
from typing import Dict from typing import Dict
from flax import traverse_util, serialization from flax.traverse_util import flatten_dict
from flax.serialization import msgpack_restore
import torch import torch
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]: def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
with open(os.path.join(path, 'flax_model.msgpack'), "rb") as f: with open(os.path.join(path, 'flax_model.msgpack'), "rb") as f:
params: Dict[str, numpy.ndarray] = serialization.msgpack_restore(f.read()) params: Dict[str, numpy.ndarray] = msgpack_restore(f.read())
P: Dict[str, numpy.ndarray] = traverse_util.flatten_dict(params, sep='.') P: Dict[str, numpy.ndarray] = flatten_dict(params, sep='.')
for i in list(P.keys()): for i in list(P.keys()):
j = i j = i
@ -42,7 +43,7 @@ def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
def load_dalle_bart_flax_params(path: str) -> Dict[str, numpy.ndarray]: def load_dalle_bart_flax_params(path: str) -> Dict[str, numpy.ndarray]:
with open(os.path.join(path, "flax_model.msgpack"), "rb") as f: with open(os.path.join(path, "flax_model.msgpack"), "rb") as f:
params = serialization.msgpack_restore(f.read()) params = msgpack_restore(f.read())
for codec in ['encoder', 'decoder']: for codec in ['encoder', 'decoder']:
k = 'FlaxBart{}Layers'.format(codec.title()) k = 'FlaxBart{}Layers'.format(codec.title())
@ -82,7 +83,7 @@ def convert_dalle_bart_torch_from_flax_params(
is_encoder: bool is_encoder: bool
) -> dict: ) -> dict:
P = deepcopy(params) P = deepcopy(params)
P: Dict[str, numpy.ndarray] = traverse_util.flatten_dict(P, sep='.') P: Dict[str, numpy.ndarray] = flatten_dict(P, sep='.')
for i in P: for i in P:
P[i] = torch.tensor(P[i]) P[i] = torch.tensor(P[i])

View File

@ -160,14 +160,15 @@ class DalleBartDecoderTorch(nn.Module):
text_tokens: LongTensor, text_tokens: LongTensor,
encoder_state: FloatTensor, encoder_state: FloatTensor,
attention_state: FloatTensor, attention_state: FloatTensor,
prev_token_and_index: LongTensor prev_token: LongTensor,
token_index: LongTensor
) -> Tuple[LongTensor, FloatTensor]: ) -> Tuple[LongTensor, FloatTensor]:
attention_mask = text_tokens.not_equal(1) attention_mask = text_tokens.not_equal(1)
batch_count = encoder_state.shape[0] batch_count = encoder_state.shape[0]
prev_token = torch.cat([prev_token_and_index[:1]] * batch_count) prev_token_batched = torch.cat([prev_token] * batch_count)
token_index = torch.cat([prev_token_and_index[1:]] * batch_count) token_index_batched = torch.cat([token_index] * batch_count)
decoder_state = self.embed_tokens.forward(prev_token) decoder_state = self.embed_tokens.forward(prev_token_batched)
decoder_state += self.embed_positions.forward(token_index) decoder_state += self.embed_positions.forward(token_index_batched)
decoder_state = self.layernorm_embedding.forward(decoder_state) decoder_state = self.layernorm_embedding.forward(decoder_state)
decoder_state = decoder_state[:, None] decoder_state = decoder_state[:, None]
attention_states_new = [] attention_states_new = []
@ -177,7 +178,7 @@ class DalleBartDecoderTorch(nn.Module):
encoder_state, encoder_state,
attention_state[i], attention_state[i],
attention_mask, attention_mask,
token_index[:1] token_index
) )
attention_states_new.append(attention_state_layer) attention_states_new.append(attention_state_layer)
decoder_state = self.final_ln(decoder_state) decoder_state = self.final_ln(decoder_state)
@ -185,7 +186,7 @@ class DalleBartDecoderTorch(nn.Module):
a = self.condition_factor a = self.condition_factor
logits: FloatTensor = a * logits[0, -1] + (1 - a) * logits[1, -1] logits: FloatTensor = a * logits[0, -1] + (1 - a) * logits[1, -1]
top_logits = logits.sort(descending=True)[0][:50] top_logits, _ = logits.topk(50, dim=-1)
probs = torch.where( probs = torch.where(
logits < top_logits[-1], logits < top_logits[-1],
self.zero_prob, self.zero_prob,
@ -206,12 +207,12 @@ class DalleBartDecoderTorch(nn.Module):
image_token = self.start_token image_token = self.start_token
for i in range(self.sample_token_count): for i in range(self.sample_token_count):
token_index = self.token_indices[i:i+1]
probs, attention_state = self.decode_step( probs, attention_state = self.decode_step(
text_tokens = text_tokens, text_tokens = text_tokens,
encoder_state = encoder_state, encoder_state = encoder_state,
attention_state = attention_state, attention_state = attention_state,
prev_token_and_index = torch.cat([image_token, token_index]) prev_token = image_token,
token_index = self.token_indices[[i]]
) )
image_token = torch.multinomial(probs, 1) image_token = torch.multinomial(probs, 1)

View File

@ -61,6 +61,7 @@ class AttentionBlock(Module):
h = self.proj_out.forward(h) h = self.proj_out.forward(h)
return x + h return x + h
class MiddleLayer(Module): class MiddleLayer(Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -74,6 +75,7 @@ class MiddleLayer(Module):
h = self.block_2.forward(h) h = self.block_2.forward(h)
return h return h
class Upsample(Module): class Upsample(Module):
def __init__(self, log2_count): def __init__(self, log2_count):
super().__init__() super().__init__()
@ -86,6 +88,7 @@ class Upsample(Module):
x = self.conv.forward(x) x = self.conv.forward(x)
return x return x
class UpsampleBlock(Module): class UpsampleBlock(Module):
def __init__( def __init__(
self, self,
@ -124,6 +127,7 @@ class UpsampleBlock(Module):
h = self.upsample.forward(h) h = self.upsample.forward(h)
return h return h
class Decoder(Module): class Decoder(Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -154,6 +158,7 @@ class Decoder(Module):
z = self.conv_out.forward(z) z = self.conv_out.forward(z)
return z return z
class VQGanDetokenizer(Module): class VQGanDetokenizer(Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()