refactored to load models once and run multiple times
This commit is contained in:
@@ -26,7 +26,8 @@ class DecoderCrossAttentionFlax(AttentionFlax):
|
||||
|
||||
|
||||
class DecoderSelfAttentionFlax(AttentionFlax):
|
||||
def __call__(self,
|
||||
def __call__(
|
||||
self,
|
||||
decoder_state: jnp.ndarray,
|
||||
keys_state: jnp.ndarray,
|
||||
values_state: jnp.ndarray,
|
||||
@@ -77,7 +78,8 @@ class DalleBartDecoderLayerFlax(nn.Module):
|
||||
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
|
||||
|
||||
@nn.compact
|
||||
def __call__(self,
|
||||
def __call__(
|
||||
self,
|
||||
decoder_state: jnp.ndarray,
|
||||
encoder_state: jnp.ndarray,
|
||||
keys_state: jnp.ndarray,
|
||||
@@ -173,7 +175,8 @@ class DalleBartDecoderFlax(nn.Module):
|
||||
self.final_ln = nn.LayerNorm(use_scale=False)
|
||||
self.lm_head = nn.Dense(self.image_vocab_count + 1, use_bias=False)
|
||||
|
||||
def __call__(self,
|
||||
def __call__(
|
||||
self,
|
||||
encoder_state: jnp.ndarray,
|
||||
keys_state: jnp.ndarray,
|
||||
values_state: jnp.ndarray,
|
||||
@@ -198,7 +201,8 @@ class DalleBartDecoderFlax(nn.Module):
|
||||
decoder_state = self.lm_head(decoder_state)
|
||||
return decoder_state, keys_state, values_state
|
||||
|
||||
def sample_image_tokens(self,
|
||||
def sample_image_tokens(
|
||||
self,
|
||||
text_tokens: jnp.ndarray,
|
||||
encoder_state: jnp.ndarray,
|
||||
prng_key: jax.random.PRNGKey,
|
||||
|
@@ -26,7 +26,8 @@ class DecoderCrossAttentionTorch(AttentionTorch):
|
||||
|
||||
|
||||
class DecoderSelfAttentionTorch(AttentionTorch):
|
||||
def forward(self,
|
||||
def forward(
|
||||
self,
|
||||
decoder_state: FloatTensor,
|
||||
keys_values: FloatTensor,
|
||||
attention_mask: BoolTensor,
|
||||
@@ -49,7 +50,8 @@ class DecoderSelfAttentionTorch(AttentionTorch):
|
||||
|
||||
|
||||
class DecoderLayerTorch(nn.Module):
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
image_token_count: int,
|
||||
head_count: int,
|
||||
embed_count: int,
|
||||
@@ -69,7 +71,8 @@ class DecoderLayerTorch(nn.Module):
|
||||
if torch.cuda.is_available():
|
||||
self.token_indices = self.token_indices.cuda()
|
||||
|
||||
def forward(self,
|
||||
def forward(
|
||||
self,
|
||||
decoder_state: FloatTensor,
|
||||
encoder_state: FloatTensor,
|
||||
keys_values_state: FloatTensor,
|
||||
@@ -111,7 +114,8 @@ class DecoderLayerTorch(nn.Module):
|
||||
|
||||
|
||||
class DalleBartDecoderTorch(nn.Module):
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
image_vocab_size: int,
|
||||
image_token_count: int,
|
||||
sample_token_count: int,
|
||||
@@ -158,7 +162,8 @@ class DalleBartDecoderTorch(nn.Module):
|
||||
self.start_token = self.start_token.cuda()
|
||||
|
||||
|
||||
def decode_step(self,
|
||||
def decode_step(
|
||||
self,
|
||||
text_tokens: LongTensor,
|
||||
encoder_state: FloatTensor,
|
||||
keys_values_state: FloatTensor,
|
||||
@@ -198,7 +203,8 @@ class DalleBartDecoderTorch(nn.Module):
|
||||
return probs, keys_values
|
||||
|
||||
|
||||
def forward(self,
|
||||
def forward(
|
||||
self,
|
||||
text_tokens: LongTensor,
|
||||
encoder_state: FloatTensor
|
||||
) -> LongTensor:
|
||||
|
@@ -34,7 +34,8 @@ class AttentionFlax(nn.Module):
|
||||
self.v_proj = nn.Dense(self.embed_count, use_bias=False)
|
||||
self.out_proj = nn.Dense(self.embed_count, use_bias=False)
|
||||
|
||||
def forward(self,
|
||||
def forward(
|
||||
self,
|
||||
keys: jnp.ndarray,
|
||||
values: jnp.ndarray,
|
||||
queries: jnp.ndarray,
|
||||
@@ -92,7 +93,8 @@ class DalleBartEncoderLayerFlax(nn.Module):
|
||||
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
|
||||
|
||||
@nn.compact
|
||||
def __call__(self,
|
||||
def __call__(
|
||||
self,
|
||||
encoder_state: jnp.ndarray,
|
||||
attention_mask: jnp.ndarray
|
||||
) -> jnp.ndarray:
|
||||
|
@@ -37,7 +37,8 @@ class AttentionTorch(nn.Module):
|
||||
self.one = torch.ones((1, 1))
|
||||
if torch.cuda.is_available(): self.one = self.one.cuda()
|
||||
|
||||
def forward(self,
|
||||
def forward(
|
||||
self,
|
||||
keys: FloatTensor,
|
||||
values: FloatTensor,
|
||||
queries: FloatTensor,
|
||||
@@ -105,7 +106,8 @@ class EncoderLayerTorch(nn.Module):
|
||||
|
||||
|
||||
class DalleBartEncoderTorch(nn.Module):
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
layer_count: int,
|
||||
embed_count: int,
|
||||
attention_head_count: int,
|
||||
|
Reference in New Issue
Block a user