remove config.json dependency, default to torch in image_from_text.py

This commit is contained in:
Brett Kuprel
2022-07-01 12:03:37 -04:00
parent 4404e70764
commit 85f5866eff
10 changed files with 52 additions and 64 deletions

View File

@@ -130,7 +130,6 @@ def keep_top_k(logits: jnp.ndarray, k: int) -> jnp.ndarray:
class DalleBartDecoderFlax(nn.Module):
image_token_count: int
text_token_count: int
image_vocab_count: int
attention_head_count: int
embed_count: int

View File

@@ -109,7 +109,7 @@ class DecoderLayerTorch(nn.Module):
class DalleBartDecoderTorch(nn.Module):
def __init__(
self,
image_vocab_size: int,
image_vocab_count: int,
image_token_count: int,
sample_token_count: int,
embed_count: int,
@@ -117,16 +117,14 @@ class DalleBartDecoderTorch(nn.Module):
glu_embed_count: int,
layer_count: int,
batch_count: int,
start_token: int,
is_verbose: bool
start_token: int
):
super().__init__()
self.is_verbose = is_verbose
self.layer_count = layer_count
self.sample_token_count = sample_token_count
self.condition_factor = 10.0
self.image_token_count = image_token_count
self.embed_tokens = nn.Embedding(image_vocab_size + 1, embed_count)
self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count)
self.embed_positions = nn.Embedding(image_token_count, embed_count)
self.layers: List[DecoderLayerTorch] = nn.ModuleList([
DecoderLayerTorch(
@@ -139,7 +137,7 @@ class DalleBartDecoderTorch(nn.Module):
])
self.layernorm_embedding = nn.LayerNorm(embed_count)
self.final_ln = nn.LayerNorm(embed_count)
self.lm_head = nn.Linear(embed_count, image_vocab_size + 1, bias=False)
self.lm_head = nn.Linear(embed_count, image_vocab_count + 1, bias=False)
self.attention_state_shape = (
layer_count,
2 * batch_count,