display intermediate images
This commit is contained in:
@@ -5,6 +5,9 @@ torch.set_grad_enabled(False)
|
||||
|
||||
from .dalle_bart_encoder import GLU, AttentionBase
|
||||
|
||||
IMAGE_TOKEN_COUNT = 256
|
||||
BLANK_TOKEN = 6965
|
||||
|
||||
|
||||
class DecoderCrossAttention(AttentionBase):
|
||||
def forward(
|
||||
@@ -20,9 +23,9 @@ class DecoderCrossAttention(AttentionBase):
|
||||
|
||||
|
||||
class DecoderSelfAttention(AttentionBase):
|
||||
def __init__(self, head_count: int, embed_count: int, token_count: int):
|
||||
def __init__(self, head_count: int, embed_count: int):
|
||||
super().__init__(head_count, embed_count)
|
||||
token_indices = torch.arange(token_count)
|
||||
token_indices = torch.arange(IMAGE_TOKEN_COUNT)
|
||||
if torch.cuda.is_available(): token_indices = token_indices.cuda()
|
||||
self.token_indices = token_indices
|
||||
|
||||
@@ -48,19 +51,13 @@ class DecoderSelfAttention(AttentionBase):
|
||||
class DecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
image_token_count: int,
|
||||
head_count: int,
|
||||
embed_count: int,
|
||||
glu_embed_count: int
|
||||
):
|
||||
super().__init__()
|
||||
self.image_token_count = image_token_count
|
||||
self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
|
||||
self.self_attn = DecoderSelfAttention(
|
||||
head_count,
|
||||
embed_count,
|
||||
image_token_count
|
||||
)
|
||||
self.self_attn = DecoderSelfAttention(head_count, embed_count)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(embed_count)
|
||||
self.pre_encoder_attn_layer_norm = nn.LayerNorm(embed_count)
|
||||
self.encoder_attn = DecoderCrossAttention(head_count, embed_count)
|
||||
@@ -110,7 +107,6 @@ class DalleBartDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
image_vocab_count: int,
|
||||
image_token_count: int,
|
||||
embed_count: int,
|
||||
attention_head_count: int,
|
||||
glu_embed_count: int,
|
||||
@@ -121,12 +117,10 @@ class DalleBartDecoder(nn.Module):
|
||||
self.layer_count = layer_count
|
||||
self.embed_count = embed_count
|
||||
self.condition_factor = 10.0
|
||||
self.image_token_count = image_token_count
|
||||
self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count)
|
||||
self.embed_positions = nn.Embedding(image_token_count, embed_count)
|
||||
self.embed_positions = nn.Embedding(IMAGE_TOKEN_COUNT, embed_count)
|
||||
self.layers: List[DecoderLayer] = nn.ModuleList([
|
||||
DecoderLayer(
|
||||
image_token_count,
|
||||
attention_head_count,
|
||||
embed_count,
|
||||
glu_embed_count
|
||||
@@ -137,7 +131,7 @@ class DalleBartDecoder(nn.Module):
|
||||
self.final_ln = nn.LayerNorm(embed_count)
|
||||
self.lm_head = nn.Linear(embed_count, image_vocab_count + 1, bias=False)
|
||||
self.zero_prob = torch.zeros([1])
|
||||
self.token_indices = torch.arange(self.image_token_count)
|
||||
self.token_indices = torch.arange(IMAGE_TOKEN_COUNT)
|
||||
self.start_token = torch.tensor([start_token]).to(torch.long)
|
||||
if torch.cuda.is_available():
|
||||
self.zero_prob = self.zero_prob.cuda()
|
||||
@@ -183,13 +177,13 @@ class DalleBartDecoder(nn.Module):
|
||||
torch.exp(logits - top_logits[:, [0]])
|
||||
)
|
||||
return probs, attention_state
|
||||
|
||||
|
||||
|
||||
def decode_row(
|
||||
self,
|
||||
row_index: int,
|
||||
attention_mask: BoolTensor,
|
||||
encoder_state: FloatTensor,
|
||||
attention_mask: BoolTensor,
|
||||
attention_state: FloatTensor,
|
||||
image_tokens_sequence: LongTensor
|
||||
) -> Tuple[FloatTensor, LongTensor]:
|
||||
@@ -202,19 +196,18 @@ class DalleBartDecoder(nn.Module):
|
||||
prev_tokens = image_tokens_sequence[:, i],
|
||||
token_index = self.token_indices[[i]]
|
||||
)
|
||||
|
||||
image_tokens_sequence[:, i + 1] = torch.multinomial(probs, 1)[:, 0]
|
||||
|
||||
return attention_state, image_tokens_sequence
|
||||
|
||||
|
||||
def forward(
|
||||
|
||||
def decode_initial(
|
||||
self,
|
||||
seed: int,
|
||||
image_count: int,
|
||||
row_count: int,
|
||||
text_tokens: LongTensor,
|
||||
encoder_state: FloatTensor
|
||||
) -> LongTensor:
|
||||
) -> Tuple[FloatTensor, FloatTensor, FloatTensor, LongTensor]:
|
||||
expanded_indices = [0] * image_count + [1] * image_count
|
||||
text_tokens = text_tokens[expanded_indices]
|
||||
encoder_state = encoder_state[expanded_indices]
|
||||
@@ -223,13 +216,13 @@ class DalleBartDecoder(nn.Module):
|
||||
attention_state_shape = (
|
||||
self.layer_count,
|
||||
image_count * 4,
|
||||
self.image_token_count,
|
||||
IMAGE_TOKEN_COUNT,
|
||||
self.embed_count
|
||||
)
|
||||
attention_state = torch.zeros(attention_state_shape)
|
||||
image_tokens_sequence = torch.full(
|
||||
(image_count, self.image_token_count + 1),
|
||||
6965, # black token
|
||||
(image_count, IMAGE_TOKEN_COUNT + 1),
|
||||
BLANK_TOKEN,
|
||||
dtype=torch.long
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
@@ -238,13 +231,6 @@ class DalleBartDecoder(nn.Module):
|
||||
|
||||
image_tokens_sequence[:, 0] = self.start_token[0]
|
||||
|
||||
for row_index in range(row_count):
|
||||
attention_state, image_tokens_sequence = self.decode_row(
|
||||
row_index,
|
||||
attention_mask,
|
||||
encoder_state,
|
||||
attention_state,
|
||||
image_tokens_sequence
|
||||
)
|
||||
|
||||
return image_tokens_sequence[:, 1:]
|
||||
if seed > 0: torch.manual_seed(seed)
|
||||
|
||||
return encoder_state, attention_mask, attention_state, image_tokens_sequence
|
Reference in New Issue
Block a user