sort -> topk, prev_token_and_index -> prev_token, token_index
This commit is contained in:
parent
fb97ba5e20
commit
df9aa6f915
|
@ -2,16 +2,17 @@ import os
|
|||
import numpy
|
||||
from copy import deepcopy
|
||||
from typing import Dict
|
||||
from flax import traverse_util, serialization
|
||||
from flax.traverse_util import flatten_dict
|
||||
from flax.serialization import msgpack_restore
|
||||
import torch
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
|
||||
def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
|
||||
with open(os.path.join(path, 'flax_model.msgpack'), "rb") as f:
|
||||
params: Dict[str, numpy.ndarray] = serialization.msgpack_restore(f.read())
|
||||
params: Dict[str, numpy.ndarray] = msgpack_restore(f.read())
|
||||
|
||||
P: Dict[str, numpy.ndarray] = traverse_util.flatten_dict(params, sep='.')
|
||||
P: Dict[str, numpy.ndarray] = flatten_dict(params, sep='.')
|
||||
|
||||
for i in list(P.keys()):
|
||||
j = i
|
||||
|
@ -42,7 +43,7 @@ def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
|
|||
|
||||
def load_dalle_bart_flax_params(path: str) -> Dict[str, numpy.ndarray]:
|
||||
with open(os.path.join(path, "flax_model.msgpack"), "rb") as f:
|
||||
params = serialization.msgpack_restore(f.read())
|
||||
params = msgpack_restore(f.read())
|
||||
|
||||
for codec in ['encoder', 'decoder']:
|
||||
k = 'FlaxBart{}Layers'.format(codec.title())
|
||||
|
@ -82,7 +83,7 @@ def convert_dalle_bart_torch_from_flax_params(
|
|||
is_encoder: bool
|
||||
) -> dict:
|
||||
P = deepcopy(params)
|
||||
P: Dict[str, numpy.ndarray] = traverse_util.flatten_dict(P, sep='.')
|
||||
P: Dict[str, numpy.ndarray] = flatten_dict(P, sep='.')
|
||||
|
||||
for i in P:
|
||||
P[i] = torch.tensor(P[i])
|
||||
|
|
|
@ -160,14 +160,15 @@ class DalleBartDecoderTorch(nn.Module):
|
|||
text_tokens: LongTensor,
|
||||
encoder_state: FloatTensor,
|
||||
attention_state: FloatTensor,
|
||||
prev_token_and_index: LongTensor
|
||||
prev_token: LongTensor,
|
||||
token_index: LongTensor
|
||||
) -> Tuple[LongTensor, FloatTensor]:
|
||||
attention_mask = text_tokens.not_equal(1)
|
||||
batch_count = encoder_state.shape[0]
|
||||
prev_token = torch.cat([prev_token_and_index[:1]] * batch_count)
|
||||
token_index = torch.cat([prev_token_and_index[1:]] * batch_count)
|
||||
decoder_state = self.embed_tokens.forward(prev_token)
|
||||
decoder_state += self.embed_positions.forward(token_index)
|
||||
prev_token_batched = torch.cat([prev_token] * batch_count)
|
||||
token_index_batched = torch.cat([token_index] * batch_count)
|
||||
decoder_state = self.embed_tokens.forward(prev_token_batched)
|
||||
decoder_state += self.embed_positions.forward(token_index_batched)
|
||||
decoder_state = self.layernorm_embedding.forward(decoder_state)
|
||||
decoder_state = decoder_state[:, None]
|
||||
attention_states_new = []
|
||||
|
@ -177,7 +178,7 @@ class DalleBartDecoderTorch(nn.Module):
|
|||
encoder_state,
|
||||
attention_state[i],
|
||||
attention_mask,
|
||||
token_index[:1]
|
||||
token_index
|
||||
)
|
||||
attention_states_new.append(attention_state_layer)
|
||||
decoder_state = self.final_ln(decoder_state)
|
||||
|
@ -185,7 +186,7 @@ class DalleBartDecoderTorch(nn.Module):
|
|||
a = self.condition_factor
|
||||
logits: FloatTensor = a * logits[0, -1] + (1 - a) * logits[1, -1]
|
||||
|
||||
top_logits = logits.sort(descending=True)[0][:50]
|
||||
top_logits, _ = logits.topk(50, dim=-1)
|
||||
probs = torch.where(
|
||||
logits < top_logits[-1],
|
||||
self.zero_prob,
|
||||
|
@ -206,12 +207,12 @@ class DalleBartDecoderTorch(nn.Module):
|
|||
image_token = self.start_token
|
||||
|
||||
for i in range(self.sample_token_count):
|
||||
token_index = self.token_indices[i:i+1]
|
||||
probs, attention_state = self.decode_step(
|
||||
text_tokens = text_tokens,
|
||||
encoder_state = encoder_state,
|
||||
attention_state = attention_state,
|
||||
prev_token_and_index = torch.cat([image_token, token_index])
|
||||
prev_token = image_token,
|
||||
token_index = self.token_indices[[i]]
|
||||
)
|
||||
|
||||
image_token = torch.multinomial(probs, 1)
|
||||
|
|
|
@ -61,6 +61,7 @@ class AttentionBlock(Module):
|
|||
h = self.proj_out.forward(h)
|
||||
return x + h
|
||||
|
||||
|
||||
class MiddleLayer(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -74,6 +75,7 @@ class MiddleLayer(Module):
|
|||
h = self.block_2.forward(h)
|
||||
return h
|
||||
|
||||
|
||||
class Upsample(Module):
|
||||
def __init__(self, log2_count):
|
||||
super().__init__()
|
||||
|
@ -86,6 +88,7 @@ class Upsample(Module):
|
|||
x = self.conv.forward(x)
|
||||
return x
|
||||
|
||||
|
||||
class UpsampleBlock(Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -124,6 +127,7 @@ class UpsampleBlock(Module):
|
|||
h = self.upsample.forward(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -154,6 +158,7 @@ class Decoder(Module):
|
|||
z = self.conv_out.forward(z)
|
||||
return z
|
||||
|
||||
|
||||
class VQGanDetokenizer(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue
Block a user