diff --git a/min_dalle/models/dalle_bart_decoder_flax.py b/min_dalle/models/dalle_bart_decoder_flax.py index 704201d..edce6c2 100644 --- a/min_dalle/models/dalle_bart_decoder_flax.py +++ b/min_dalle/models/dalle_bart_decoder_flax.py @@ -148,7 +148,7 @@ class DalleBartDecoderFlax(nn.Module): ) self.layers = nn.scan( DalleBartDecoderLayerFlax, - variable_axes = { "params": 0, "cache": 0 }, + variable_axes = { "params": 0 }, split_rngs = { "params": True }, in_axes = (nn.broadcast, 0, nn.broadcast, nn.broadcast), out_axes = 0, diff --git a/min_dalle/models/dalle_bart_encoder_flax.py b/min_dalle/models/dalle_bart_encoder_flax.py index d320e69..7a1cc1b 100644 --- a/min_dalle/models/dalle_bart_encoder_flax.py +++ b/min_dalle/models/dalle_bart_encoder_flax.py @@ -124,7 +124,7 @@ class DalleBartEncoderFlax(nn.Module): self.embed_positions = nn.Embed(self.text_token_count, self.embed_count) self.layers = nn.scan( DalleBartEncoderLayerFlax, - variable_axes = { "params": 0, "cache": 0 }, + variable_axes = { "params": 0 }, split_rngs = { "params": True }, in_axes = nn.broadcast, length = self.layer_count