delete cache

main
Brett Kuprel 2 years ago
parent d9d7f34b22
commit 2311a1af7b
  1. 2
      min_dalle/models/dalle_bart_decoder_flax.py
  2. 2
      min_dalle/models/dalle_bart_encoder_flax.py

@ -148,7 +148,7 @@ class DalleBartDecoderFlax(nn.Module):
)
self.layers = nn.scan(
DalleBartDecoderLayerFlax,
variable_axes = { "params": 0, "cache": 0 },
variable_axes = { "params": 0 },
split_rngs = { "params": True },
in_axes = (nn.broadcast, 0, nn.broadcast, nn.broadcast),
out_axes = 0,

@ -124,7 +124,7 @@ class DalleBartEncoderFlax(nn.Module):
self.embed_positions = nn.Embed(self.text_token_count, self.embed_count)
self.layers = nn.scan(
DalleBartEncoderLayerFlax,
variable_axes = { "params": 0, "cache": 0 },
variable_axes = { "params": 0 },
split_rngs = { "params": True },
in_axes = nn.broadcast,
length = self.layer_count

Loading…
Cancel
Save