delete cache
This commit is contained in:
parent
d9d7f34b22
commit
2311a1af7b
|
@ -148,7 +148,7 @@ class DalleBartDecoderFlax(nn.Module):
|
|||
)
|
||||
self.layers = nn.scan(
|
||||
DalleBartDecoderLayerFlax,
|
||||
variable_axes = { "params": 0, "cache": 0 },
|
||||
variable_axes = { "params": 0 },
|
||||
split_rngs = { "params": True },
|
||||
in_axes = (nn.broadcast, 0, nn.broadcast, nn.broadcast),
|
||||
out_axes = 0,
|
||||
|
|
|
@ -124,7 +124,7 @@ class DalleBartEncoderFlax(nn.Module):
|
|||
self.embed_positions = nn.Embed(self.text_token_count, self.embed_count)
|
||||
self.layers = nn.scan(
|
||||
DalleBartEncoderLayerFlax,
|
||||
variable_axes = { "params": 0, "cache": 0 },
|
||||
variable_axes = { "params": 0 },
|
||||
split_rngs = { "params": True },
|
||||
in_axes = nn.broadcast,
|
||||
length = self.layer_count
|
||||
|
|
Loading…
Reference in New Issue
Block a user