v0.2.0, MinDalleTorch -> MinDalle, breaking change

This commit is contained in:
Brett Kuprel
2022-07-01 19:44:24 -04:00
parent 2080e596c3
commit 35e97768a5
10 changed files with 43 additions and 45 deletions

View File

@@ -1,3 +1,3 @@
from .dalle_bart_encoder_torch import DalleBartEncoderTorch
from .dalle_bart_decoder_torch import DalleBartDecoderTorch
from .dalle_bart_encoder import DalleBartEncoder
from .dalle_bart_decoder import DalleBartDecoder
from .vqgan_detokenizer import VQGanDetokenizer

View File

@@ -3,10 +3,10 @@ import torch
from torch import LongTensor, nn, FloatTensor, BoolTensor
torch.set_grad_enabled(False)
from .dalle_bart_encoder_torch import GLUTorch, AttentionTorch
from .dalle_bart_encoder import GLU, AttentionBase
class DecoderCrossAttentionTorch(AttentionTorch):
class DecoderCrossAttention(AttentionBase):
def forward(
self,
decoder_state: FloatTensor,
@@ -19,7 +19,7 @@ class DecoderCrossAttentionTorch(AttentionTorch):
return super().forward(keys, values, queries, attention_mask)
class DecoderSelfAttentionTorch(AttentionTorch):
class DecoderSelfAttention(AttentionBase):
def forward(
self,
decoder_state: FloatTensor,
@@ -42,7 +42,7 @@ class DecoderSelfAttentionTorch(AttentionTorch):
return decoder_state, attention_state
class DecoderLayerTorch(nn.Module):
class DecoderLayer(nn.Module):
def __init__(
self,
image_token_count: int,
@@ -53,12 +53,12 @@ class DecoderLayerTorch(nn.Module):
super().__init__()
self.image_token_count = image_token_count
self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
self.self_attn = DecoderSelfAttentionTorch(head_count, embed_count)
self.self_attn = DecoderSelfAttention(head_count, embed_count)
self.self_attn_layer_norm = nn.LayerNorm(embed_count)
self.pre_encoder_attn_layer_norm = nn.LayerNorm(embed_count)
self.encoder_attn = DecoderCrossAttentionTorch(head_count, embed_count)
self.encoder_attn = DecoderCrossAttention(head_count, embed_count)
self.encoder_attn_layer_norm = nn.LayerNorm(embed_count)
self.glu = GLUTorch(embed_count, glu_embed_count)
self.glu = GLU(embed_count, glu_embed_count)
self.token_indices = torch.arange(self.image_token_count)
if torch.cuda.is_available():
@@ -106,7 +106,7 @@ class DecoderLayerTorch(nn.Module):
return decoder_state, attention_state
class DalleBartDecoderTorch(nn.Module):
class DalleBartDecoder(nn.Module):
def __init__(
self,
image_vocab_count: int,
@@ -126,8 +126,8 @@ class DalleBartDecoderTorch(nn.Module):
self.image_token_count = image_token_count
self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count)
self.embed_positions = nn.Embedding(image_token_count, embed_count)
self.layers: List[DecoderLayerTorch] = nn.ModuleList([
DecoderLayerTorch(
self.layers: List[DecoderLayer] = nn.ModuleList([
DecoderLayer(
image_token_count,
attention_head_count,
embed_count,

View File

@@ -4,7 +4,7 @@ from torch import nn, BoolTensor, FloatTensor, LongTensor
torch.set_grad_enabled(False)
class GLUTorch(nn.Module):
class GLU(nn.Module):
def __init__(self, count_in_out, count_middle):
super().__init__()
self.gelu = nn.GELU()
@@ -24,7 +24,7 @@ class GLUTorch(nn.Module):
return z
class AttentionTorch(nn.Module):
class AttentionBase(nn.Module):
def __init__(self, head_count: int, embed_count: int):
super().__init__()
self.head_count = head_count
@@ -72,7 +72,7 @@ class AttentionTorch(nn.Module):
return attention_output
class EncoderSelfAttentionTorch(AttentionTorch):
class EncoderSelfAttention(AttentionBase):
def forward(
self,
encoder_state: FloatTensor,
@@ -84,13 +84,13 @@ class EncoderSelfAttentionTorch(AttentionTorch):
return super().forward(keys, values, queries, attention_mask)
class EncoderLayerTorch(nn.Module):
class EncoderLayer(nn.Module):
def __init__(self, embed_count: int, head_count: int, glu_embed_count: int):
super().__init__()
self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
self.self_attn = EncoderSelfAttentionTorch(head_count, embed_count)
self.self_attn = EncoderSelfAttention(head_count, embed_count)
self.self_attn_layer_norm = nn.LayerNorm(embed_count)
self.glu = GLUTorch(embed_count, glu_embed_count)
self.glu = GLU(embed_count, glu_embed_count)
def forward(
self,
@@ -108,7 +108,7 @@ class EncoderLayerTorch(nn.Module):
return encoder_state
class DalleBartEncoderTorch(nn.Module):
class DalleBartEncoder(nn.Module):
def __init__(
self,
layer_count: int,
@@ -121,8 +121,8 @@ class DalleBartEncoderTorch(nn.Module):
super().__init__()
self.embed_tokens = nn.Embedding(text_vocab_count, embed_count)
self.embed_positions = nn.Embedding(text_token_count, embed_count)
self.layers: List[EncoderLayerTorch] = nn.ModuleList([
EncoderLayerTorch(
self.layers: List[EncoderLayer] = nn.ModuleList([
EncoderLayer(
embed_count = embed_count,
head_count = attention_head_count,
glu_embed_count = glu_embed_count