0.3.13, simplified code, specify device when initializing MinDalle
This commit is contained in:
@@ -4,7 +4,6 @@ from torch import nn, LongTensor, FloatTensor, BoolTensor
|
||||
from .dalle_bart_encoder import GLU, AttentionBase
|
||||
|
||||
IMAGE_TOKEN_COUNT = 256
|
||||
BLANK_TOKEN = 6965
|
||||
|
||||
|
||||
class DecoderCrossAttention(AttentionBase):
|
||||
@@ -23,21 +22,18 @@ class DecoderCrossAttention(AttentionBase):
|
||||
class DecoderSelfAttention(AttentionBase):
|
||||
def __init__(self, head_count: int, embed_count: int):
|
||||
super().__init__(head_count, embed_count)
|
||||
token_indices = torch.arange(IMAGE_TOKEN_COUNT)
|
||||
if torch.cuda.is_available(): token_indices = token_indices.cuda()
|
||||
self.token_indices = token_indices
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
decoder_state: FloatTensor,
|
||||
attention_state: FloatTensor,
|
||||
attn_mask: BoolTensor,
|
||||
token_index: LongTensor
|
||||
) -> Tuple[FloatTensor, FloatTensor]:
|
||||
keys = self.k_proj.forward(decoder_state)
|
||||
values = self.v_proj.forward(decoder_state)
|
||||
queries = self.q_proj.forward(decoder_state)
|
||||
attn_mask = self.token_indices < token_index + 1
|
||||
attn_mask = attn_mask[None][[0] * decoder_state.shape[0]]
|
||||
attn_state_new = torch.cat([keys, values]).to(attention_state.dtype)
|
||||
attention_state[:, token_index] = attn_state_new
|
||||
batch_count = decoder_state.shape[0]
|
||||
@@ -52,7 +48,8 @@ class DecoderLayer(nn.Module):
|
||||
self,
|
||||
head_count: int,
|
||||
embed_count: int,
|
||||
glu_embed_count: int
|
||||
glu_embed_count: int,
|
||||
device: str
|
||||
):
|
||||
super().__init__()
|
||||
self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
|
||||
@@ -62,6 +59,7 @@ class DecoderLayer(nn.Module):
|
||||
self.encoder_attn = DecoderCrossAttention(head_count, embed_count)
|
||||
self.encoder_attn_layer_norm = nn.LayerNorm(embed_count)
|
||||
self.glu = GLU(embed_count, glu_embed_count)
|
||||
self.token_indices = torch.arange(IMAGE_TOKEN_COUNT, device=device)
|
||||
|
||||
|
||||
def forward(
|
||||
@@ -73,12 +71,15 @@ class DecoderLayer(nn.Module):
|
||||
token_index: LongTensor
|
||||
) -> Tuple[FloatTensor, FloatTensor]:
|
||||
# Self Attention
|
||||
self_attn_mask = self.token_indices < token_index + 1
|
||||
self_attn_mask = self_attn_mask[None][[0] * decoder_state.shape[0]]
|
||||
residual = decoder_state
|
||||
decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state)
|
||||
decoder_state, attention_state = self.self_attn.forward(
|
||||
decoder_state,
|
||||
attention_state,
|
||||
token_index
|
||||
decoder_state=decoder_state,
|
||||
attention_state=attention_state,
|
||||
attn_mask=self_attn_mask,
|
||||
token_index=token_index
|
||||
)
|
||||
decoder_state = self.self_attn_layer_norm.forward(decoder_state)
|
||||
decoder_state = residual + decoder_state
|
||||
@@ -87,9 +88,9 @@ class DecoderLayer(nn.Module):
|
||||
residual = decoder_state
|
||||
decoder_state = self.pre_encoder_attn_layer_norm.forward(decoder_state)
|
||||
decoder_state = self.encoder_attn.forward(
|
||||
decoder_state,
|
||||
encoder_state,
|
||||
attention_mask
|
||||
decoder_state=decoder_state,
|
||||
encoder_state=encoder_state,
|
||||
attention_mask=attention_mask
|
||||
)
|
||||
decoder_state = self.encoder_attn_layer_norm.forward(decoder_state)
|
||||
decoder_state = residual + decoder_state
|
||||
@@ -110,7 +111,7 @@ class DalleBartDecoder(nn.Module):
|
||||
attention_head_count: int,
|
||||
glu_embed_count: int,
|
||||
layer_count: int,
|
||||
start_token: int
|
||||
device: str
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_count = layer_count
|
||||
@@ -120,70 +121,28 @@ class DalleBartDecoder(nn.Module):
|
||||
self.embed_positions = nn.Embedding(IMAGE_TOKEN_COUNT, embed_count)
|
||||
self.layers: List[DecoderLayer] = nn.ModuleList([
|
||||
DecoderLayer(
|
||||
attention_head_count,
|
||||
embed_count,
|
||||
glu_embed_count
|
||||
head_count=attention_head_count,
|
||||
embed_count=embed_count,
|
||||
glu_embed_count=glu_embed_count,
|
||||
device=device
|
||||
)
|
||||
for _ in range(layer_count)
|
||||
])
|
||||
self.layernorm_embedding = nn.LayerNorm(embed_count)
|
||||
self.final_ln = nn.LayerNorm(embed_count)
|
||||
self.lm_head = nn.Linear(embed_count, image_vocab_count + 1, bias=False)
|
||||
self.zero_prob = torch.zeros([1])
|
||||
self.token_indices = torch.arange(IMAGE_TOKEN_COUNT)
|
||||
self.start_token = torch.tensor([start_token]).to(torch.long)
|
||||
if torch.cuda.is_available():
|
||||
self.zero_prob = self.zero_prob.cuda()
|
||||
self.token_indices = self.token_indices.cuda()
|
||||
self.start_token = self.start_token.cuda()
|
||||
self.token_indices = torch.arange(IMAGE_TOKEN_COUNT, device=device)
|
||||
|
||||
|
||||
def decode_initial(
|
||||
def forward(
|
||||
self,
|
||||
seed: int,
|
||||
image_count: int,
|
||||
text_tokens: LongTensor,
|
||||
encoder_state: FloatTensor
|
||||
) -> Tuple[FloatTensor, FloatTensor, FloatTensor, LongTensor]:
|
||||
expanded_indices = [0] * image_count + [1] * image_count
|
||||
text_tokens = text_tokens[expanded_indices]
|
||||
encoder_state = encoder_state[expanded_indices]
|
||||
attention_mask = text_tokens.not_equal(1)
|
||||
|
||||
attention_state_shape = (
|
||||
self.layer_count,
|
||||
image_count * 4,
|
||||
IMAGE_TOKEN_COUNT,
|
||||
self.embed_count
|
||||
)
|
||||
attention_state = torch.zeros(attention_state_shape)
|
||||
image_tokens_sequence = torch.full(
|
||||
(image_count, IMAGE_TOKEN_COUNT + 1),
|
||||
BLANK_TOKEN,
|
||||
dtype=torch.long
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
attention_state = attention_state.cuda()
|
||||
image_tokens_sequence = image_tokens_sequence.cuda()
|
||||
|
||||
image_tokens_sequence[:, 0] = self.start_token[0]
|
||||
|
||||
if seed > 0: torch.manual_seed(seed)
|
||||
|
||||
return encoder_state, attention_mask, attention_state, image_tokens_sequence
|
||||
|
||||
|
||||
def decode_step(
|
||||
self,
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
supercondition_factor: float,
|
||||
settings: FloatTensor,
|
||||
attention_mask: BoolTensor,
|
||||
encoder_state: FloatTensor,
|
||||
attention_state: FloatTensor,
|
||||
prev_tokens: LongTensor,
|
||||
token_index: LongTensor
|
||||
) -> Tuple[FloatTensor, FloatTensor]:
|
||||
) -> Tuple[LongTensor, FloatTensor]:
|
||||
image_count = encoder_state.shape[0] // 2
|
||||
token_index_batched = token_index[[0] * image_count * 2]
|
||||
prev_tokens = prev_tokens[list(range(image_count)) * 2]
|
||||
@@ -202,44 +161,19 @@ class DalleBartDecoder(nn.Module):
|
||||
)
|
||||
decoder_state = self.final_ln(decoder_state)
|
||||
logits = self.lm_head(decoder_state)
|
||||
a = supercondition_factor
|
||||
temperature = settings[0]
|
||||
top_k = settings[1].to(torch.long)
|
||||
supercondition_factor = settings[2]
|
||||
logits = logits[:, -1, : 2 ** 14]
|
||||
logits: FloatTensor = (
|
||||
logits[:image_count, -1] * (1 - a) +
|
||||
logits[image_count:, -1] * a
|
||||
logits[:image_count] * (1 - supercondition_factor) +
|
||||
logits[image_count:] * supercondition_factor
|
||||
)
|
||||
|
||||
top_logits, _ = logits.topk(top_k, dim=-1)
|
||||
is_kept = logits >= top_logits[:, [-1]]
|
||||
logits -= top_logits[:, [0]]
|
||||
logits /= max(temperature, 1e-6)
|
||||
probs = torch.where(is_kept, torch.exp(logits), self.zero_prob)
|
||||
probs[:, 2 ** 14:] = 0 # vqgan vocab_count is only 2 ** 14
|
||||
return probs, attention_state
|
||||
|
||||
|
||||
def decode_row(
|
||||
self,
|
||||
row_index: int,
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
supercondition_factor: float,
|
||||
encoder_state: FloatTensor,
|
||||
attention_mask: BoolTensor,
|
||||
attention_state: FloatTensor,
|
||||
image_tokens_sequence: LongTensor
|
||||
) -> Tuple[FloatTensor, LongTensor]:
|
||||
for col_index in range(16):
|
||||
i = 16 * row_index + col_index
|
||||
probs, attention_state = self.decode_step(
|
||||
temperature = temperature,
|
||||
top_k = top_k,
|
||||
supercondition_factor = supercondition_factor,
|
||||
attention_mask = attention_mask,
|
||||
encoder_state = encoder_state,
|
||||
attention_state = attention_state,
|
||||
prev_tokens = image_tokens_sequence[:, i],
|
||||
token_index = self.token_indices[[i]]
|
||||
)
|
||||
image_tokens_sequence[:, i + 1] = torch.multinomial(probs, 1)[:, 0]
|
||||
|
||||
return attention_state, image_tokens_sequence
|
||||
logits_sorted, _ = logits.sort(descending=True)
|
||||
is_kept = logits >= logits_sorted[:, top_k: top_k + 1]
|
||||
logits -= logits_sorted[:, [0]]
|
||||
logits /= temperature
|
||||
logits.exp_()
|
||||
logits *= is_kept.to(torch.float32)
|
||||
image_tokens = torch.multinomial(logits, 1)[:, 0]
|
||||
return image_tokens, attention_state
|
@@ -4,7 +4,7 @@ from torch import nn, BoolTensor, FloatTensor, LongTensor
|
||||
|
||||
|
||||
class GLU(nn.Module):
|
||||
def __init__(self, count_in_out, count_middle):
|
||||
def __init__(self, count_in_out: int, count_middle: int):
|
||||
super().__init__()
|
||||
self.gelu = nn.GELU()
|
||||
self.ln0 = nn.LayerNorm(count_in_out)
|
||||
@@ -33,8 +33,6 @@ class AttentionBase(nn.Module):
|
||||
self.v_proj = nn.Linear(embed_count, embed_count, bias=False)
|
||||
self.q_proj = nn.Linear(embed_count, embed_count, bias=False)
|
||||
self.out_proj = nn.Linear(embed_count, embed_count, bias=False)
|
||||
self.one = torch.ones((1, 1))
|
||||
if torch.cuda.is_available(): self.one = self.one.cuda()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -48,11 +46,7 @@ class AttentionBase(nn.Module):
|
||||
queries = queries.reshape(queries.shape[:2] + (self.head_count, -1))
|
||||
queries /= queries.shape[-1] ** 0.5
|
||||
|
||||
attention_bias = torch.where(
|
||||
attention_mask,
|
||||
self.one * 0,
|
||||
self.one * (-torch.inf),
|
||||
)
|
||||
attention_bias = (1 - attention_mask.to(torch.float32)) * -1e12
|
||||
attention_weights: FloatTensor = torch.einsum(
|
||||
'bqhc,bkhc->bhqk',
|
||||
queries,
|
||||
@@ -115,7 +109,8 @@ class DalleBartEncoder(nn.Module):
|
||||
attention_head_count: int,
|
||||
text_vocab_count: int,
|
||||
text_token_count: int,
|
||||
glu_embed_count: int
|
||||
glu_embed_count: int,
|
||||
device: str
|
||||
):
|
||||
super().__init__()
|
||||
self.text_vocab_count = text_vocab_count
|
||||
@@ -131,17 +126,14 @@ class DalleBartEncoder(nn.Module):
|
||||
])
|
||||
self.layernorm_embedding = nn.LayerNorm(embed_count)
|
||||
self.final_ln = nn.LayerNorm(embed_count)
|
||||
self.token_indices = torch.arange(text_token_count).to(torch.long)
|
||||
if torch.cuda.is_available():
|
||||
self.token_indices = self.token_indices.cuda()
|
||||
token_indices = torch.arange(text_token_count, device=device)
|
||||
self.pose_tokens = torch.stack([token_indices] * 2)
|
||||
|
||||
def forward(self, text_tokens: LongTensor) -> FloatTensor:
|
||||
attention_mask = text_tokens.not_equal(1)
|
||||
pose_tokens = self.token_indices[None][[0] * text_tokens.shape[0]]
|
||||
text_tokens.clamp_(0, self.text_vocab_count - 1)
|
||||
encoder_state = (
|
||||
self.embed_tokens.forward(text_tokens) +
|
||||
self.embed_positions.forward(pose_tokens)
|
||||
self.embed_positions.forward(self.pose_tokens)
|
||||
)
|
||||
encoder_state = self.layernorm_embedding.forward(encoder_state)
|
||||
for layer in self.layers:
|
||||
|
Reference in New Issue
Block a user