From 3c28b1059bcd6b5cacbe0c0fc31bda326f2556b1 Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Fri, 15 Jul 2022 17:18:23 -0400 Subject: [PATCH] 0.3.13, simplified code, specify device when initializing MinDalle --- README.md | 5 +- cog.yaml | 2 +- min_dalle.ipynb | 17 ++- min_dalle/min_dalle.py | 102 +++++++++++------- min_dalle/models/dalle_bart_decoder.py | 140 +++++++------------------ min_dalle/models/dalle_bart_encoder.py | 22 ++-- replicate_predictor.py | 6 +- setup.py | 2 +- tkinter_ui.py | 31 +++--- 9 files changed, 139 insertions(+), 188 deletions(-) diff --git a/README.md b/README.md index e7e0679..6b9ff6b 100644 --- a/README.md +++ b/README.md @@ -36,12 +36,13 @@ from min_dalle import MinDalle model = MinDalle( models_root='./pretrained', dtype=torch.float32, + device='cuda', is_mega=True, is_reusable=True ) ``` -The required models will be downloaded to `models_root` if they are not already there. Set the `dtype` to `torch.float16` to save GPU memory. If you have an Ampere architecture GPU you can use `torch.bfloat16`. Once everything has finished initializing, call `generate_image` with some text as many times as you want. Use a positive `seed` for reproducible results. Higher values for `supercondition_factor` result in better agreement with the text but a narrower variety of generated images. Every image token is sampled from the `top_k` most probable tokens. The largest logit is subtracted from the logits to avoid infs. The logits are then divided by the `temperature`. +The required models will be downloaded to `models_root` if they are not already there. Set the `dtype` to `torch.float16` to save GPU memory. If you have an Ampere architecture GPU you can use `torch.bfloat16`. Set the `device` to either "cuda" or "cpu". Once everything has finished initializing, call `generate_image` with some text as many times as you want. Use a positive `seed` for reproducible results. Higher values for `supercondition_factor` result in better agreement with the text but a narrower variety of generated images. Every image token is sampled from the `top_k` most probable tokens. The largest logit is subtracted from the logits to avoid infs. The logits are then divided by the `temperature`. ```python image = model.generate_image( @@ -88,7 +89,7 @@ image.save('image_{}.png'.format(i)) ### Progressive Outputs -If the model is being used interactively (e.g. in a notebook) `generate_image_stream` can be used to generate a stream of images as the model is decoding. The detokenizer adds a slight delay for each image. Setting `log2_mid_count` to 3 results in a total of `2 ** 3 = 8` generated images. The only valid values for `log2_mid_count` are 0, 1, 2, 3, and 4. This is implemented in the colab. +If the model is being used interactively (e.g. in a notebook) `generate_image_stream` can be used to generate a stream of images as the model is decoding. The detokenizer adds a slight delay for each image. Set `progressive_outputs` to `True` to enable this. An example is implemented in the colab. ```python image_stream = model.generate_image_stream( diff --git a/cog.yaml b/cog.yaml index a900493..f724114 100644 --- a/cog.yaml +++ b/cog.yaml @@ -6,7 +6,7 @@ build: - "libgl1-mesa-glx" - "libglib2.0-0" python_packages: - - "min-dalle==0.3.12" + - "min-dalle==0.3.13" run: - pip install torch==1.12.0+cu116 -f https://download.pytorch.org/whl/torch_stable.html diff --git a/min_dalle.ipynb b/min_dalle.ipynb index 60c70a2..8f94aa7 100644 --- a/min_dalle.ipynb +++ b/min_dalle.ipynb @@ -135,6 +135,7 @@ "\n", "model = MinDalle(\n", " dtype=getattr(torch, dtype),\n", + " device='cuda',\n", " is_mega=True, \n", " is_reusable=True\n", ")" @@ -196,14 +197,13 @@ "grid_size = 5 #@param {type:\"integer\"}\n", "temperature = 2 #@param {type:\"slider\", min:0.01, max:3, step:0.01}\n", "supercondition_factor = 16 #@param {type:\"number\"}\n", - "top_k = 256 #@param {type:\"integer\"}\n", - "log2_mid_count = 3 if progressive_outputs else 0\n", + "top_k = 128 #@param {type:\"integer\"}\n", "\n", "image_stream = model.generate_image_stream(\n", " text=text,\n", " seed=-1,\n", " grid_size=grid_size,\n", - " log2_mid_count=log2_mid_count,\n", + " progressive_outputs=progressive_outputs,\n", " temperature=temperature,\n", " top_k=int(top_k),\n", " supercondition_factor=float(supercondition_factor)\n", @@ -229,11 +229,18 @@ }, "gpuClass": "standard", "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3.9.13 64-bit", + "language": "python", "name": "python3" }, "language_info": { - "name": "python" + "name": "python", + "version": "3.9.13" + }, + "vscode": { + "interpreter": { + "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e" + } } }, "nbformat": 4, diff --git a/min_dalle/min_dalle.py b/min_dalle/min_dalle.py index 26981e6..f38e499 100644 --- a/min_dalle/min_dalle.py +++ b/min_dalle/min_dalle.py @@ -20,6 +20,7 @@ torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/' +IMAGE_TOKEN_COUNT = 256 class MinDalle: @@ -27,10 +28,15 @@ class MinDalle: self, models_root: str = 'pretrained', dtype: torch.dtype = torch.float32, + device: str = None, is_mega: bool = True, is_reusable: bool = True, is_verbose = True ): + if device == None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + if is_verbose: print("using device", device) + self.device = device self.is_mega = is_mega self.is_reusable = is_reusable self.dtype = dtype @@ -112,12 +118,13 @@ class MinDalle: glu_embed_count = self.glu_embed_count, text_token_count = self.text_token_count, text_vocab_count = self.text_vocab_count, - layer_count = self.layer_count + layer_count = self.layer_count, + device=self.device ).to(self.dtype).eval() params = torch.load(self.encoder_params_path) self.encoder.load_state_dict(params, strict=False) del params - if torch.cuda.is_available(): self.encoder = self.encoder.cuda() + self.encoder = self.encoder.to(device=self.device) def init_decoder(self): @@ -130,12 +137,12 @@ class MinDalle: embed_count = self.embed_count, glu_embed_count = self.glu_embed_count, layer_count = self.layer_count, - start_token = self.image_vocab_count + device=self.device ).to(self.dtype).eval() params = torch.load(self.decoder_params_path) self.decoder.load_state_dict(params, strict=False) del params - if torch.cuda.is_available(): self.decoder = self.decoder.cuda() + self.decoder = self.decoder.to(device=self.device) def init_detokenizer(self): @@ -146,7 +153,7 @@ class MinDalle: params = torch.load(self.detoker_params_path) self.detokenizer.load_state_dict(params) del params - if torch.cuda.is_available(): self.detokenizer = self.detokenizer.cuda() + self.detokenizer = self.detokenizer.to(device=self.device) def images_from_tokens( @@ -155,7 +162,7 @@ class MinDalle: is_verbose: bool = False ) -> FloatTensor: if not self.is_reusable: del self.decoder - if torch.cuda.is_available(): torch.cuda.empty_cache() + torch.cuda.empty_cache() if not self.is_reusable: self.init_detokenizer() if is_verbose: print("detokenizing image") images = self.detokenizer.forward(image_tokens).to(torch.uint8) @@ -176,13 +183,12 @@ class MinDalle: text: str, seed: int, image_count: int, - log2_mid_count: int, + progressive_outputs: bool = False, temperature: float = 1, top_k: int = 256, supercondition_factor: int = 16, is_verbose: bool = False ) -> Iterator[FloatTensor]: - assert(log2_mid_count in range(5)) if is_verbose: print("tokenizing text") tokens = self.tokenizer.tokenize(text, is_verbose=is_verbose) if len(tokens) > self.text_token_count: @@ -191,49 +197,67 @@ class MinDalle: text_tokens = numpy.ones((2, 64), dtype=numpy.int32) text_tokens[0, :2] = [tokens[0], tokens[-1]] text_tokens[1, :len(tokens)] = tokens - - text_tokens = torch.tensor(text_tokens).to(torch.long) - if torch.cuda.is_available(): text_tokens = text_tokens.cuda() + text_tokens = torch.tensor( + text_tokens, + dtype=torch.long, + device=self.device + ) if not self.is_reusable: self.init_encoder() if is_verbose: print("encoding text tokens") with torch.cuda.amp.autocast(dtype=self.dtype): encoder_state = self.encoder.forward(text_tokens) if not self.is_reusable: del self.encoder - if torch.cuda.is_available(): torch.cuda.empty_cache() + torch.cuda.empty_cache() if not self.is_reusable: self.init_decoder() with torch.cuda.amp.autocast(dtype=self.dtype): - encoder_state, attention_mask, attention_state, image_tokens = ( - self.decoder.decode_initial( - seed=seed, - image_count=image_count, - text_tokens=text_tokens, - encoder_state=encoder_state - ) + expanded_indices = [0] * image_count + [1] * image_count + text_tokens = text_tokens[expanded_indices] + encoder_state = encoder_state[expanded_indices] + attention_mask = text_tokens.not_equal(1) + attention_state = torch.zeros( + size=( + self.layer_count, + image_count * 4, + IMAGE_TOKEN_COUNT, + self.embed_count + ), + device=self.device ) + image_tokens = torch.full( + (IMAGE_TOKEN_COUNT + 1, image_count), + self.image_vocab_count, + dtype=torch.long, + device=self.device + ) + + if seed > 0: torch.manual_seed(seed) - row_count = 16 - for row_index in range(row_count): - if is_verbose: - print('sampling row {} of {}'.format(row_index + 1, row_count)) + token_indices = torch.arange(IMAGE_TOKEN_COUNT, device=self.device) + settings = torch.tensor( + [temperature, top_k, supercondition_factor], + dtype=torch.float32, + device=self.device + ) + for i in range(IMAGE_TOKEN_COUNT): with torch.cuda.amp.autocast(dtype=self.dtype): - attention_state, image_tokens = self.decoder.decode_row( - row_index, - temperature=temperature, - top_k=top_k, - supercondition_factor=supercondition_factor, - encoder_state=encoder_state, + image_tokens[i + 1], attention_state = self.decoder.forward( + settings=settings, attention_mask=attention_mask, + encoder_state=encoder_state, attention_state=attention_state, - image_tokens_sequence=image_tokens + prev_tokens=image_tokens[i], + token_index=token_indices[[i]] ) + with torch.cuda.amp.autocast(dtype=torch.float32): - if ((row_index + 1) * (2 ** log2_mid_count)) % row_count == 0: - tokens = image_tokens[:, 1:] - images = self.images_from_tokens(tokens, is_verbose) - yield images + if ((i + 1) % 32 == 0 and progressive_outputs) or i + 1 == 256: + yield self.images_from_tokens( + image_tokens=image_tokens[1:].T, + is_verbose=is_verbose + ) def generate_image_stream( @@ -241,7 +265,7 @@ class MinDalle: text: str, seed: int, grid_size: int, - log2_mid_count: int, + progressive_outputs: bool = False, temperature: float = 1, top_k: int = 256, supercondition_factor: int = 16, @@ -251,7 +275,7 @@ class MinDalle: text=text, seed=seed, image_count=grid_size ** 2, - log2_mid_count=log2_mid_count, + progressive_outputs=progressive_outputs, temperature=temperature, top_k=top_k, supercondition_factor=supercondition_factor, @@ -271,13 +295,12 @@ class MinDalle: supercondition_factor: int = 16, is_verbose: bool = False ) -> FloatTensor: - log2_mid_count = 0 images_stream = self.generate_images_stream( text=text, seed=seed, image_count=image_count, temperature=temperature, - log2_mid_count=log2_mid_count, + progressive_outputs=False, top_k=top_k, supercondition_factor=supercondition_factor, is_verbose=is_verbose @@ -295,12 +318,11 @@ class MinDalle: supercondition_factor: int = 16, is_verbose: bool = False ) -> Image.Image: - log2_mid_count = 0 image_stream = self.generate_image_stream( text=text, seed=seed, grid_size=grid_size, - log2_mid_count=log2_mid_count, + progressive_outputs=False, temperature=temperature, top_k=top_k, supercondition_factor=supercondition_factor, diff --git a/min_dalle/models/dalle_bart_decoder.py b/min_dalle/models/dalle_bart_decoder.py index 43d3fc8..f2ee357 100644 --- a/min_dalle/models/dalle_bart_decoder.py +++ b/min_dalle/models/dalle_bart_decoder.py @@ -4,7 +4,6 @@ from torch import nn, LongTensor, FloatTensor, BoolTensor from .dalle_bart_encoder import GLU, AttentionBase IMAGE_TOKEN_COUNT = 256 -BLANK_TOKEN = 6965 class DecoderCrossAttention(AttentionBase): @@ -23,21 +22,18 @@ class DecoderCrossAttention(AttentionBase): class DecoderSelfAttention(AttentionBase): def __init__(self, head_count: int, embed_count: int): super().__init__(head_count, embed_count) - token_indices = torch.arange(IMAGE_TOKEN_COUNT) - if torch.cuda.is_available(): token_indices = token_indices.cuda() - self.token_indices = token_indices + def forward( self, decoder_state: FloatTensor, attention_state: FloatTensor, + attn_mask: BoolTensor, token_index: LongTensor ) -> Tuple[FloatTensor, FloatTensor]: keys = self.k_proj.forward(decoder_state) values = self.v_proj.forward(decoder_state) queries = self.q_proj.forward(decoder_state) - attn_mask = self.token_indices < token_index + 1 - attn_mask = attn_mask[None][[0] * decoder_state.shape[0]] attn_state_new = torch.cat([keys, values]).to(attention_state.dtype) attention_state[:, token_index] = attn_state_new batch_count = decoder_state.shape[0] @@ -52,7 +48,8 @@ class DecoderLayer(nn.Module): self, head_count: int, embed_count: int, - glu_embed_count: int + glu_embed_count: int, + device: str ): super().__init__() self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count) @@ -62,6 +59,7 @@ class DecoderLayer(nn.Module): self.encoder_attn = DecoderCrossAttention(head_count, embed_count) self.encoder_attn_layer_norm = nn.LayerNorm(embed_count) self.glu = GLU(embed_count, glu_embed_count) + self.token_indices = torch.arange(IMAGE_TOKEN_COUNT, device=device) def forward( @@ -73,12 +71,15 @@ class DecoderLayer(nn.Module): token_index: LongTensor ) -> Tuple[FloatTensor, FloatTensor]: # Self Attention + self_attn_mask = self.token_indices < token_index + 1 + self_attn_mask = self_attn_mask[None][[0] * decoder_state.shape[0]] residual = decoder_state decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state) decoder_state, attention_state = self.self_attn.forward( - decoder_state, - attention_state, - token_index + decoder_state=decoder_state, + attention_state=attention_state, + attn_mask=self_attn_mask, + token_index=token_index ) decoder_state = self.self_attn_layer_norm.forward(decoder_state) decoder_state = residual + decoder_state @@ -87,9 +88,9 @@ class DecoderLayer(nn.Module): residual = decoder_state decoder_state = self.pre_encoder_attn_layer_norm.forward(decoder_state) decoder_state = self.encoder_attn.forward( - decoder_state, - encoder_state, - attention_mask + decoder_state=decoder_state, + encoder_state=encoder_state, + attention_mask=attention_mask ) decoder_state = self.encoder_attn_layer_norm.forward(decoder_state) decoder_state = residual + decoder_state @@ -110,7 +111,7 @@ class DalleBartDecoder(nn.Module): attention_head_count: int, glu_embed_count: int, layer_count: int, - start_token: int + device: str ): super().__init__() self.layer_count = layer_count @@ -120,70 +121,28 @@ class DalleBartDecoder(nn.Module): self.embed_positions = nn.Embedding(IMAGE_TOKEN_COUNT, embed_count) self.layers: List[DecoderLayer] = nn.ModuleList([ DecoderLayer( - attention_head_count, - embed_count, - glu_embed_count + head_count=attention_head_count, + embed_count=embed_count, + glu_embed_count=glu_embed_count, + device=device ) for _ in range(layer_count) ]) self.layernorm_embedding = nn.LayerNorm(embed_count) self.final_ln = nn.LayerNorm(embed_count) self.lm_head = nn.Linear(embed_count, image_vocab_count + 1, bias=False) - self.zero_prob = torch.zeros([1]) - self.token_indices = torch.arange(IMAGE_TOKEN_COUNT) - self.start_token = torch.tensor([start_token]).to(torch.long) - if torch.cuda.is_available(): - self.zero_prob = self.zero_prob.cuda() - self.token_indices = self.token_indices.cuda() - self.start_token = self.start_token.cuda() + self.token_indices = torch.arange(IMAGE_TOKEN_COUNT, device=device) - def decode_initial( + def forward( self, - seed: int, - image_count: int, - text_tokens: LongTensor, - encoder_state: FloatTensor - ) -> Tuple[FloatTensor, FloatTensor, FloatTensor, LongTensor]: - expanded_indices = [0] * image_count + [1] * image_count - text_tokens = text_tokens[expanded_indices] - encoder_state = encoder_state[expanded_indices] - attention_mask = text_tokens.not_equal(1) - - attention_state_shape = ( - self.layer_count, - image_count * 4, - IMAGE_TOKEN_COUNT, - self.embed_count - ) - attention_state = torch.zeros(attention_state_shape) - image_tokens_sequence = torch.full( - (image_count, IMAGE_TOKEN_COUNT + 1), - BLANK_TOKEN, - dtype=torch.long - ) - if torch.cuda.is_available(): - attention_state = attention_state.cuda() - image_tokens_sequence = image_tokens_sequence.cuda() - - image_tokens_sequence[:, 0] = self.start_token[0] - - if seed > 0: torch.manual_seed(seed) - - return encoder_state, attention_mask, attention_state, image_tokens_sequence - - - def decode_step( - self, - temperature: float, - top_k: int, - supercondition_factor: float, + settings: FloatTensor, attention_mask: BoolTensor, encoder_state: FloatTensor, attention_state: FloatTensor, prev_tokens: LongTensor, token_index: LongTensor - ) -> Tuple[FloatTensor, FloatTensor]: + ) -> Tuple[LongTensor, FloatTensor]: image_count = encoder_state.shape[0] // 2 token_index_batched = token_index[[0] * image_count * 2] prev_tokens = prev_tokens[list(range(image_count)) * 2] @@ -202,44 +161,19 @@ class DalleBartDecoder(nn.Module): ) decoder_state = self.final_ln(decoder_state) logits = self.lm_head(decoder_state) - a = supercondition_factor + temperature = settings[0] + top_k = settings[1].to(torch.long) + supercondition_factor = settings[2] + logits = logits[:, -1, : 2 ** 14] logits: FloatTensor = ( - logits[:image_count, -1] * (1 - a) + - logits[image_count:, -1] * a + logits[:image_count] * (1 - supercondition_factor) + + logits[image_count:] * supercondition_factor ) - - top_logits, _ = logits.topk(top_k, dim=-1) - is_kept = logits >= top_logits[:, [-1]] - logits -= top_logits[:, [0]] - logits /= max(temperature, 1e-6) - probs = torch.where(is_kept, torch.exp(logits), self.zero_prob) - probs[:, 2 ** 14:] = 0 # vqgan vocab_count is only 2 ** 14 - return probs, attention_state - - - def decode_row( - self, - row_index: int, - temperature: float, - top_k: int, - supercondition_factor: float, - encoder_state: FloatTensor, - attention_mask: BoolTensor, - attention_state: FloatTensor, - image_tokens_sequence: LongTensor - ) -> Tuple[FloatTensor, LongTensor]: - for col_index in range(16): - i = 16 * row_index + col_index - probs, attention_state = self.decode_step( - temperature = temperature, - top_k = top_k, - supercondition_factor = supercondition_factor, - attention_mask = attention_mask, - encoder_state = encoder_state, - attention_state = attention_state, - prev_tokens = image_tokens_sequence[:, i], - token_index = self.token_indices[[i]] - ) - image_tokens_sequence[:, i + 1] = torch.multinomial(probs, 1)[:, 0] - - return attention_state, image_tokens_sequence \ No newline at end of file + logits_sorted, _ = logits.sort(descending=True) + is_kept = logits >= logits_sorted[:, top_k: top_k + 1] + logits -= logits_sorted[:, [0]] + logits /= temperature + logits.exp_() + logits *= is_kept.to(torch.float32) + image_tokens = torch.multinomial(logits, 1)[:, 0] + return image_tokens, attention_state \ No newline at end of file diff --git a/min_dalle/models/dalle_bart_encoder.py b/min_dalle/models/dalle_bart_encoder.py index 4ee9045..e67a0ed 100644 --- a/min_dalle/models/dalle_bart_encoder.py +++ b/min_dalle/models/dalle_bart_encoder.py @@ -4,7 +4,7 @@ from torch import nn, BoolTensor, FloatTensor, LongTensor class GLU(nn.Module): - def __init__(self, count_in_out, count_middle): + def __init__(self, count_in_out: int, count_middle: int): super().__init__() self.gelu = nn.GELU() self.ln0 = nn.LayerNorm(count_in_out) @@ -33,8 +33,6 @@ class AttentionBase(nn.Module): self.v_proj = nn.Linear(embed_count, embed_count, bias=False) self.q_proj = nn.Linear(embed_count, embed_count, bias=False) self.out_proj = nn.Linear(embed_count, embed_count, bias=False) - self.one = torch.ones((1, 1)) - if torch.cuda.is_available(): self.one = self.one.cuda() def forward( self, @@ -48,11 +46,7 @@ class AttentionBase(nn.Module): queries = queries.reshape(queries.shape[:2] + (self.head_count, -1)) queries /= queries.shape[-1] ** 0.5 - attention_bias = torch.where( - attention_mask, - self.one * 0, - self.one * (-torch.inf), - ) + attention_bias = (1 - attention_mask.to(torch.float32)) * -1e12 attention_weights: FloatTensor = torch.einsum( 'bqhc,bkhc->bhqk', queries, @@ -115,7 +109,8 @@ class DalleBartEncoder(nn.Module): attention_head_count: int, text_vocab_count: int, text_token_count: int, - glu_embed_count: int + glu_embed_count: int, + device: str ): super().__init__() self.text_vocab_count = text_vocab_count @@ -131,17 +126,14 @@ class DalleBartEncoder(nn.Module): ]) self.layernorm_embedding = nn.LayerNorm(embed_count) self.final_ln = nn.LayerNorm(embed_count) - self.token_indices = torch.arange(text_token_count).to(torch.long) - if torch.cuda.is_available(): - self.token_indices = self.token_indices.cuda() + token_indices = torch.arange(text_token_count, device=device) + self.pose_tokens = torch.stack([token_indices] * 2) def forward(self, text_tokens: LongTensor) -> FloatTensor: attention_mask = text_tokens.not_equal(1) - pose_tokens = self.token_indices[None][[0] * text_tokens.shape[0]] - text_tokens.clamp_(0, self.text_vocab_count - 1) encoder_state = ( self.embed_tokens.forward(text_tokens) + - self.embed_positions.forward(pose_tokens) + self.embed_positions.forward(self.pose_tokens) ) encoder_state = self.layernorm_embedding.forward(encoder_state) for layer in self.layers: diff --git a/replicate_predictor.py b/replicate_predictor.py index c5e7f1b..8e702e4 100644 --- a/replicate_predictor.py +++ b/replicate_predictor.py @@ -39,12 +39,11 @@ class ReplicatePredictor(BasePredictor): description='Advanced Setting, see Readme below if interested.' ) ) -> Iterator[Path]: - log2_mid_count = 3 if progressive_outputs else 0 image_stream = self.model.generate_image_stream( text = text, seed = -1, grid_size = grid_size, - log2_mid_count = log2_mid_count, + progressive_outputs = progressive_outputs, temperature = eval(temperature), supercondition_factor = float(supercondition_factor), top_k = top_k, @@ -55,7 +54,8 @@ class ReplicatePredictor(BasePredictor): path = Path(tempfile.mkdtemp()) for image in image_stream: i += 1 - ext = 'png' if i == 2 ** log2_mid_count and save_as_png else 'jpg' + is_final = i == 8 if progressive_outputs else True + ext = 'png' if is_final and save_as_png else 'jpg' image_path = path / 'min-dalle-iter-{}.{}'.format(i, ext) image.save(str(image_path)) yield image_path \ No newline at end of file diff --git a/setup.py b/setup.py index f67f52d..96a4199 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name='min-dalle', description = 'min(DALLĀ·E)', # long_description=(Path(__file__).parent / "README.rst").read_text(), - version='0.3.12', + version='0.3.13', author='Brett Kuprel', author_email='brkuprel@gmail.com', url='https://github.com/kuprel/min-dalle', diff --git a/tkinter_ui.py b/tkinter_ui.py index 06c2551..3c10260 100644 --- a/tkinter_ui.py +++ b/tkinter_ui.py @@ -14,7 +14,7 @@ def regen_root(): root = tkinter.Tk() root.wm_resizable(False, False) - blank_image = PIL.ImageTk.PhotoImage(PIL.Image.new(size=(256 * 3, 256 * 3), mode="RGB")) + blank_image = PIL.ImageTk.PhotoImage(PIL.Image.new(size=(256 * 2, 256 * 2), mode="RGB")) padding_image = PIL.ImageTk.PhotoImage(PIL.Image.new(size=(16, 16), mode="RGBA")) regen_root() @@ -33,30 +33,24 @@ frm = ttk.Frame(root, padding=16) frm.grid() ttk.Button(frm, text="Mega", command=set_mega_true_and_destroy).grid(column=0, row=0) ttk.Label(frm, image=padding_image).grid(column=1, row=0) -ttk.Button(frm, text="Not-Mega", command=set_mega_false_and_destroy).grid(column=2, row=0) +ttk.Button(frm, text="Mini", command=set_mega_false_and_destroy).grid(column=2, row=0) root.mainloop() if is_mega is None: print("no option selected") sys.exit(0) -print("confirmed mega: ", str(is_mega)) - -# -- -- +print("is_mega", is_mega) model = MinDalle( - is_mega=is_mega, models_root="./pretrained", + is_mega=is_mega, is_reusable=True, is_verbose=True ) -# -- -- - regen_root() -# -- -- - label_image_content = blank_image sv_prompt = tkinter.StringVar(value="artificial intelligence") @@ -83,17 +77,20 @@ def generate(): return # and continue global label_image_content - image = model.generate_image( + image_stream = model.generate_image_stream( sv_prompt.get(), - grid_size=3, + grid_size=2, + seed=-1, + progressive_outputs=False, temperature=temperature, top_k=topk, supercondition_factor=supercond, is_verbose=True ) - image.save("out.png") - label_image_content = PIL.ImageTk.PhotoImage(image) - label_image.configure(image=label_image_content) + for image in image_stream: + label_image_content = PIL.ImageTk.PhotoImage(image) + label_image.configure(image=label_image_content) + label_image.update() frm = ttk.Frame(root, padding=16) frm.grid() @@ -131,6 +128,4 @@ ttk.Label(props, image=padding_image).grid(column=0, row=7) ttk.Button(props, text="Generate", command=generate).grid(column=0, row=8) ttk.Button(props, text="Quit", command=root.destroy).grid(column=1, row=8) -# alrighty -root.mainloop() - +root.mainloop() \ No newline at end of file