From 2311a1af7b00e8743efd003561b4c347eb4ebc54 Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Thu, 30 Jun 2022 15:48:20 -0400 Subject: [PATCH] delete cache --- min_dalle/models/dalle_bart_decoder_flax.py | 2 +- min_dalle/models/dalle_bart_encoder_flax.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/min_dalle/models/dalle_bart_decoder_flax.py b/min_dalle/models/dalle_bart_decoder_flax.py index 704201d..edce6c2 100644 --- a/min_dalle/models/dalle_bart_decoder_flax.py +++ b/min_dalle/models/dalle_bart_decoder_flax.py @@ -148,7 +148,7 @@ class DalleBartDecoderFlax(nn.Module): ) self.layers = nn.scan( DalleBartDecoderLayerFlax, - variable_axes = { "params": 0, "cache": 0 }, + variable_axes = { "params": 0 }, split_rngs = { "params": True }, in_axes = (nn.broadcast, 0, nn.broadcast, nn.broadcast), out_axes = 0, diff --git a/min_dalle/models/dalle_bart_encoder_flax.py b/min_dalle/models/dalle_bart_encoder_flax.py index d320e69..7a1cc1b 100644 --- a/min_dalle/models/dalle_bart_encoder_flax.py +++ b/min_dalle/models/dalle_bart_encoder_flax.py @@ -124,7 +124,7 @@ class DalleBartEncoderFlax(nn.Module): self.embed_positions = nn.Embed(self.text_token_count, self.embed_count) self.layers = nn.scan( DalleBartEncoderLayerFlax, - variable_axes = { "params": 0, "cache": 0 }, + variable_axes = { "params": 0 }, split_rngs = { "params": True }, in_axes = nn.broadcast, length = self.layer_count