delete cache

This commit is contained in:
Brett Kuprel 2022-06-30 15:48:20 -04:00
parent d9d7f34b22
commit 2311a1af7b
2 changed files with 2 additions and 2 deletions

View File

@ -148,7 +148,7 @@ class DalleBartDecoderFlax(nn.Module):
)
self.layers = nn.scan(
DalleBartDecoderLayerFlax,
variable_axes = { "params": 0, "cache": 0 },
variable_axes = { "params": 0 },
split_rngs = { "params": True },
in_axes = (nn.broadcast, 0, nn.broadcast, nn.broadcast),
out_axes = 0,

View File

@ -124,7 +124,7 @@ class DalleBartEncoderFlax(nn.Module):
self.embed_positions = nn.Embed(self.text_token_count, self.embed_count)
self.layers = nn.scan(
DalleBartEncoderLayerFlax,
variable_axes = { "params": 0, "cache": 0 },
variable_axes = { "params": 0 },
split_rngs = { "params": True },
in_axes = nn.broadcast,
length = self.layer_count