support bfloat16

This commit is contained in:
Brett Kuprel
2022-07-07 08:21:20 -04:00
parent 5f526e2109
commit da62298f06
9 changed files with 108 additions and 96 deletions

View File

@@ -40,7 +40,8 @@ class DecoderSelfAttention(AttentionBase):
queries = self.q_proj.forward(decoder_state)
attn_mask = self.token_indices < token_index + 1
attn_mask = attn_mask[None][[0] * decoder_state.shape[0]]
attention_state[:, token_index] = torch.cat([keys, values])
attn_state_new = torch.cat([keys, values]).to(attention_state.dtype)
attention_state[:, token_index] = attn_state_new
batch_count = decoder_state.shape[0]
keys = attention_state[:batch_count]
values = attention_state[batch_count:]

View File

@@ -82,7 +82,7 @@ class Upsample(Module):
self.conv = Conv2d(n, n, 3, padding=1)
def forward(self, x: Tensor) -> Tensor:
x = self.upsample.forward(x)
x = self.upsample.forward(x.to(torch.float32))
x = self.conv.forward(x)
return x