v0.2.0, MinDalleTorch -> MinDalle, breaking change

This commit is contained in:
Brett Kuprel 2022-07-01 19:44:24 -04:00
parent 2080e596c3
commit 35e97768a5
10 changed files with 43 additions and 45 deletions

6
README.md vendored
View File

@ -22,12 +22,12 @@ $ pip install min-dalle
### Python ### Python
To load a model once and generate multiple times, first initialize `MinDalleTorch`. To load a model once and generate multiple times, first initialize `MinDalle`.
```python ```python
from min_dalle import MinDalleTorch from min_dalle import MinDalle
model = MinDalleTorch( model = MinDalle(
is_mega=True, is_mega=True,
is_reusable=True, is_reusable=True,
models_root='./pretrained' models_root='./pretrained'

View File

@ -2,14 +2,15 @@ import argparse
import os import os
from PIL import Image from PIL import Image
from min_dalle import MinDalleTorch from min_dalle import MinDalle
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--mega', action='store_true') parser.add_argument('--mega', action='store_true')
parser.add_argument('--no-mega', dest='mega', action='store_false') parser.add_argument('--no-mega', dest='mega', action='store_false')
parser.set_defaults(mega=False) parser.set_defaults(mega=False)
parser.add_argument('--text', type=str, default='alien life') parser.add_argument('--text', type=str, default='alien life')
parser.add_argument('--seed', type=int, default=7) parser.add_argument('--seed', type=int, default=-1)
parser.add_argument('--image_path', type=str, default='generated') parser.add_argument('--image_path', type=str, default='generated')
parser.add_argument('--token_count', type=int, default=256) # for debugging parser.add_argument('--token_count', type=int, default=256) # for debugging
@ -39,7 +40,7 @@ def generate_image(
image_path: str, image_path: str,
token_count: int token_count: int
): ):
model = MinDalleTorch( model = MinDalle(
is_mega=is_mega, is_mega=is_mega,
models_root='pretrained', models_root='pretrained',
is_reusable=False, is_reusable=False,

4
min_dalle.ipynb vendored
View File

@ -77,9 +77,9 @@
} }
], ],
"source": [ "source": [
"from min_dalle import MinDalleTorch\n", "from min_dalle import MinDalle\n",
"\n", "\n",
"model = MinDalleTorch(is_mega=True, is_reusable=True)" "model = MinDalle(is_mega=True, is_reusable=True)"
] ]
}, },
{ {

View File

@ -1 +1 @@
from .min_dalle_torch import MinDalleTorch from .min_dalle import MinDalle

View File

@ -1,6 +1,5 @@
import os import os
from PIL import Image from PIL import Image
from typing import Dict
import numpy import numpy
from torch import LongTensor from torch import LongTensor
import torch import torch
@ -10,16 +9,13 @@ import random
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count()) torch.set_num_threads(os.cpu_count())
from .text_tokenizer import TextTokenizer
from .models import DalleBartEncoder, DalleBartDecoder, VQGanDetokenizer
MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/' MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
from .text_tokenizer import TextTokenizer
from .models import (
DalleBartEncoderTorch,
DalleBartDecoderTorch,
VQGanDetokenizer
)
class MinDalleTorch: class MinDalle:
def __init__( def __init__(
self, self,
is_mega: bool, is_mega: bool,
@ -104,7 +100,7 @@ class MinDalleTorch:
is_downloaded = os.path.exists(self.encoder_params_path) is_downloaded = os.path.exists(self.encoder_params_path)
if not is_downloaded: self.download_encoder() if not is_downloaded: self.download_encoder()
print("initializing DalleBartEncoderTorch") print("initializing DalleBartEncoderTorch")
self.encoder = DalleBartEncoderTorch( self.encoder = DalleBartEncoder(
attention_head_count = self.attention_head_count, attention_head_count = self.attention_head_count,
embed_count = self.embed_count, embed_count = self.embed_count,
glu_embed_count = self.glu_embed_count, glu_embed_count = self.glu_embed_count,
@ -122,7 +118,7 @@ class MinDalleTorch:
is_downloaded = os.path.exists(self.decoder_params_path) is_downloaded = os.path.exists(self.decoder_params_path)
if not is_downloaded: self.download_decoder() if not is_downloaded: self.download_decoder()
print("initializing DalleBartDecoderTorch") print("initializing DalleBartDecoderTorch")
self.decoder = DalleBartDecoderTorch( self.decoder = DalleBartDecoder(
sample_token_count = self.sample_token_count, sample_token_count = self.sample_token_count,
image_token_count = self.image_token_count, image_token_count = self.image_token_count,
image_vocab_count = self.image_vocab_count, image_vocab_count = self.image_vocab_count,

View File

@ -1,3 +1,3 @@
from .dalle_bart_encoder_torch import DalleBartEncoderTorch from .dalle_bart_encoder import DalleBartEncoder
from .dalle_bart_decoder_torch import DalleBartDecoderTorch from .dalle_bart_decoder import DalleBartDecoder
from .vqgan_detokenizer import VQGanDetokenizer from .vqgan_detokenizer import VQGanDetokenizer

View File

@ -3,10 +3,10 @@ import torch
from torch import LongTensor, nn, FloatTensor, BoolTensor from torch import LongTensor, nn, FloatTensor, BoolTensor
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
from .dalle_bart_encoder_torch import GLUTorch, AttentionTorch from .dalle_bart_encoder import GLU, AttentionBase
class DecoderCrossAttentionTorch(AttentionTorch): class DecoderCrossAttention(AttentionBase):
def forward( def forward(
self, self,
decoder_state: FloatTensor, decoder_state: FloatTensor,
@ -19,7 +19,7 @@ class DecoderCrossAttentionTorch(AttentionTorch):
return super().forward(keys, values, queries, attention_mask) return super().forward(keys, values, queries, attention_mask)
class DecoderSelfAttentionTorch(AttentionTorch): class DecoderSelfAttention(AttentionBase):
def forward( def forward(
self, self,
decoder_state: FloatTensor, decoder_state: FloatTensor,
@ -42,7 +42,7 @@ class DecoderSelfAttentionTorch(AttentionTorch):
return decoder_state, attention_state return decoder_state, attention_state
class DecoderLayerTorch(nn.Module): class DecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
image_token_count: int, image_token_count: int,
@ -53,12 +53,12 @@ class DecoderLayerTorch(nn.Module):
super().__init__() super().__init__()
self.image_token_count = image_token_count self.image_token_count = image_token_count
self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count) self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
self.self_attn = DecoderSelfAttentionTorch(head_count, embed_count) self.self_attn = DecoderSelfAttention(head_count, embed_count)
self.self_attn_layer_norm = nn.LayerNorm(embed_count) self.self_attn_layer_norm = nn.LayerNorm(embed_count)
self.pre_encoder_attn_layer_norm = nn.LayerNorm(embed_count) self.pre_encoder_attn_layer_norm = nn.LayerNorm(embed_count)
self.encoder_attn = DecoderCrossAttentionTorch(head_count, embed_count) self.encoder_attn = DecoderCrossAttention(head_count, embed_count)
self.encoder_attn_layer_norm = nn.LayerNorm(embed_count) self.encoder_attn_layer_norm = nn.LayerNorm(embed_count)
self.glu = GLUTorch(embed_count, glu_embed_count) self.glu = GLU(embed_count, glu_embed_count)
self.token_indices = torch.arange(self.image_token_count) self.token_indices = torch.arange(self.image_token_count)
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -106,7 +106,7 @@ class DecoderLayerTorch(nn.Module):
return decoder_state, attention_state return decoder_state, attention_state
class DalleBartDecoderTorch(nn.Module): class DalleBartDecoder(nn.Module):
def __init__( def __init__(
self, self,
image_vocab_count: int, image_vocab_count: int,
@ -126,8 +126,8 @@ class DalleBartDecoderTorch(nn.Module):
self.image_token_count = image_token_count self.image_token_count = image_token_count
self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count) self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count)
self.embed_positions = nn.Embedding(image_token_count, embed_count) self.embed_positions = nn.Embedding(image_token_count, embed_count)
self.layers: List[DecoderLayerTorch] = nn.ModuleList([ self.layers: List[DecoderLayer] = nn.ModuleList([
DecoderLayerTorch( DecoderLayer(
image_token_count, image_token_count,
attention_head_count, attention_head_count,
embed_count, embed_count,

View File

@ -4,7 +4,7 @@ from torch import nn, BoolTensor, FloatTensor, LongTensor
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
class GLUTorch(nn.Module): class GLU(nn.Module):
def __init__(self, count_in_out, count_middle): def __init__(self, count_in_out, count_middle):
super().__init__() super().__init__()
self.gelu = nn.GELU() self.gelu = nn.GELU()
@ -24,7 +24,7 @@ class GLUTorch(nn.Module):
return z return z
class AttentionTorch(nn.Module): class AttentionBase(nn.Module):
def __init__(self, head_count: int, embed_count: int): def __init__(self, head_count: int, embed_count: int):
super().__init__() super().__init__()
self.head_count = head_count self.head_count = head_count
@ -72,7 +72,7 @@ class AttentionTorch(nn.Module):
return attention_output return attention_output
class EncoderSelfAttentionTorch(AttentionTorch): class EncoderSelfAttention(AttentionBase):
def forward( def forward(
self, self,
encoder_state: FloatTensor, encoder_state: FloatTensor,
@ -84,13 +84,13 @@ class EncoderSelfAttentionTorch(AttentionTorch):
return super().forward(keys, values, queries, attention_mask) return super().forward(keys, values, queries, attention_mask)
class EncoderLayerTorch(nn.Module): class EncoderLayer(nn.Module):
def __init__(self, embed_count: int, head_count: int, glu_embed_count: int): def __init__(self, embed_count: int, head_count: int, glu_embed_count: int):
super().__init__() super().__init__()
self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count) self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
self.self_attn = EncoderSelfAttentionTorch(head_count, embed_count) self.self_attn = EncoderSelfAttention(head_count, embed_count)
self.self_attn_layer_norm = nn.LayerNorm(embed_count) self.self_attn_layer_norm = nn.LayerNorm(embed_count)
self.glu = GLUTorch(embed_count, glu_embed_count) self.glu = GLU(embed_count, glu_embed_count)
def forward( def forward(
self, self,
@ -108,7 +108,7 @@ class EncoderLayerTorch(nn.Module):
return encoder_state return encoder_state
class DalleBartEncoderTorch(nn.Module): class DalleBartEncoder(nn.Module):
def __init__( def __init__(
self, self,
layer_count: int, layer_count: int,
@ -121,8 +121,8 @@ class DalleBartEncoderTorch(nn.Module):
super().__init__() super().__init__()
self.embed_tokens = nn.Embedding(text_vocab_count, embed_count) self.embed_tokens = nn.Embedding(text_vocab_count, embed_count)
self.embed_positions = nn.Embedding(text_token_count, embed_count) self.embed_positions = nn.Embedding(text_token_count, embed_count)
self.layers: List[EncoderLayerTorch] = nn.ModuleList([ self.layers: List[EncoderLayer] = nn.ModuleList([
EncoderLayerTorch( EncoderLayer(
embed_count = embed_count, embed_count = embed_count,
head_count = attention_head_count, head_count = attention_head_count,
glu_embed_count = glu_embed_count glu_embed_count = glu_embed_count

View File

@ -1,11 +1,11 @@
import tempfile import tempfile
from cog import BasePredictor, Path, Input from cog import BasePredictor, Path, Input
from min_dalle.min_dalle_torch import MinDalleTorch from min_dalle import MinDalle
class Predictor(BasePredictor): class Predictor(BasePredictor):
def setup(self): def setup(self):
self.model = MinDalleTorch(is_mega=True) self.model = MinDalle(is_mega=True)
def predict( def predict(
self, self,

View File

@ -3,9 +3,9 @@ import setuptools
setuptools.setup( setuptools.setup(
name='min-dalle', name='min-dalle',
description = 'min(DALL·E)', description = 'min(DALL·E)',
version='0.1.4', version='0.2.0',
author='Brett Kuprel', author='Brett Kuprel',
author_email = 'brkuprel@gmail.com', author_email='brkuprel@gmail.com',
packages=[ packages=[
'min_dalle', 'min_dalle',
'min_dalle.models' 'min_dalle.models'
@ -18,6 +18,7 @@ setuptools.setup(
keywords = [ keywords = [
'artificial intelligence', 'artificial intelligence',
'deep learning', 'deep learning',
'text to image' 'text-to-image',
'pytorch'
] ]
) )