v0.2.0, MinDalleTorch -> MinDalle, breaking change

main
Brett Kuprel 2 years ago
parent 2080e596c3
commit 35e97768a5
  1. 6
      README.md
  2. 7
      image_from_text.py
  3. 4
      min_dalle.ipynb
  4. 2
      min_dalle/__init__.py
  5. 16
      min_dalle/min_dalle.py
  6. 4
      min_dalle/models/__init__.py
  7. 20
      min_dalle/models/dalle_bart_decoder.py
  8. 18
      min_dalle/models/dalle_bart_encoder.py
  9. 4
      replicate/predict.py
  10. 7
      setup.py

6
README.md vendored

@ -22,12 +22,12 @@ $ pip install min-dalle
### Python ### Python
To load a model once and generate multiple times, first initialize `MinDalleTorch`. To load a model once and generate multiple times, first initialize `MinDalle`.
```python ```python
from min_dalle import MinDalleTorch from min_dalle import MinDalle
model = MinDalleTorch( model = MinDalle(
is_mega=True, is_mega=True,
is_reusable=True, is_reusable=True,
models_root='./pretrained' models_root='./pretrained'

@ -2,14 +2,15 @@ import argparse
import os import os
from PIL import Image from PIL import Image
from min_dalle import MinDalleTorch from min_dalle import MinDalle
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--mega', action='store_true') parser.add_argument('--mega', action='store_true')
parser.add_argument('--no-mega', dest='mega', action='store_false') parser.add_argument('--no-mega', dest='mega', action='store_false')
parser.set_defaults(mega=False) parser.set_defaults(mega=False)
parser.add_argument('--text', type=str, default='alien life') parser.add_argument('--text', type=str, default='alien life')
parser.add_argument('--seed', type=int, default=7) parser.add_argument('--seed', type=int, default=-1)
parser.add_argument('--image_path', type=str, default='generated') parser.add_argument('--image_path', type=str, default='generated')
parser.add_argument('--token_count', type=int, default=256) # for debugging parser.add_argument('--token_count', type=int, default=256) # for debugging
@ -39,7 +40,7 @@ def generate_image(
image_path: str, image_path: str,
token_count: int token_count: int
): ):
model = MinDalleTorch( model = MinDalle(
is_mega=is_mega, is_mega=is_mega,
models_root='pretrained', models_root='pretrained',
is_reusable=False, is_reusable=False,

4
min_dalle.ipynb vendored

@ -77,9 +77,9 @@
} }
], ],
"source": [ "source": [
"from min_dalle import MinDalleTorch\n", "from min_dalle import MinDalle\n",
"\n", "\n",
"model = MinDalleTorch(is_mega=True, is_reusable=True)" "model = MinDalle(is_mega=True, is_reusable=True)"
] ]
}, },
{ {

@ -1 +1 @@
from .min_dalle_torch import MinDalleTorch from .min_dalle import MinDalle

@ -1,6 +1,5 @@
import os import os
from PIL import Image from PIL import Image
from typing import Dict
import numpy import numpy
from torch import LongTensor from torch import LongTensor
import torch import torch
@ -10,16 +9,13 @@ import random
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count()) torch.set_num_threads(os.cpu_count())
from .text_tokenizer import TextTokenizer
from .models import DalleBartEncoder, DalleBartDecoder, VQGanDetokenizer
MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/' MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
from .text_tokenizer import TextTokenizer
from .models import (
DalleBartEncoderTorch,
DalleBartDecoderTorch,
VQGanDetokenizer
)
class MinDalleTorch: class MinDalle:
def __init__( def __init__(
self, self,
is_mega: bool, is_mega: bool,
@ -104,7 +100,7 @@ class MinDalleTorch:
is_downloaded = os.path.exists(self.encoder_params_path) is_downloaded = os.path.exists(self.encoder_params_path)
if not is_downloaded: self.download_encoder() if not is_downloaded: self.download_encoder()
print("initializing DalleBartEncoderTorch") print("initializing DalleBartEncoderTorch")
self.encoder = DalleBartEncoderTorch( self.encoder = DalleBartEncoder(
attention_head_count = self.attention_head_count, attention_head_count = self.attention_head_count,
embed_count = self.embed_count, embed_count = self.embed_count,
glu_embed_count = self.glu_embed_count, glu_embed_count = self.glu_embed_count,
@ -122,7 +118,7 @@ class MinDalleTorch:
is_downloaded = os.path.exists(self.decoder_params_path) is_downloaded = os.path.exists(self.decoder_params_path)
if not is_downloaded: self.download_decoder() if not is_downloaded: self.download_decoder()
print("initializing DalleBartDecoderTorch") print("initializing DalleBartDecoderTorch")
self.decoder = DalleBartDecoderTorch( self.decoder = DalleBartDecoder(
sample_token_count = self.sample_token_count, sample_token_count = self.sample_token_count,
image_token_count = self.image_token_count, image_token_count = self.image_token_count,
image_vocab_count = self.image_vocab_count, image_vocab_count = self.image_vocab_count,

@ -1,3 +1,3 @@
from .dalle_bart_encoder_torch import DalleBartEncoderTorch from .dalle_bart_encoder import DalleBartEncoder
from .dalle_bart_decoder_torch import DalleBartDecoderTorch from .dalle_bart_decoder import DalleBartDecoder
from .vqgan_detokenizer import VQGanDetokenizer from .vqgan_detokenizer import VQGanDetokenizer

@ -3,10 +3,10 @@ import torch
from torch import LongTensor, nn, FloatTensor, BoolTensor from torch import LongTensor, nn, FloatTensor, BoolTensor
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
from .dalle_bart_encoder_torch import GLUTorch, AttentionTorch from .dalle_bart_encoder import GLU, AttentionBase
class DecoderCrossAttentionTorch(AttentionTorch): class DecoderCrossAttention(AttentionBase):
def forward( def forward(
self, self,
decoder_state: FloatTensor, decoder_state: FloatTensor,
@ -19,7 +19,7 @@ class DecoderCrossAttentionTorch(AttentionTorch):
return super().forward(keys, values, queries, attention_mask) return super().forward(keys, values, queries, attention_mask)
class DecoderSelfAttentionTorch(AttentionTorch): class DecoderSelfAttention(AttentionBase):
def forward( def forward(
self, self,
decoder_state: FloatTensor, decoder_state: FloatTensor,
@ -42,7 +42,7 @@ class DecoderSelfAttentionTorch(AttentionTorch):
return decoder_state, attention_state return decoder_state, attention_state
class DecoderLayerTorch(nn.Module): class DecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
image_token_count: int, image_token_count: int,
@ -53,12 +53,12 @@ class DecoderLayerTorch(nn.Module):
super().__init__() super().__init__()
self.image_token_count = image_token_count self.image_token_count = image_token_count
self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count) self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
self.self_attn = DecoderSelfAttentionTorch(head_count, embed_count) self.self_attn = DecoderSelfAttention(head_count, embed_count)
self.self_attn_layer_norm = nn.LayerNorm(embed_count) self.self_attn_layer_norm = nn.LayerNorm(embed_count)
self.pre_encoder_attn_layer_norm = nn.LayerNorm(embed_count) self.pre_encoder_attn_layer_norm = nn.LayerNorm(embed_count)
self.encoder_attn = DecoderCrossAttentionTorch(head_count, embed_count) self.encoder_attn = DecoderCrossAttention(head_count, embed_count)
self.encoder_attn_layer_norm = nn.LayerNorm(embed_count) self.encoder_attn_layer_norm = nn.LayerNorm(embed_count)
self.glu = GLUTorch(embed_count, glu_embed_count) self.glu = GLU(embed_count, glu_embed_count)
self.token_indices = torch.arange(self.image_token_count) self.token_indices = torch.arange(self.image_token_count)
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -106,7 +106,7 @@ class DecoderLayerTorch(nn.Module):
return decoder_state, attention_state return decoder_state, attention_state
class DalleBartDecoderTorch(nn.Module): class DalleBartDecoder(nn.Module):
def __init__( def __init__(
self, self,
image_vocab_count: int, image_vocab_count: int,
@ -126,8 +126,8 @@ class DalleBartDecoderTorch(nn.Module):
self.image_token_count = image_token_count self.image_token_count = image_token_count
self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count) self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count)
self.embed_positions = nn.Embedding(image_token_count, embed_count) self.embed_positions = nn.Embedding(image_token_count, embed_count)
self.layers: List[DecoderLayerTorch] = nn.ModuleList([ self.layers: List[DecoderLayer] = nn.ModuleList([
DecoderLayerTorch( DecoderLayer(
image_token_count, image_token_count,
attention_head_count, attention_head_count,
embed_count, embed_count,

@ -4,7 +4,7 @@ from torch import nn, BoolTensor, FloatTensor, LongTensor
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
class GLUTorch(nn.Module): class GLU(nn.Module):
def __init__(self, count_in_out, count_middle): def __init__(self, count_in_out, count_middle):
super().__init__() super().__init__()
self.gelu = nn.GELU() self.gelu = nn.GELU()
@ -24,7 +24,7 @@ class GLUTorch(nn.Module):
return z return z
class AttentionTorch(nn.Module): class AttentionBase(nn.Module):
def __init__(self, head_count: int, embed_count: int): def __init__(self, head_count: int, embed_count: int):
super().__init__() super().__init__()
self.head_count = head_count self.head_count = head_count
@ -72,7 +72,7 @@ class AttentionTorch(nn.Module):
return attention_output return attention_output
class EncoderSelfAttentionTorch(AttentionTorch): class EncoderSelfAttention(AttentionBase):
def forward( def forward(
self, self,
encoder_state: FloatTensor, encoder_state: FloatTensor,
@ -84,13 +84,13 @@ class EncoderSelfAttentionTorch(AttentionTorch):
return super().forward(keys, values, queries, attention_mask) return super().forward(keys, values, queries, attention_mask)
class EncoderLayerTorch(nn.Module): class EncoderLayer(nn.Module):
def __init__(self, embed_count: int, head_count: int, glu_embed_count: int): def __init__(self, embed_count: int, head_count: int, glu_embed_count: int):
super().__init__() super().__init__()
self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count) self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
self.self_attn = EncoderSelfAttentionTorch(head_count, embed_count) self.self_attn = EncoderSelfAttention(head_count, embed_count)
self.self_attn_layer_norm = nn.LayerNorm(embed_count) self.self_attn_layer_norm = nn.LayerNorm(embed_count)
self.glu = GLUTorch(embed_count, glu_embed_count) self.glu = GLU(embed_count, glu_embed_count)
def forward( def forward(
self, self,
@ -108,7 +108,7 @@ class EncoderLayerTorch(nn.Module):
return encoder_state return encoder_state
class DalleBartEncoderTorch(nn.Module): class DalleBartEncoder(nn.Module):
def __init__( def __init__(
self, self,
layer_count: int, layer_count: int,
@ -121,8 +121,8 @@ class DalleBartEncoderTorch(nn.Module):
super().__init__() super().__init__()
self.embed_tokens = nn.Embedding(text_vocab_count, embed_count) self.embed_tokens = nn.Embedding(text_vocab_count, embed_count)
self.embed_positions = nn.Embedding(text_token_count, embed_count) self.embed_positions = nn.Embedding(text_token_count, embed_count)
self.layers: List[EncoderLayerTorch] = nn.ModuleList([ self.layers: List[EncoderLayer] = nn.ModuleList([
EncoderLayerTorch( EncoderLayer(
embed_count = embed_count, embed_count = embed_count,
head_count = attention_head_count, head_count = attention_head_count,
glu_embed_count = glu_embed_count glu_embed_count = glu_embed_count

@ -1,11 +1,11 @@
import tempfile import tempfile
from cog import BasePredictor, Path, Input from cog import BasePredictor, Path, Input
from min_dalle.min_dalle_torch import MinDalleTorch from min_dalle import MinDalle
class Predictor(BasePredictor): class Predictor(BasePredictor):
def setup(self): def setup(self):
self.model = MinDalleTorch(is_mega=True) self.model = MinDalle(is_mega=True)
def predict( def predict(
self, self,

@ -3,9 +3,9 @@ import setuptools
setuptools.setup( setuptools.setup(
name='min-dalle', name='min-dalle',
description = 'min(DALL·E)', description = 'min(DALL·E)',
version='0.1.4', version='0.2.0',
author='Brett Kuprel', author='Brett Kuprel',
author_email = 'brkuprel@gmail.com', author_email='brkuprel@gmail.com',
packages=[ packages=[
'min_dalle', 'min_dalle',
'min_dalle.models' 'min_dalle.models'
@ -18,6 +18,7 @@ setuptools.setup(
keywords = [ keywords = [
'artificial intelligence', 'artificial intelligence',
'deep learning', 'deep learning',
'text to image' 'text-to-image',
'pytorch'
] ]
) )
Loading…
Cancel
Save