remove config.json dependency, default to torch in image_from_text.py
This commit is contained in:
@@ -130,7 +130,6 @@ def keep_top_k(logits: jnp.ndarray, k: int) -> jnp.ndarray:
|
||||
|
||||
class DalleBartDecoderFlax(nn.Module):
|
||||
image_token_count: int
|
||||
text_token_count: int
|
||||
image_vocab_count: int
|
||||
attention_head_count: int
|
||||
embed_count: int
|
||||
|
@@ -109,7 +109,7 @@ class DecoderLayerTorch(nn.Module):
|
||||
class DalleBartDecoderTorch(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
image_vocab_size: int,
|
||||
image_vocab_count: int,
|
||||
image_token_count: int,
|
||||
sample_token_count: int,
|
||||
embed_count: int,
|
||||
@@ -117,16 +117,14 @@ class DalleBartDecoderTorch(nn.Module):
|
||||
glu_embed_count: int,
|
||||
layer_count: int,
|
||||
batch_count: int,
|
||||
start_token: int,
|
||||
is_verbose: bool
|
||||
start_token: int
|
||||
):
|
||||
super().__init__()
|
||||
self.is_verbose = is_verbose
|
||||
self.layer_count = layer_count
|
||||
self.sample_token_count = sample_token_count
|
||||
self.condition_factor = 10.0
|
||||
self.image_token_count = image_token_count
|
||||
self.embed_tokens = nn.Embedding(image_vocab_size + 1, embed_count)
|
||||
self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count)
|
||||
self.embed_positions = nn.Embedding(image_token_count, embed_count)
|
||||
self.layers: List[DecoderLayerTorch] = nn.ModuleList([
|
||||
DecoderLayerTorch(
|
||||
@@ -139,7 +137,7 @@ class DalleBartDecoderTorch(nn.Module):
|
||||
])
|
||||
self.layernorm_embedding = nn.LayerNorm(embed_count)
|
||||
self.final_ln = nn.LayerNorm(embed_count)
|
||||
self.lm_head = nn.Linear(embed_count, image_vocab_size + 1, bias=False)
|
||||
self.lm_head = nn.Linear(embed_count, image_vocab_count + 1, bias=False)
|
||||
self.attention_state_shape = (
|
||||
layer_count,
|
||||
2 * batch_count,
|
||||
|
Reference in New Issue
Block a user