torch.no_grad(), cleanup

main
Brett Kuprel 2 years ago
parent 1e2649148b
commit aef24ea157
  1. 2
      README.md
  2. 124
      min_dalle.ipynb
  3. 3
      min_dalle/min_dalle_torch.py
  4. 3
      min_dalle/models/dalle_bart_decoder_torch.py
  5. 1
      min_dalle/models/dalle_bart_encoder_torch.py
  6. 1
      min_dalle/models/vqgan_detokenizer.py
  7. 5
      setup.sh

@ -2,7 +2,7 @@
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb)
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are `torch` for the torch model and `flax` for the flax model. This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are `numpy` and `torch` for the torch model and `flax` for the flax model.
### Setup ### Setup

File diff suppressed because one or more lines are too long

@ -1,7 +1,8 @@
import numpy import numpy
from typing import Dict
import torch import torch
from torch import Tensor from torch import Tensor
from typing import Dict torch.no_grad()
from .models.vqgan_detokenizer import VQGanDetokenizer from .models.vqgan_detokenizer import VQGanDetokenizer
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch

@ -1,6 +1,7 @@
from typing import List, Tuple
import torch import torch
from torch import LongTensor, nn, FloatTensor, BoolTensor from torch import LongTensor, nn, FloatTensor, BoolTensor
from typing import List, Tuple torch.no_grad()
from .dalle_bart_encoder_torch import GLUTorch, AttentionTorch from .dalle_bart_encoder_torch import GLUTorch, AttentionTorch

@ -1,6 +1,7 @@
from typing import List from typing import List
import torch import torch
from torch import nn, BoolTensor, FloatTensor, LongTensor from torch import nn, BoolTensor, FloatTensor, LongTensor
torch.no_grad()
class GLUTorch(nn.Module): class GLUTorch(nn.Module):

@ -1,6 +1,7 @@
import torch import torch
from torch import Tensor from torch import Tensor
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
torch.no_grad()
BATCH_COUNT: int = 1 BATCH_COUNT: int = 1

@ -1,7 +1,6 @@
#!/bin/bash #!/bin/bash
_pip=$(command -v pip pip3) pip install -r requirements.txt
$_pip install -r requirements.txt
mkdir -p pretrained mkdir -p pretrained
@ -10,7 +9,7 @@ git lfs install
git clone https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384 ./pretrained/vqgan git clone https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384 ./pretrained/vqgan
# download dalle-mini and dalle mega # download dalle-mini and dalle mega
$_pip install wandb pip install wandb
wandb login wandb login
wandb artifact get --root=./pretrained/dalle_bart_mini dalle-mini/dalle-mini/mini-1:v0 wandb artifact get --root=./pretrained/dalle_bart_mini dalle-mini/dalle-mini/mini-1:v0
wandb artifact get --root=./pretrained/dalle_bart_mega dalle-mini/dalle-mini/mega-1-fp16:v14 wandb artifact get --root=./pretrained/dalle_bart_mega dalle-mini/dalle-mini/mega-1-fp16:v14

Loading…
Cancel
Save