v0.2.0, MinDalleTorch -> MinDalle, breaking change
This commit is contained in:
@@ -1 +1 @@
|
||||
from .min_dalle_torch import MinDalleTorch
|
||||
from .min_dalle import MinDalle
|
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
from PIL import Image
|
||||
from typing import Dict
|
||||
import numpy
|
||||
from torch import LongTensor
|
||||
import torch
|
||||
@@ -10,16 +9,13 @@ import random
|
||||
torch.set_grad_enabled(False)
|
||||
torch.set_num_threads(os.cpu_count())
|
||||
|
||||
from .text_tokenizer import TextTokenizer
|
||||
from .models import DalleBartEncoder, DalleBartDecoder, VQGanDetokenizer
|
||||
|
||||
MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
|
||||
|
||||
from .text_tokenizer import TextTokenizer
|
||||
from .models import (
|
||||
DalleBartEncoderTorch,
|
||||
DalleBartDecoderTorch,
|
||||
VQGanDetokenizer
|
||||
)
|
||||
|
||||
class MinDalleTorch:
|
||||
class MinDalle:
|
||||
def __init__(
|
||||
self,
|
||||
is_mega: bool,
|
||||
@@ -104,7 +100,7 @@ class MinDalleTorch:
|
||||
is_downloaded = os.path.exists(self.encoder_params_path)
|
||||
if not is_downloaded: self.download_encoder()
|
||||
print("initializing DalleBartEncoderTorch")
|
||||
self.encoder = DalleBartEncoderTorch(
|
||||
self.encoder = DalleBartEncoder(
|
||||
attention_head_count = self.attention_head_count,
|
||||
embed_count = self.embed_count,
|
||||
glu_embed_count = self.glu_embed_count,
|
||||
@@ -122,7 +118,7 @@ class MinDalleTorch:
|
||||
is_downloaded = os.path.exists(self.decoder_params_path)
|
||||
if not is_downloaded: self.download_decoder()
|
||||
print("initializing DalleBartDecoderTorch")
|
||||
self.decoder = DalleBartDecoderTorch(
|
||||
self.decoder = DalleBartDecoder(
|
||||
sample_token_count = self.sample_token_count,
|
||||
image_token_count = self.image_token_count,
|
||||
image_vocab_count = self.image_vocab_count,
|
@@ -1,3 +1,3 @@
|
||||
from .dalle_bart_encoder_torch import DalleBartEncoderTorch
|
||||
from .dalle_bart_decoder_torch import DalleBartDecoderTorch
|
||||
from .dalle_bart_encoder import DalleBartEncoder
|
||||
from .dalle_bart_decoder import DalleBartDecoder
|
||||
from .vqgan_detokenizer import VQGanDetokenizer
|
@@ -3,10 +3,10 @@ import torch
|
||||
from torch import LongTensor, nn, FloatTensor, BoolTensor
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
from .dalle_bart_encoder_torch import GLUTorch, AttentionTorch
|
||||
from .dalle_bart_encoder import GLU, AttentionBase
|
||||
|
||||
|
||||
class DecoderCrossAttentionTorch(AttentionTorch):
|
||||
class DecoderCrossAttention(AttentionBase):
|
||||
def forward(
|
||||
self,
|
||||
decoder_state: FloatTensor,
|
||||
@@ -19,7 +19,7 @@ class DecoderCrossAttentionTorch(AttentionTorch):
|
||||
return super().forward(keys, values, queries, attention_mask)
|
||||
|
||||
|
||||
class DecoderSelfAttentionTorch(AttentionTorch):
|
||||
class DecoderSelfAttention(AttentionBase):
|
||||
def forward(
|
||||
self,
|
||||
decoder_state: FloatTensor,
|
||||
@@ -42,7 +42,7 @@ class DecoderSelfAttentionTorch(AttentionTorch):
|
||||
return decoder_state, attention_state
|
||||
|
||||
|
||||
class DecoderLayerTorch(nn.Module):
|
||||
class DecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
image_token_count: int,
|
||||
@@ -53,12 +53,12 @@ class DecoderLayerTorch(nn.Module):
|
||||
super().__init__()
|
||||
self.image_token_count = image_token_count
|
||||
self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
|
||||
self.self_attn = DecoderSelfAttentionTorch(head_count, embed_count)
|
||||
self.self_attn = DecoderSelfAttention(head_count, embed_count)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(embed_count)
|
||||
self.pre_encoder_attn_layer_norm = nn.LayerNorm(embed_count)
|
||||
self.encoder_attn = DecoderCrossAttentionTorch(head_count, embed_count)
|
||||
self.encoder_attn = DecoderCrossAttention(head_count, embed_count)
|
||||
self.encoder_attn_layer_norm = nn.LayerNorm(embed_count)
|
||||
self.glu = GLUTorch(embed_count, glu_embed_count)
|
||||
self.glu = GLU(embed_count, glu_embed_count)
|
||||
|
||||
self.token_indices = torch.arange(self.image_token_count)
|
||||
if torch.cuda.is_available():
|
||||
@@ -106,7 +106,7 @@ class DecoderLayerTorch(nn.Module):
|
||||
return decoder_state, attention_state
|
||||
|
||||
|
||||
class DalleBartDecoderTorch(nn.Module):
|
||||
class DalleBartDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
image_vocab_count: int,
|
||||
@@ -126,8 +126,8 @@ class DalleBartDecoderTorch(nn.Module):
|
||||
self.image_token_count = image_token_count
|
||||
self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count)
|
||||
self.embed_positions = nn.Embedding(image_token_count, embed_count)
|
||||
self.layers: List[DecoderLayerTorch] = nn.ModuleList([
|
||||
DecoderLayerTorch(
|
||||
self.layers: List[DecoderLayer] = nn.ModuleList([
|
||||
DecoderLayer(
|
||||
image_token_count,
|
||||
attention_head_count,
|
||||
embed_count,
|
@@ -4,7 +4,7 @@ from torch import nn, BoolTensor, FloatTensor, LongTensor
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
|
||||
class GLUTorch(nn.Module):
|
||||
class GLU(nn.Module):
|
||||
def __init__(self, count_in_out, count_middle):
|
||||
super().__init__()
|
||||
self.gelu = nn.GELU()
|
||||
@@ -24,7 +24,7 @@ class GLUTorch(nn.Module):
|
||||
return z
|
||||
|
||||
|
||||
class AttentionTorch(nn.Module):
|
||||
class AttentionBase(nn.Module):
|
||||
def __init__(self, head_count: int, embed_count: int):
|
||||
super().__init__()
|
||||
self.head_count = head_count
|
||||
@@ -72,7 +72,7 @@ class AttentionTorch(nn.Module):
|
||||
return attention_output
|
||||
|
||||
|
||||
class EncoderSelfAttentionTorch(AttentionTorch):
|
||||
class EncoderSelfAttention(AttentionBase):
|
||||
def forward(
|
||||
self,
|
||||
encoder_state: FloatTensor,
|
||||
@@ -84,13 +84,13 @@ class EncoderSelfAttentionTorch(AttentionTorch):
|
||||
return super().forward(keys, values, queries, attention_mask)
|
||||
|
||||
|
||||
class EncoderLayerTorch(nn.Module):
|
||||
class EncoderLayer(nn.Module):
|
||||
def __init__(self, embed_count: int, head_count: int, glu_embed_count: int):
|
||||
super().__init__()
|
||||
self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
|
||||
self.self_attn = EncoderSelfAttentionTorch(head_count, embed_count)
|
||||
self.self_attn = EncoderSelfAttention(head_count, embed_count)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(embed_count)
|
||||
self.glu = GLUTorch(embed_count, glu_embed_count)
|
||||
self.glu = GLU(embed_count, glu_embed_count)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -108,7 +108,7 @@ class EncoderLayerTorch(nn.Module):
|
||||
return encoder_state
|
||||
|
||||
|
||||
class DalleBartEncoderTorch(nn.Module):
|
||||
class DalleBartEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
layer_count: int,
|
||||
@@ -121,8 +121,8 @@ class DalleBartEncoderTorch(nn.Module):
|
||||
super().__init__()
|
||||
self.embed_tokens = nn.Embedding(text_vocab_count, embed_count)
|
||||
self.embed_positions = nn.Embedding(text_token_count, embed_count)
|
||||
self.layers: List[EncoderLayerTorch] = nn.ModuleList([
|
||||
EncoderLayerTorch(
|
||||
self.layers: List[EncoderLayer] = nn.ModuleList([
|
||||
EncoderLayer(
|
||||
embed_count = embed_count,
|
||||
head_count = attention_head_count,
|
||||
glu_embed_count = glu_embed_count
|
Reference in New Issue
Block a user