From fb97ba5e204b81ae869c9b5f4f35eac9d7f23b54 Mon Sep 17 00:00:00 2001 From: Brett Kuprel Date: Thu, 30 Jun 2022 07:41:31 -0400 Subject: [PATCH] update readme, cleanup --- README.md | 2 +- min_dalle/min_dalle_torch.py | 2 -- min_dalle/models/dalle_bart_decoder_torch.py | 3 ++- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index ac20a8b..46724cc 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb)   [![Replicate](https://replicate.com/kuprel/min-dalle/badge)](https://replicate.com/kuprel/min-dalle) -This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are numpy, torch, and flax (and optionally wandb to download the models). +This is a minimal implementation of Boris Dayma's [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. To run the torch model, the only third party dependencies are numpy and torch. Flax is used to convert the weights (which can be saved with `torch.save` once the model is loaded), and wandb is only used to download the models. It currently takes **7.3 seconds** to generate an avocado armchair with DALL·E Mega in PyTorch on Colab diff --git a/min_dalle/min_dalle_torch.py b/min_dalle/min_dalle_torch.py index 2dd21cd..bdfa662 100644 --- a/min_dalle/min_dalle_torch.py +++ b/min_dalle/min_dalle_torch.py @@ -1,5 +1,3 @@ -from random import sample -import numpy import os from PIL import Image from typing import Dict diff --git a/min_dalle/models/dalle_bart_decoder_torch.py b/min_dalle/models/dalle_bart_decoder_torch.py index 3aab6dd..4b07beb 100644 --- a/min_dalle/models/dalle_bart_decoder_torch.py +++ b/min_dalle/models/dalle_bart_decoder_torch.py @@ -36,7 +36,8 @@ class DecoderSelfAttentionTorch(AttentionTorch): attention_state ) batch_count = decoder_state.shape[0] - keys, values = attention_state[:batch_count], attention_state[batch_count:] + keys = attention_state[:batch_count] + values = attention_state[batch_count:] decoder_state = super().forward(keys, values, queries, attention_mask) return decoder_state, attention_state