update readme, cleanup

This commit is contained in:
Brett Kuprel 2022-06-30 07:41:31 -04:00
parent 1e18ba0ffa
commit fb97ba5e20
3 changed files with 3 additions and 4 deletions

2
README.md vendored
View File

@ -3,7 +3,7 @@
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb)  
[![Replicate](https://replicate.com/kuprel/min-dalle/badge)](https://replicate.com/kuprel/min-dalle)
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are numpy, torch, and flax (and optionally wandb to download the models).
This is a minimal implementation of Boris Dayma's [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. To run the torch model, the only third party dependencies are numpy and torch. Flax is used to convert the weights (which can be saved with `torch.save` once the model is loaded), and wandb is only used to download the models.
It currently takes **7.3 seconds** to generate an avocado armchair with DALL·E Mega in PyTorch on Colab

View File

@ -1,5 +1,3 @@
from random import sample
import numpy
import os
from PIL import Image
from typing import Dict

View File

@ -36,7 +36,8 @@ class DecoderSelfAttentionTorch(AttentionTorch):
attention_state
)
batch_count = decoder_state.shape[0]
keys, values = attention_state[:batch_count], attention_state[batch_count:]
keys = attention_state[:batch_count]
values = attention_state[batch_count:]
decoder_state = super().forward(keys, values, queries, attention_mask)
return decoder_state, attention_state