support bfloat16
This commit is contained in:
@@ -40,7 +40,8 @@ class DecoderSelfAttention(AttentionBase):
|
||||
queries = self.q_proj.forward(decoder_state)
|
||||
attn_mask = self.token_indices < token_index + 1
|
||||
attn_mask = attn_mask[None][[0] * decoder_state.shape[0]]
|
||||
attention_state[:, token_index] = torch.cat([keys, values])
|
||||
attn_state_new = torch.cat([keys, values]).to(attention_state.dtype)
|
||||
attention_state[:, token_index] = attn_state_new
|
||||
batch_count = decoder_state.shape[0]
|
||||
keys = attention_state[:batch_count]
|
||||
values = attention_state[batch_count:]
|
||||
|
@@ -82,7 +82,7 @@ class Upsample(Module):
|
||||
self.conv = Conv2d(n, n, 3, padding=1)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.upsample.forward(x)
|
||||
x = self.upsample.forward(x.to(torch.float32))
|
||||
x = self.conv.forward(x)
|
||||
return x
|
||||
|
||||
|
Reference in New Issue
Block a user