v0.2.0, MinDalleTorch -> MinDalle, breaking change
This commit is contained in:
parent
2080e596c3
commit
35e97768a5
6
README.md
vendored
6
README.md
vendored
|
@ -22,12 +22,12 @@ $ pip install min-dalle
|
||||||
|
|
||||||
### Python
|
### Python
|
||||||
|
|
||||||
To load a model once and generate multiple times, first initialize `MinDalleTorch`.
|
To load a model once and generate multiple times, first initialize `MinDalle`.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from min_dalle import MinDalleTorch
|
from min_dalle import MinDalle
|
||||||
|
|
||||||
model = MinDalleTorch(
|
model = MinDalle(
|
||||||
is_mega=True,
|
is_mega=True,
|
||||||
is_reusable=True,
|
is_reusable=True,
|
||||||
models_root='./pretrained'
|
models_root='./pretrained'
|
||||||
|
|
|
@ -2,14 +2,15 @@ import argparse
|
||||||
import os
|
import os
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from min_dalle import MinDalleTorch
|
from min_dalle import MinDalle
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--mega', action='store_true')
|
parser.add_argument('--mega', action='store_true')
|
||||||
parser.add_argument('--no-mega', dest='mega', action='store_false')
|
parser.add_argument('--no-mega', dest='mega', action='store_false')
|
||||||
parser.set_defaults(mega=False)
|
parser.set_defaults(mega=False)
|
||||||
parser.add_argument('--text', type=str, default='alien life')
|
parser.add_argument('--text', type=str, default='alien life')
|
||||||
parser.add_argument('--seed', type=int, default=7)
|
parser.add_argument('--seed', type=int, default=-1)
|
||||||
parser.add_argument('--image_path', type=str, default='generated')
|
parser.add_argument('--image_path', type=str, default='generated')
|
||||||
parser.add_argument('--token_count', type=int, default=256) # for debugging
|
parser.add_argument('--token_count', type=int, default=256) # for debugging
|
||||||
|
|
||||||
|
@ -39,7 +40,7 @@ def generate_image(
|
||||||
image_path: str,
|
image_path: str,
|
||||||
token_count: int
|
token_count: int
|
||||||
):
|
):
|
||||||
model = MinDalleTorch(
|
model = MinDalle(
|
||||||
is_mega=is_mega,
|
is_mega=is_mega,
|
||||||
models_root='pretrained',
|
models_root='pretrained',
|
||||||
is_reusable=False,
|
is_reusable=False,
|
||||||
|
|
4
min_dalle.ipynb
vendored
4
min_dalle.ipynb
vendored
|
@ -77,9 +77,9 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from min_dalle import MinDalleTorch\n",
|
"from min_dalle import MinDalle\n",
|
||||||
"\n",
|
"\n",
|
||||||
"model = MinDalleTorch(is_mega=True, is_reusable=True)"
|
"model = MinDalle(is_mega=True, is_reusable=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
from .min_dalle_torch import MinDalleTorch
|
from .min_dalle import MinDalle
|
|
@ -1,6 +1,5 @@
|
||||||
import os
|
import os
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from typing import Dict
|
|
||||||
import numpy
|
import numpy
|
||||||
from torch import LongTensor
|
from torch import LongTensor
|
||||||
import torch
|
import torch
|
||||||
|
@ -10,16 +9,13 @@ import random
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
torch.set_num_threads(os.cpu_count())
|
torch.set_num_threads(os.cpu_count())
|
||||||
|
|
||||||
|
from .text_tokenizer import TextTokenizer
|
||||||
|
from .models import DalleBartEncoder, DalleBartDecoder, VQGanDetokenizer
|
||||||
|
|
||||||
MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
|
MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
|
||||||
|
|
||||||
from .text_tokenizer import TextTokenizer
|
|
||||||
from .models import (
|
|
||||||
DalleBartEncoderTorch,
|
|
||||||
DalleBartDecoderTorch,
|
|
||||||
VQGanDetokenizer
|
|
||||||
)
|
|
||||||
|
|
||||||
class MinDalleTorch:
|
class MinDalle:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
is_mega: bool,
|
is_mega: bool,
|
||||||
|
@ -104,7 +100,7 @@ class MinDalleTorch:
|
||||||
is_downloaded = os.path.exists(self.encoder_params_path)
|
is_downloaded = os.path.exists(self.encoder_params_path)
|
||||||
if not is_downloaded: self.download_encoder()
|
if not is_downloaded: self.download_encoder()
|
||||||
print("initializing DalleBartEncoderTorch")
|
print("initializing DalleBartEncoderTorch")
|
||||||
self.encoder = DalleBartEncoderTorch(
|
self.encoder = DalleBartEncoder(
|
||||||
attention_head_count = self.attention_head_count,
|
attention_head_count = self.attention_head_count,
|
||||||
embed_count = self.embed_count,
|
embed_count = self.embed_count,
|
||||||
glu_embed_count = self.glu_embed_count,
|
glu_embed_count = self.glu_embed_count,
|
||||||
|
@ -122,7 +118,7 @@ class MinDalleTorch:
|
||||||
is_downloaded = os.path.exists(self.decoder_params_path)
|
is_downloaded = os.path.exists(self.decoder_params_path)
|
||||||
if not is_downloaded: self.download_decoder()
|
if not is_downloaded: self.download_decoder()
|
||||||
print("initializing DalleBartDecoderTorch")
|
print("initializing DalleBartDecoderTorch")
|
||||||
self.decoder = DalleBartDecoderTorch(
|
self.decoder = DalleBartDecoder(
|
||||||
sample_token_count = self.sample_token_count,
|
sample_token_count = self.sample_token_count,
|
||||||
image_token_count = self.image_token_count,
|
image_token_count = self.image_token_count,
|
||||||
image_vocab_count = self.image_vocab_count,
|
image_vocab_count = self.image_vocab_count,
|
|
@ -1,3 +1,3 @@
|
||||||
from .dalle_bart_encoder_torch import DalleBartEncoderTorch
|
from .dalle_bart_encoder import DalleBartEncoder
|
||||||
from .dalle_bart_decoder_torch import DalleBartDecoderTorch
|
from .dalle_bart_decoder import DalleBartDecoder
|
||||||
from .vqgan_detokenizer import VQGanDetokenizer
|
from .vqgan_detokenizer import VQGanDetokenizer
|
|
@ -3,10 +3,10 @@ import torch
|
||||||
from torch import LongTensor, nn, FloatTensor, BoolTensor
|
from torch import LongTensor, nn, FloatTensor, BoolTensor
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
from .dalle_bart_encoder_torch import GLUTorch, AttentionTorch
|
from .dalle_bart_encoder import GLU, AttentionBase
|
||||||
|
|
||||||
|
|
||||||
class DecoderCrossAttentionTorch(AttentionTorch):
|
class DecoderCrossAttention(AttentionBase):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
decoder_state: FloatTensor,
|
decoder_state: FloatTensor,
|
||||||
|
@ -19,7 +19,7 @@ class DecoderCrossAttentionTorch(AttentionTorch):
|
||||||
return super().forward(keys, values, queries, attention_mask)
|
return super().forward(keys, values, queries, attention_mask)
|
||||||
|
|
||||||
|
|
||||||
class DecoderSelfAttentionTorch(AttentionTorch):
|
class DecoderSelfAttention(AttentionBase):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
decoder_state: FloatTensor,
|
decoder_state: FloatTensor,
|
||||||
|
@ -42,7 +42,7 @@ class DecoderSelfAttentionTorch(AttentionTorch):
|
||||||
return decoder_state, attention_state
|
return decoder_state, attention_state
|
||||||
|
|
||||||
|
|
||||||
class DecoderLayerTorch(nn.Module):
|
class DecoderLayer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
image_token_count: int,
|
image_token_count: int,
|
||||||
|
@ -53,12 +53,12 @@ class DecoderLayerTorch(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.image_token_count = image_token_count
|
self.image_token_count = image_token_count
|
||||||
self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
|
self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
|
||||||
self.self_attn = DecoderSelfAttentionTorch(head_count, embed_count)
|
self.self_attn = DecoderSelfAttention(head_count, embed_count)
|
||||||
self.self_attn_layer_norm = nn.LayerNorm(embed_count)
|
self.self_attn_layer_norm = nn.LayerNorm(embed_count)
|
||||||
self.pre_encoder_attn_layer_norm = nn.LayerNorm(embed_count)
|
self.pre_encoder_attn_layer_norm = nn.LayerNorm(embed_count)
|
||||||
self.encoder_attn = DecoderCrossAttentionTorch(head_count, embed_count)
|
self.encoder_attn = DecoderCrossAttention(head_count, embed_count)
|
||||||
self.encoder_attn_layer_norm = nn.LayerNorm(embed_count)
|
self.encoder_attn_layer_norm = nn.LayerNorm(embed_count)
|
||||||
self.glu = GLUTorch(embed_count, glu_embed_count)
|
self.glu = GLU(embed_count, glu_embed_count)
|
||||||
|
|
||||||
self.token_indices = torch.arange(self.image_token_count)
|
self.token_indices = torch.arange(self.image_token_count)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -106,7 +106,7 @@ class DecoderLayerTorch(nn.Module):
|
||||||
return decoder_state, attention_state
|
return decoder_state, attention_state
|
||||||
|
|
||||||
|
|
||||||
class DalleBartDecoderTorch(nn.Module):
|
class DalleBartDecoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
image_vocab_count: int,
|
image_vocab_count: int,
|
||||||
|
@ -126,8 +126,8 @@ class DalleBartDecoderTorch(nn.Module):
|
||||||
self.image_token_count = image_token_count
|
self.image_token_count = image_token_count
|
||||||
self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count)
|
self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count)
|
||||||
self.embed_positions = nn.Embedding(image_token_count, embed_count)
|
self.embed_positions = nn.Embedding(image_token_count, embed_count)
|
||||||
self.layers: List[DecoderLayerTorch] = nn.ModuleList([
|
self.layers: List[DecoderLayer] = nn.ModuleList([
|
||||||
DecoderLayerTorch(
|
DecoderLayer(
|
||||||
image_token_count,
|
image_token_count,
|
||||||
attention_head_count,
|
attention_head_count,
|
||||||
embed_count,
|
embed_count,
|
|
@ -4,7 +4,7 @@ from torch import nn, BoolTensor, FloatTensor, LongTensor
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
|
|
||||||
class GLUTorch(nn.Module):
|
class GLU(nn.Module):
|
||||||
def __init__(self, count_in_out, count_middle):
|
def __init__(self, count_in_out, count_middle):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gelu = nn.GELU()
|
self.gelu = nn.GELU()
|
||||||
|
@ -24,7 +24,7 @@ class GLUTorch(nn.Module):
|
||||||
return z
|
return z
|
||||||
|
|
||||||
|
|
||||||
class AttentionTorch(nn.Module):
|
class AttentionBase(nn.Module):
|
||||||
def __init__(self, head_count: int, embed_count: int):
|
def __init__(self, head_count: int, embed_count: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.head_count = head_count
|
self.head_count = head_count
|
||||||
|
@ -72,7 +72,7 @@ class AttentionTorch(nn.Module):
|
||||||
return attention_output
|
return attention_output
|
||||||
|
|
||||||
|
|
||||||
class EncoderSelfAttentionTorch(AttentionTorch):
|
class EncoderSelfAttention(AttentionBase):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
encoder_state: FloatTensor,
|
encoder_state: FloatTensor,
|
||||||
|
@ -84,13 +84,13 @@ class EncoderSelfAttentionTorch(AttentionTorch):
|
||||||
return super().forward(keys, values, queries, attention_mask)
|
return super().forward(keys, values, queries, attention_mask)
|
||||||
|
|
||||||
|
|
||||||
class EncoderLayerTorch(nn.Module):
|
class EncoderLayer(nn.Module):
|
||||||
def __init__(self, embed_count: int, head_count: int, glu_embed_count: int):
|
def __init__(self, embed_count: int, head_count: int, glu_embed_count: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
|
self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
|
||||||
self.self_attn = EncoderSelfAttentionTorch(head_count, embed_count)
|
self.self_attn = EncoderSelfAttention(head_count, embed_count)
|
||||||
self.self_attn_layer_norm = nn.LayerNorm(embed_count)
|
self.self_attn_layer_norm = nn.LayerNorm(embed_count)
|
||||||
self.glu = GLUTorch(embed_count, glu_embed_count)
|
self.glu = GLU(embed_count, glu_embed_count)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -108,7 +108,7 @@ class EncoderLayerTorch(nn.Module):
|
||||||
return encoder_state
|
return encoder_state
|
||||||
|
|
||||||
|
|
||||||
class DalleBartEncoderTorch(nn.Module):
|
class DalleBartEncoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
layer_count: int,
|
layer_count: int,
|
||||||
|
@ -121,8 +121,8 @@ class DalleBartEncoderTorch(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_tokens = nn.Embedding(text_vocab_count, embed_count)
|
self.embed_tokens = nn.Embedding(text_vocab_count, embed_count)
|
||||||
self.embed_positions = nn.Embedding(text_token_count, embed_count)
|
self.embed_positions = nn.Embedding(text_token_count, embed_count)
|
||||||
self.layers: List[EncoderLayerTorch] = nn.ModuleList([
|
self.layers: List[EncoderLayer] = nn.ModuleList([
|
||||||
EncoderLayerTorch(
|
EncoderLayer(
|
||||||
embed_count = embed_count,
|
embed_count = embed_count,
|
||||||
head_count = attention_head_count,
|
head_count = attention_head_count,
|
||||||
glu_embed_count = glu_embed_count
|
glu_embed_count = glu_embed_count
|
|
@ -1,11 +1,11 @@
|
||||||
import tempfile
|
import tempfile
|
||||||
from cog import BasePredictor, Path, Input
|
from cog import BasePredictor, Path, Input
|
||||||
|
|
||||||
from min_dalle.min_dalle_torch import MinDalleTorch
|
from min_dalle import MinDalle
|
||||||
|
|
||||||
class Predictor(BasePredictor):
|
class Predictor(BasePredictor):
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self.model = MinDalleTorch(is_mega=True)
|
self.model = MinDalle(is_mega=True)
|
||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self,
|
self,
|
||||||
|
|
5
setup.py
5
setup.py
|
@ -3,7 +3,7 @@ import setuptools
|
||||||
setuptools.setup(
|
setuptools.setup(
|
||||||
name='min-dalle',
|
name='min-dalle',
|
||||||
description = 'min(DALL·E)',
|
description = 'min(DALL·E)',
|
||||||
version='0.1.4',
|
version='0.2.0',
|
||||||
author='Brett Kuprel',
|
author='Brett Kuprel',
|
||||||
author_email='brkuprel@gmail.com',
|
author_email='brkuprel@gmail.com',
|
||||||
packages=[
|
packages=[
|
||||||
|
@ -18,6 +18,7 @@ setuptools.setup(
|
||||||
keywords = [
|
keywords = [
|
||||||
'artificial intelligence',
|
'artificial intelligence',
|
||||||
'deep learning',
|
'deep learning',
|
||||||
'text to image'
|
'text-to-image',
|
||||||
|
'pytorch'
|
||||||
]
|
]
|
||||||
)
|
)
|
Loading…
Reference in New Issue
Block a user