works with cuda
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from typing import List, Tuple
|
||||
import torch
|
||||
from torch import LongTensor, nn, FloatTensor, BoolTensor
|
||||
torch.no_grad()
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
from .dalle_bart_encoder_torch import GLUTorch, AttentionTorch
|
||||
|
||||
@@ -30,14 +30,12 @@ class DecoderSelfAttentionTorch(AttentionTorch):
|
||||
decoder_state: FloatTensor,
|
||||
keys_values: FloatTensor,
|
||||
attention_mask: BoolTensor,
|
||||
token_index: LongTensor
|
||||
token_mask: BoolTensor
|
||||
) -> Tuple[FloatTensor, FloatTensor]:
|
||||
batch_count = decoder_state.shape[0]
|
||||
token_count = keys_values.shape[1]
|
||||
shape = (batch_count, 1) + keys_values.shape[2:]
|
||||
keys = self.k_proj.forward(decoder_state).view(shape)
|
||||
values = self.v_proj.forward(decoder_state).view(shape)
|
||||
token_mask = torch.arange(token_count) == token_index
|
||||
keys_values = torch.where(
|
||||
token_mask[None, :, None, None],
|
||||
torch.cat([keys, values]),
|
||||
@@ -67,6 +65,10 @@ class DecoderLayerTorch(nn.Module):
|
||||
self.encoder_attn_layer_norm = nn.LayerNorm(embed_count)
|
||||
self.glu = GLUTorch(embed_count, glu_embed_count)
|
||||
|
||||
self.token_indices = torch.arange(self.image_token_count)
|
||||
if torch.cuda.is_available():
|
||||
self.token_indices = self.token_indices.cuda()
|
||||
|
||||
def forward(self,
|
||||
decoder_state: FloatTensor,
|
||||
encoder_state: FloatTensor,
|
||||
@@ -77,13 +79,14 @@ class DecoderLayerTorch(nn.Module):
|
||||
# Self Attention
|
||||
residual = decoder_state
|
||||
decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state)
|
||||
self_attn_mask = torch.arange(self.image_token_count) < token_index + 1
|
||||
self_attn_mask = self.token_indices < token_index + 1
|
||||
token_mask = self.token_indices == token_index
|
||||
self_attn_mask = torch.stack([self_attn_mask] * decoder_state.shape[0])
|
||||
decoder_state, keys_values_state = self.self_attn.forward(
|
||||
decoder_state,
|
||||
keys_values_state,
|
||||
self_attn_mask,
|
||||
token_index
|
||||
token_mask
|
||||
)
|
||||
decoder_state = self.self_attn_layer_norm.forward(decoder_state)
|
||||
decoder_state = residual + decoder_state
|
||||
@@ -124,13 +127,7 @@ class DalleBartDecoderTorch(nn.Module):
|
||||
self.is_verbose = is_verbose
|
||||
self.layer_count = layer_count
|
||||
self.sample_token_count = sample_token_count
|
||||
self.start_token = torch.tensor([start_token]).to(torch.long)
|
||||
self.pad_token = torch.tensor([1]).to(torch.long)
|
||||
self.condition_factor = torch.tensor([10]).to(torch.float)
|
||||
# if torch.cuda.is_available():
|
||||
# self.start_token = self.start_token.cuda()
|
||||
# self.pad_token = self.pad_token.cuda()
|
||||
# self.condition_factor = self.condition_factor.cuda()
|
||||
self.condition_factor = 10.0
|
||||
self.image_token_count = image_token_count
|
||||
self.embed_tokens = nn.Embedding(image_vocab_size + 1, embed_count)
|
||||
self.embed_positions = nn.Embedding(image_token_count, embed_count)
|
||||
@@ -152,6 +149,13 @@ class DalleBartDecoderTorch(nn.Module):
|
||||
attention_head_count,
|
||||
embed_count // attention_head_count
|
||||
)
|
||||
self.zero_prob = torch.zeros([1])
|
||||
self.token_indices = torch.arange(self.sample_token_count)
|
||||
self.start_token = torch.tensor([start_token]).to(torch.long)
|
||||
if torch.cuda.is_available():
|
||||
self.zero_prob = self.zero_prob.cuda()
|
||||
self.token_indices = self.token_indices.cuda()
|
||||
self.start_token = self.start_token.cuda()
|
||||
|
||||
|
||||
def decode_step(self,
|
||||
@@ -160,7 +164,7 @@ class DalleBartDecoderTorch(nn.Module):
|
||||
keys_values_state: FloatTensor,
|
||||
prev_token_and_index: LongTensor
|
||||
) -> Tuple[LongTensor, FloatTensor]:
|
||||
attention_mask = text_tokens.not_equal(self.pad_token)
|
||||
attention_mask = text_tokens.not_equal(1)
|
||||
batch_count = encoder_state.shape[0]
|
||||
prev_token = torch.cat([prev_token_and_index[:1]] * batch_count)
|
||||
token_index = torch.cat([prev_token_and_index[1:]] * batch_count)
|
||||
@@ -188,7 +192,7 @@ class DalleBartDecoderTorch(nn.Module):
|
||||
top_logits = logits.sort(descending=True)[0][:50]
|
||||
probs = torch.where(
|
||||
logits < top_logits[-1],
|
||||
torch.zeros([1]),
|
||||
self.zero_prob,
|
||||
torch.exp(logits - top_logits[0])
|
||||
)
|
||||
return probs, keys_values
|
||||
@@ -200,11 +204,12 @@ class DalleBartDecoderTorch(nn.Module):
|
||||
) -> LongTensor:
|
||||
image_tokens: List[LongTensor] = []
|
||||
keys_values_state = torch.zeros(self.keys_values_state_shape)
|
||||
if torch.cuda.is_available():
|
||||
keys_values_state = keys_values_state.cuda()
|
||||
image_token = self.start_token
|
||||
|
||||
for i in range(self.sample_token_count):
|
||||
token_index = torch.tensor([i]).to(torch.long)
|
||||
# if torch.cuda.is_available(): token_index = token_index.cuda()
|
||||
token_index = self.token_indices[i:i+1]
|
||||
probs, keys_values_state = self.decode_step(
|
||||
text_tokens = text_tokens,
|
||||
encoder_state = encoder_state,
|
||||
@@ -214,9 +219,5 @@ class DalleBartDecoderTorch(nn.Module):
|
||||
|
||||
image_token = torch.multinomial(probs, 1)
|
||||
image_tokens += [image_token]
|
||||
|
||||
if self.is_verbose:
|
||||
token = int(image_token.detach().numpy())
|
||||
print("image token {} is {}".format(i, token))
|
||||
|
||||
return torch.cat(image_tokens)
|
@@ -1,7 +1,7 @@
|
||||
from typing import List
|
||||
import torch
|
||||
from torch import nn, BoolTensor, FloatTensor, LongTensor
|
||||
torch.no_grad()
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
|
||||
class GLUTorch(nn.Module):
|
||||
@@ -34,6 +34,8 @@ class AttentionTorch(nn.Module):
|
||||
self.v_proj = nn.Linear(embed_count, embed_count, bias=False)
|
||||
self.q_proj = nn.Linear(embed_count, embed_count, bias=False)
|
||||
self.out_proj = nn.Linear(embed_count, embed_count, bias=False)
|
||||
self.one = torch.ones((1, 1))
|
||||
if torch.cuda.is_available(): self.one = self.one.cuda()
|
||||
|
||||
def forward(self,
|
||||
keys: FloatTensor,
|
||||
@@ -43,8 +45,8 @@ class AttentionTorch(nn.Module):
|
||||
) -> FloatTensor:
|
||||
attention_bias = torch.where(
|
||||
attention_mask,
|
||||
torch.full(attention_mask.shape, 0.0),
|
||||
torch.full(attention_mask.shape, -torch.inf),
|
||||
self.one * 0,
|
||||
self.one * (-torch.inf),
|
||||
)
|
||||
attention_weights: FloatTensor = torch.einsum(
|
||||
'bqhc,bkhc->bhqk',
|
||||
@@ -124,11 +126,14 @@ class DalleBartEncoderTorch(nn.Module):
|
||||
])
|
||||
self.layernorm_embedding = nn.LayerNorm(embed_count)
|
||||
self.final_ln = nn.LayerNorm(embed_count)
|
||||
self.token_indices = torch.arange(text_token_count).to(torch.long)
|
||||
if torch.cuda.is_available():
|
||||
self.token_indices = self.token_indices.cuda()
|
||||
|
||||
def forward(self, text_tokens: LongTensor) -> FloatTensor:
|
||||
attention_mask = text_tokens.not_equal(1)
|
||||
batch_count, token_count = text_tokens.shape
|
||||
pose_tokens = torch.stack([torch.arange(token_count)] * batch_count)
|
||||
batch_count = text_tokens.shape[0]
|
||||
pose_tokens = torch.stack([self.token_indices] * batch_count)
|
||||
encoder_state = (
|
||||
self.embed_tokens.forward(text_tokens) +
|
||||
self.embed_positions.forward(pose_tokens)
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
|
||||
torch.no_grad()
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
BATCH_COUNT: int = 1
|
||||
|
||||
|
Reference in New Issue
Block a user