refactored to load models once and run multiple times

This commit is contained in:
Brett Kuprel
2022-06-29 09:42:12 -04:00
parent 1ef9b0b929
commit ed91ab4a30
11 changed files with 225 additions and 282 deletions

View File

@@ -26,7 +26,8 @@ class DecoderCrossAttentionFlax(AttentionFlax):
class DecoderSelfAttentionFlax(AttentionFlax):
def __call__(self,
def __call__(
self,
decoder_state: jnp.ndarray,
keys_state: jnp.ndarray,
values_state: jnp.ndarray,
@@ -77,7 +78,8 @@ class DalleBartDecoderLayerFlax(nn.Module):
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
@nn.compact
def __call__(self,
def __call__(
self,
decoder_state: jnp.ndarray,
encoder_state: jnp.ndarray,
keys_state: jnp.ndarray,
@@ -173,7 +175,8 @@ class DalleBartDecoderFlax(nn.Module):
self.final_ln = nn.LayerNorm(use_scale=False)
self.lm_head = nn.Dense(self.image_vocab_count + 1, use_bias=False)
def __call__(self,
def __call__(
self,
encoder_state: jnp.ndarray,
keys_state: jnp.ndarray,
values_state: jnp.ndarray,
@@ -198,7 +201,8 @@ class DalleBartDecoderFlax(nn.Module):
decoder_state = self.lm_head(decoder_state)
return decoder_state, keys_state, values_state
def sample_image_tokens(self,
def sample_image_tokens(
self,
text_tokens: jnp.ndarray,
encoder_state: jnp.ndarray,
prng_key: jax.random.PRNGKey,

View File

@@ -26,7 +26,8 @@ class DecoderCrossAttentionTorch(AttentionTorch):
class DecoderSelfAttentionTorch(AttentionTorch):
def forward(self,
def forward(
self,
decoder_state: FloatTensor,
keys_values: FloatTensor,
attention_mask: BoolTensor,
@@ -49,7 +50,8 @@ class DecoderSelfAttentionTorch(AttentionTorch):
class DecoderLayerTorch(nn.Module):
def __init__(self,
def __init__(
self,
image_token_count: int,
head_count: int,
embed_count: int,
@@ -69,7 +71,8 @@ class DecoderLayerTorch(nn.Module):
if torch.cuda.is_available():
self.token_indices = self.token_indices.cuda()
def forward(self,
def forward(
self,
decoder_state: FloatTensor,
encoder_state: FloatTensor,
keys_values_state: FloatTensor,
@@ -111,7 +114,8 @@ class DecoderLayerTorch(nn.Module):
class DalleBartDecoderTorch(nn.Module):
def __init__(self,
def __init__(
self,
image_vocab_size: int,
image_token_count: int,
sample_token_count: int,
@@ -158,7 +162,8 @@ class DalleBartDecoderTorch(nn.Module):
self.start_token = self.start_token.cuda()
def decode_step(self,
def decode_step(
self,
text_tokens: LongTensor,
encoder_state: FloatTensor,
keys_values_state: FloatTensor,
@@ -198,7 +203,8 @@ class DalleBartDecoderTorch(nn.Module):
return probs, keys_values
def forward(self,
def forward(
self,
text_tokens: LongTensor,
encoder_state: FloatTensor
) -> LongTensor:

View File

@@ -34,7 +34,8 @@ class AttentionFlax(nn.Module):
self.v_proj = nn.Dense(self.embed_count, use_bias=False)
self.out_proj = nn.Dense(self.embed_count, use_bias=False)
def forward(self,
def forward(
self,
keys: jnp.ndarray,
values: jnp.ndarray,
queries: jnp.ndarray,
@@ -92,7 +93,8 @@ class DalleBartEncoderLayerFlax(nn.Module):
self.glu = GLUFlax(self.embed_count, self.glu_embed_count)
@nn.compact
def __call__(self,
def __call__(
self,
encoder_state: jnp.ndarray,
attention_mask: jnp.ndarray
) -> jnp.ndarray:

View File

@@ -37,7 +37,8 @@ class AttentionTorch(nn.Module):
self.one = torch.ones((1, 1))
if torch.cuda.is_available(): self.one = self.one.cuda()
def forward(self,
def forward(
self,
keys: FloatTensor,
values: FloatTensor,
queries: FloatTensor,
@@ -105,7 +106,8 @@ class EncoderLayerTorch(nn.Module):
class DalleBartEncoderTorch(nn.Module):
def __init__(self,
def __init__(
self,
layer_count: int,
embed_count: int,
attention_head_count: int,