torch.no_grad(), cleanup
This commit is contained in:
parent
1e2649148b
commit
aef24ea157
|
@ -2,7 +2,7 @@
|
|||
|
||||
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb)
|
||||
|
||||
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are `torch` for the torch model and `flax` for the flax model.
|
||||
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are `numpy` and `torch` for the torch model and `flax` for the flax model.
|
||||
|
||||
### Setup
|
||||
|
||||
|
|
122
min_dalle.ipynb
122
min_dalle.ipynb
File diff suppressed because one or more lines are too long
|
@ -1,7 +1,8 @@
|
|||
import numpy
|
||||
from typing import Dict
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from typing import Dict
|
||||
torch.no_grad()
|
||||
|
||||
from .models.vqgan_detokenizer import VQGanDetokenizer
|
||||
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from typing import List, Tuple
|
||||
import torch
|
||||
from torch import LongTensor, nn, FloatTensor, BoolTensor
|
||||
from typing import List, Tuple
|
||||
torch.no_grad()
|
||||
|
||||
from .dalle_bart_encoder_torch import GLUTorch, AttentionTorch
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from typing import List
|
||||
import torch
|
||||
from torch import nn, BoolTensor, FloatTensor, LongTensor
|
||||
torch.no_grad()
|
||||
|
||||
|
||||
class GLUTorch(nn.Module):
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
|
||||
torch.no_grad()
|
||||
|
||||
BATCH_COUNT: int = 1
|
||||
|
||||
|
|
5
setup.sh
5
setup.sh
|
@ -1,7 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
_pip=$(command -v pip pip3)
|
||||
$_pip install -r requirements.txt
|
||||
pip install -r requirements.txt
|
||||
|
||||
mkdir -p pretrained
|
||||
|
||||
|
@ -10,7 +9,7 @@ git lfs install
|
|||
git clone https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384 ./pretrained/vqgan
|
||||
|
||||
# download dalle-mini and dalle mega
|
||||
$_pip install wandb
|
||||
pip install wandb
|
||||
wandb login
|
||||
wandb artifact get --root=./pretrained/dalle_bart_mini dalle-mini/dalle-mini/mini-1:v0
|
||||
wandb artifact get --root=./pretrained/dalle_bart_mega dalle-mini/dalle-mini/mega-1-fp16:v14
|
||||
|
|
Loading…
Reference in New Issue
Block a user