sort -> topk, prev_token_and_index -> prev_token, token_index
This commit is contained in:
parent
fb97ba5e20
commit
df9aa6f915
|
@ -2,16 +2,17 @@ import os
|
||||||
import numpy
|
import numpy
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from flax import traverse_util, serialization
|
from flax.traverse_util import flatten_dict
|
||||||
|
from flax.serialization import msgpack_restore
|
||||||
import torch
|
import torch
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
|
|
||||||
def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
|
def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
|
||||||
with open(os.path.join(path, 'flax_model.msgpack'), "rb") as f:
|
with open(os.path.join(path, 'flax_model.msgpack'), "rb") as f:
|
||||||
params: Dict[str, numpy.ndarray] = serialization.msgpack_restore(f.read())
|
params: Dict[str, numpy.ndarray] = msgpack_restore(f.read())
|
||||||
|
|
||||||
P: Dict[str, numpy.ndarray] = traverse_util.flatten_dict(params, sep='.')
|
P: Dict[str, numpy.ndarray] = flatten_dict(params, sep='.')
|
||||||
|
|
||||||
for i in list(P.keys()):
|
for i in list(P.keys()):
|
||||||
j = i
|
j = i
|
||||||
|
@ -42,7 +43,7 @@ def load_vqgan_torch_params(path: str) -> Dict[str, torch.Tensor]:
|
||||||
|
|
||||||
def load_dalle_bart_flax_params(path: str) -> Dict[str, numpy.ndarray]:
|
def load_dalle_bart_flax_params(path: str) -> Dict[str, numpy.ndarray]:
|
||||||
with open(os.path.join(path, "flax_model.msgpack"), "rb") as f:
|
with open(os.path.join(path, "flax_model.msgpack"), "rb") as f:
|
||||||
params = serialization.msgpack_restore(f.read())
|
params = msgpack_restore(f.read())
|
||||||
|
|
||||||
for codec in ['encoder', 'decoder']:
|
for codec in ['encoder', 'decoder']:
|
||||||
k = 'FlaxBart{}Layers'.format(codec.title())
|
k = 'FlaxBart{}Layers'.format(codec.title())
|
||||||
|
@ -82,7 +83,7 @@ def convert_dalle_bart_torch_from_flax_params(
|
||||||
is_encoder: bool
|
is_encoder: bool
|
||||||
) -> dict:
|
) -> dict:
|
||||||
P = deepcopy(params)
|
P = deepcopy(params)
|
||||||
P: Dict[str, numpy.ndarray] = traverse_util.flatten_dict(P, sep='.')
|
P: Dict[str, numpy.ndarray] = flatten_dict(P, sep='.')
|
||||||
|
|
||||||
for i in P:
|
for i in P:
|
||||||
P[i] = torch.tensor(P[i])
|
P[i] = torch.tensor(P[i])
|
||||||
|
|
|
@ -160,14 +160,15 @@ class DalleBartDecoderTorch(nn.Module):
|
||||||
text_tokens: LongTensor,
|
text_tokens: LongTensor,
|
||||||
encoder_state: FloatTensor,
|
encoder_state: FloatTensor,
|
||||||
attention_state: FloatTensor,
|
attention_state: FloatTensor,
|
||||||
prev_token_and_index: LongTensor
|
prev_token: LongTensor,
|
||||||
|
token_index: LongTensor
|
||||||
) -> Tuple[LongTensor, FloatTensor]:
|
) -> Tuple[LongTensor, FloatTensor]:
|
||||||
attention_mask = text_tokens.not_equal(1)
|
attention_mask = text_tokens.not_equal(1)
|
||||||
batch_count = encoder_state.shape[0]
|
batch_count = encoder_state.shape[0]
|
||||||
prev_token = torch.cat([prev_token_and_index[:1]] * batch_count)
|
prev_token_batched = torch.cat([prev_token] * batch_count)
|
||||||
token_index = torch.cat([prev_token_and_index[1:]] * batch_count)
|
token_index_batched = torch.cat([token_index] * batch_count)
|
||||||
decoder_state = self.embed_tokens.forward(prev_token)
|
decoder_state = self.embed_tokens.forward(prev_token_batched)
|
||||||
decoder_state += self.embed_positions.forward(token_index)
|
decoder_state += self.embed_positions.forward(token_index_batched)
|
||||||
decoder_state = self.layernorm_embedding.forward(decoder_state)
|
decoder_state = self.layernorm_embedding.forward(decoder_state)
|
||||||
decoder_state = decoder_state[:, None]
|
decoder_state = decoder_state[:, None]
|
||||||
attention_states_new = []
|
attention_states_new = []
|
||||||
|
@ -177,7 +178,7 @@ class DalleBartDecoderTorch(nn.Module):
|
||||||
encoder_state,
|
encoder_state,
|
||||||
attention_state[i],
|
attention_state[i],
|
||||||
attention_mask,
|
attention_mask,
|
||||||
token_index[:1]
|
token_index
|
||||||
)
|
)
|
||||||
attention_states_new.append(attention_state_layer)
|
attention_states_new.append(attention_state_layer)
|
||||||
decoder_state = self.final_ln(decoder_state)
|
decoder_state = self.final_ln(decoder_state)
|
||||||
|
@ -185,7 +186,7 @@ class DalleBartDecoderTorch(nn.Module):
|
||||||
a = self.condition_factor
|
a = self.condition_factor
|
||||||
logits: FloatTensor = a * logits[0, -1] + (1 - a) * logits[1, -1]
|
logits: FloatTensor = a * logits[0, -1] + (1 - a) * logits[1, -1]
|
||||||
|
|
||||||
top_logits = logits.sort(descending=True)[0][:50]
|
top_logits, _ = logits.topk(50, dim=-1)
|
||||||
probs = torch.where(
|
probs = torch.where(
|
||||||
logits < top_logits[-1],
|
logits < top_logits[-1],
|
||||||
self.zero_prob,
|
self.zero_prob,
|
||||||
|
@ -206,12 +207,12 @@ class DalleBartDecoderTorch(nn.Module):
|
||||||
image_token = self.start_token
|
image_token = self.start_token
|
||||||
|
|
||||||
for i in range(self.sample_token_count):
|
for i in range(self.sample_token_count):
|
||||||
token_index = self.token_indices[i:i+1]
|
|
||||||
probs, attention_state = self.decode_step(
|
probs, attention_state = self.decode_step(
|
||||||
text_tokens = text_tokens,
|
text_tokens = text_tokens,
|
||||||
encoder_state = encoder_state,
|
encoder_state = encoder_state,
|
||||||
attention_state = attention_state,
|
attention_state = attention_state,
|
||||||
prev_token_and_index = torch.cat([image_token, token_index])
|
prev_token = image_token,
|
||||||
|
token_index = self.token_indices[[i]]
|
||||||
)
|
)
|
||||||
|
|
||||||
image_token = torch.multinomial(probs, 1)
|
image_token = torch.multinomial(probs, 1)
|
||||||
|
|
|
@ -61,6 +61,7 @@ class AttentionBlock(Module):
|
||||||
h = self.proj_out.forward(h)
|
h = self.proj_out.forward(h)
|
||||||
return x + h
|
return x + h
|
||||||
|
|
||||||
|
|
||||||
class MiddleLayer(Module):
|
class MiddleLayer(Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -74,6 +75,7 @@ class MiddleLayer(Module):
|
||||||
h = self.block_2.forward(h)
|
h = self.block_2.forward(h)
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
|
||||||
class Upsample(Module):
|
class Upsample(Module):
|
||||||
def __init__(self, log2_count):
|
def __init__(self, log2_count):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -86,6 +88,7 @@ class Upsample(Module):
|
||||||
x = self.conv.forward(x)
|
x = self.conv.forward(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class UpsampleBlock(Module):
|
class UpsampleBlock(Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -124,6 +127,7 @@ class UpsampleBlock(Module):
|
||||||
h = self.upsample.forward(h)
|
h = self.upsample.forward(h)
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
|
||||||
class Decoder(Module):
|
class Decoder(Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -154,6 +158,7 @@ class Decoder(Module):
|
||||||
z = self.conv_out.forward(z)
|
z = self.conv_out.forward(z)
|
||||||
return z
|
return z
|
||||||
|
|
||||||
|
|
||||||
class VQGanDetokenizer(Module):
|
class VQGanDetokenizer(Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user