|
|
|
@ -116,7 +116,6 @@ class DalleBartDecoder(nn.Module): |
|
|
|
|
super().__init__() |
|
|
|
|
self.layer_count = layer_count |
|
|
|
|
self.embed_count = embed_count |
|
|
|
|
self.condition_factor = 10.0 |
|
|
|
|
self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count) |
|
|
|
|
self.embed_positions = nn.Embedding(IMAGE_TOKEN_COUNT, embed_count) |
|
|
|
|
self.layers: List[DecoderLayer] = nn.ModuleList([ |
|
|
|
@ -141,6 +140,7 @@ class DalleBartDecoder(nn.Module): |
|
|
|
|
|
|
|
|
|
def decode_step( |
|
|
|
|
self, |
|
|
|
|
log2_supercondition_factor: int, |
|
|
|
|
attention_mask: BoolTensor, |
|
|
|
|
encoder_state: FloatTensor, |
|
|
|
|
attention_state: FloatTensor, |
|
|
|
@ -164,7 +164,7 @@ class DalleBartDecoder(nn.Module): |
|
|
|
|
) |
|
|
|
|
decoder_state = self.final_ln(decoder_state) |
|
|
|
|
logits = self.lm_head(decoder_state) |
|
|
|
|
a = self.condition_factor |
|
|
|
|
a = log2_supercondition_factor |
|
|
|
|
logits: FloatTensor = ( |
|
|
|
|
logits[:image_count, -1] * (1 - a) + |
|
|
|
|
logits[image_count:, -1] * a |
|
|
|
@ -182,6 +182,7 @@ class DalleBartDecoder(nn.Module): |
|
|
|
|
def decode_row( |
|
|
|
|
self, |
|
|
|
|
row_index: int, |
|
|
|
|
log2_supercondition_factor: int, |
|
|
|
|
encoder_state: FloatTensor, |
|
|
|
|
attention_mask: BoolTensor, |
|
|
|
|
attention_state: FloatTensor, |
|
|
|
@ -190,6 +191,7 @@ class DalleBartDecoder(nn.Module): |
|
|
|
|
for col_index in range(16): |
|
|
|
|
i = 16 * row_index + col_index |
|
|
|
|
probs, attention_state = self.decode_step( |
|
|
|
|
log2_supercondition_factor = log2_supercondition_factor, |
|
|
|
|
attention_mask = attention_mask, |
|
|
|
|
encoder_state = encoder_state, |
|
|
|
|
attention_state = attention_state, |
|
|
|
|