diff --git a/README.md b/README.md index db091a0..a89f0e7 100644 --- a/README.md +++ b/README.md @@ -22,12 +22,12 @@ $ pip install min-dalle ### Python -To load a model once and generate multiple times, first initialize `MinDalleTorch`. +To load a model once and generate multiple times, first initialize `MinDalle`. ```python -from min_dalle import MinDalleTorch +from min_dalle import MinDalle -model = MinDalleTorch( +model = MinDalle( is_mega=True, is_reusable=True, models_root='./pretrained' diff --git a/image_from_text.py b/image_from_text.py index ee771b7..e35341c 100644 --- a/image_from_text.py +++ b/image_from_text.py @@ -2,14 +2,15 @@ import argparse import os from PIL import Image -from min_dalle import MinDalleTorch +from min_dalle import MinDalle + parser = argparse.ArgumentParser() parser.add_argument('--mega', action='store_true') parser.add_argument('--no-mega', dest='mega', action='store_false') parser.set_defaults(mega=False) parser.add_argument('--text', type=str, default='alien life') -parser.add_argument('--seed', type=int, default=7) +parser.add_argument('--seed', type=int, default=-1) parser.add_argument('--image_path', type=str, default='generated') parser.add_argument('--token_count', type=int, default=256) # for debugging @@ -39,7 +40,7 @@ def generate_image( image_path: str, token_count: int ): - model = MinDalleTorch( + model = MinDalle( is_mega=is_mega, models_root='pretrained', is_reusable=False, diff --git a/min_dalle.ipynb b/min_dalle.ipynb index 66beeb5..64db5ff 100644 --- a/min_dalle.ipynb +++ b/min_dalle.ipynb @@ -77,9 +77,9 @@ } ], "source": [ - "from min_dalle import MinDalleTorch\n", + "from min_dalle import MinDalle\n", "\n", - "model = MinDalleTorch(is_mega=True, is_reusable=True)" + "model = MinDalle(is_mega=True, is_reusable=True)" ] }, { diff --git a/min_dalle/__init__.py b/min_dalle/__init__.py index 9807201..83962cb 100644 --- a/min_dalle/__init__.py +++ b/min_dalle/__init__.py @@ -1 +1 @@ -from .min_dalle_torch import MinDalleTorch \ No newline at end of file +from .min_dalle import MinDalle \ No newline at end of file diff --git a/min_dalle/min_dalle_torch.py b/min_dalle/min_dalle.py similarity index 96% rename from min_dalle/min_dalle_torch.py rename to min_dalle/min_dalle.py index 212f461..2e92cc6 100644 --- a/min_dalle/min_dalle_torch.py +++ b/min_dalle/min_dalle.py @@ -1,6 +1,5 @@ import os from PIL import Image -from typing import Dict import numpy from torch import LongTensor import torch @@ -10,16 +9,13 @@ import random torch.set_grad_enabled(False) torch.set_num_threads(os.cpu_count()) +from .text_tokenizer import TextTokenizer +from .models import DalleBartEncoder, DalleBartDecoder, VQGanDetokenizer + MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/' -from .text_tokenizer import TextTokenizer -from .models import ( - DalleBartEncoderTorch, - DalleBartDecoderTorch, - VQGanDetokenizer -) -class MinDalleTorch: +class MinDalle: def __init__( self, is_mega: bool, @@ -104,7 +100,7 @@ class MinDalleTorch: is_downloaded = os.path.exists(self.encoder_params_path) if not is_downloaded: self.download_encoder() print("initializing DalleBartEncoderTorch") - self.encoder = DalleBartEncoderTorch( + self.encoder = DalleBartEncoder( attention_head_count = self.attention_head_count, embed_count = self.embed_count, glu_embed_count = self.glu_embed_count, @@ -122,7 +118,7 @@ class MinDalleTorch: is_downloaded = os.path.exists(self.decoder_params_path) if not is_downloaded: self.download_decoder() print("initializing DalleBartDecoderTorch") - self.decoder = DalleBartDecoderTorch( + self.decoder = DalleBartDecoder( sample_token_count = self.sample_token_count, image_token_count = self.image_token_count, image_vocab_count = self.image_vocab_count, diff --git a/min_dalle/models/__init__.py b/min_dalle/models/__init__.py index f224241..5ac9af0 100644 --- a/min_dalle/models/__init__.py +++ b/min_dalle/models/__init__.py @@ -1,3 +1,3 @@ -from .dalle_bart_encoder_torch import DalleBartEncoderTorch -from .dalle_bart_decoder_torch import DalleBartDecoderTorch +from .dalle_bart_encoder import DalleBartEncoder +from .dalle_bart_decoder import DalleBartDecoder from .vqgan_detokenizer import VQGanDetokenizer \ No newline at end of file diff --git a/min_dalle/models/dalle_bart_decoder_torch.py b/min_dalle/models/dalle_bart_decoder.py similarity index 92% rename from min_dalle/models/dalle_bart_decoder_torch.py rename to min_dalle/models/dalle_bart_decoder.py index bdb860c..4b71858 100644 --- a/min_dalle/models/dalle_bart_decoder_torch.py +++ b/min_dalle/models/dalle_bart_decoder.py @@ -3,10 +3,10 @@ import torch from torch import LongTensor, nn, FloatTensor, BoolTensor torch.set_grad_enabled(False) -from .dalle_bart_encoder_torch import GLUTorch, AttentionTorch +from .dalle_bart_encoder import GLU, AttentionBase -class DecoderCrossAttentionTorch(AttentionTorch): +class DecoderCrossAttention(AttentionBase): def forward( self, decoder_state: FloatTensor, @@ -19,7 +19,7 @@ class DecoderCrossAttentionTorch(AttentionTorch): return super().forward(keys, values, queries, attention_mask) -class DecoderSelfAttentionTorch(AttentionTorch): +class DecoderSelfAttention(AttentionBase): def forward( self, decoder_state: FloatTensor, @@ -42,7 +42,7 @@ class DecoderSelfAttentionTorch(AttentionTorch): return decoder_state, attention_state -class DecoderLayerTorch(nn.Module): +class DecoderLayer(nn.Module): def __init__( self, image_token_count: int, @@ -53,12 +53,12 @@ class DecoderLayerTorch(nn.Module): super().__init__() self.image_token_count = image_token_count self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count) - self.self_attn = DecoderSelfAttentionTorch(head_count, embed_count) + self.self_attn = DecoderSelfAttention(head_count, embed_count) self.self_attn_layer_norm = nn.LayerNorm(embed_count) self.pre_encoder_attn_layer_norm = nn.LayerNorm(embed_count) - self.encoder_attn = DecoderCrossAttentionTorch(head_count, embed_count) + self.encoder_attn = DecoderCrossAttention(head_count, embed_count) self.encoder_attn_layer_norm = nn.LayerNorm(embed_count) - self.glu = GLUTorch(embed_count, glu_embed_count) + self.glu = GLU(embed_count, glu_embed_count) self.token_indices = torch.arange(self.image_token_count) if torch.cuda.is_available(): @@ -106,7 +106,7 @@ class DecoderLayerTorch(nn.Module): return decoder_state, attention_state -class DalleBartDecoderTorch(nn.Module): +class DalleBartDecoder(nn.Module): def __init__( self, image_vocab_count: int, @@ -126,8 +126,8 @@ class DalleBartDecoderTorch(nn.Module): self.image_token_count = image_token_count self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count) self.embed_positions = nn.Embedding(image_token_count, embed_count) - self.layers: List[DecoderLayerTorch] = nn.ModuleList([ - DecoderLayerTorch( + self.layers: List[DecoderLayer] = nn.ModuleList([ + DecoderLayer( image_token_count, attention_head_count, embed_count, diff --git a/min_dalle/models/dalle_bart_encoder_torch.py b/min_dalle/models/dalle_bart_encoder.py similarity index 92% rename from min_dalle/models/dalle_bart_encoder_torch.py rename to min_dalle/models/dalle_bart_encoder.py index 296cdec..e3d8eb8 100644 --- a/min_dalle/models/dalle_bart_encoder_torch.py +++ b/min_dalle/models/dalle_bart_encoder.py @@ -4,7 +4,7 @@ from torch import nn, BoolTensor, FloatTensor, LongTensor torch.set_grad_enabled(False) -class GLUTorch(nn.Module): +class GLU(nn.Module): def __init__(self, count_in_out, count_middle): super().__init__() self.gelu = nn.GELU() @@ -24,7 +24,7 @@ class GLUTorch(nn.Module): return z -class AttentionTorch(nn.Module): +class AttentionBase(nn.Module): def __init__(self, head_count: int, embed_count: int): super().__init__() self.head_count = head_count @@ -72,7 +72,7 @@ class AttentionTorch(nn.Module): return attention_output -class EncoderSelfAttentionTorch(AttentionTorch): +class EncoderSelfAttention(AttentionBase): def forward( self, encoder_state: FloatTensor, @@ -84,13 +84,13 @@ class EncoderSelfAttentionTorch(AttentionTorch): return super().forward(keys, values, queries, attention_mask) -class EncoderLayerTorch(nn.Module): +class EncoderLayer(nn.Module): def __init__(self, embed_count: int, head_count: int, glu_embed_count: int): super().__init__() self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count) - self.self_attn = EncoderSelfAttentionTorch(head_count, embed_count) + self.self_attn = EncoderSelfAttention(head_count, embed_count) self.self_attn_layer_norm = nn.LayerNorm(embed_count) - self.glu = GLUTorch(embed_count, glu_embed_count) + self.glu = GLU(embed_count, glu_embed_count) def forward( self, @@ -108,7 +108,7 @@ class EncoderLayerTorch(nn.Module): return encoder_state -class DalleBartEncoderTorch(nn.Module): +class DalleBartEncoder(nn.Module): def __init__( self, layer_count: int, @@ -121,8 +121,8 @@ class DalleBartEncoderTorch(nn.Module): super().__init__() self.embed_tokens = nn.Embedding(text_vocab_count, embed_count) self.embed_positions = nn.Embedding(text_token_count, embed_count) - self.layers: List[EncoderLayerTorch] = nn.ModuleList([ - EncoderLayerTorch( + self.layers: List[EncoderLayer] = nn.ModuleList([ + EncoderLayer( embed_count = embed_count, head_count = attention_head_count, glu_embed_count = glu_embed_count diff --git a/replicate/predict.py b/replicate/predict.py index 6c6701f..b1b9509 100644 --- a/replicate/predict.py +++ b/replicate/predict.py @@ -1,11 +1,11 @@ import tempfile from cog import BasePredictor, Path, Input -from min_dalle.min_dalle_torch import MinDalleTorch +from min_dalle import MinDalle class Predictor(BasePredictor): def setup(self): - self.model = MinDalleTorch(is_mega=True) + self.model = MinDalle(is_mega=True) def predict( self, diff --git a/setup.py b/setup.py index b0cfaae..11251c2 100644 --- a/setup.py +++ b/setup.py @@ -3,9 +3,9 @@ import setuptools setuptools.setup( name='min-dalle', description = 'min(DALLĀ·E)', - version='0.1.4', + version='0.2.0', author='Brett Kuprel', - author_email = 'brkuprel@gmail.com', + author_email='brkuprel@gmail.com', packages=[ 'min_dalle', 'min_dalle.models' @@ -18,6 +18,7 @@ setuptools.setup( keywords = [ 'artificial intelligence', 'deep learning', - 'text to image' + 'text-to-image', + 'pytorch' ] ) \ No newline at end of file