torch.no_grad(), cleanup
This commit is contained in:
parent
1e2649148b
commit
aef24ea157
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb)
|
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb)
|
||||||
|
|
||||||
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are `torch` for the torch model and `flax` for the flax model.
|
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are `numpy` and `torch` for the torch model and `flax` for the flax model.
|
||||||
|
|
||||||
### Setup
|
### Setup
|
||||||
|
|
||||||
|
|
124
min_dalle.ipynb
124
min_dalle.ipynb
File diff suppressed because one or more lines are too long
|
@ -1,7 +1,8 @@
|
||||||
import numpy
|
import numpy
|
||||||
|
from typing import Dict
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import Dict
|
torch.no_grad()
|
||||||
|
|
||||||
from .models.vqgan_detokenizer import VQGanDetokenizer
|
from .models.vqgan_detokenizer import VQGanDetokenizer
|
||||||
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
|
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
|
from typing import List, Tuple
|
||||||
import torch
|
import torch
|
||||||
from torch import LongTensor, nn, FloatTensor, BoolTensor
|
from torch import LongTensor, nn, FloatTensor, BoolTensor
|
||||||
from typing import List, Tuple
|
torch.no_grad()
|
||||||
|
|
||||||
from .dalle_bart_encoder_torch import GLUTorch, AttentionTorch
|
from .dalle_bart_encoder_torch import GLUTorch, AttentionTorch
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
import torch
|
import torch
|
||||||
from torch import nn, BoolTensor, FloatTensor, LongTensor
|
from torch import nn, BoolTensor, FloatTensor, LongTensor
|
||||||
|
torch.no_grad()
|
||||||
|
|
||||||
|
|
||||||
class GLUTorch(nn.Module):
|
class GLUTorch(nn.Module):
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
|
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
|
||||||
|
torch.no_grad()
|
||||||
|
|
||||||
BATCH_COUNT: int = 1
|
BATCH_COUNT: int = 1
|
||||||
|
|
||||||
|
|
5
setup.sh
5
setup.sh
|
@ -1,7 +1,6 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
_pip=$(command -v pip pip3)
|
pip install -r requirements.txt
|
||||||
$_pip install -r requirements.txt
|
|
||||||
|
|
||||||
mkdir -p pretrained
|
mkdir -p pretrained
|
||||||
|
|
||||||
|
@ -10,7 +9,7 @@ git lfs install
|
||||||
git clone https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384 ./pretrained/vqgan
|
git clone https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384 ./pretrained/vqgan
|
||||||
|
|
||||||
# download dalle-mini and dalle mega
|
# download dalle-mini and dalle mega
|
||||||
$_pip install wandb
|
pip install wandb
|
||||||
wandb login
|
wandb login
|
||||||
wandb artifact get --root=./pretrained/dalle_bart_mini dalle-mini/dalle-mini/mini-1:v0
|
wandb artifact get --root=./pretrained/dalle_bart_mini dalle-mini/dalle-mini/mini-1:v0
|
||||||
wandb artifact get --root=./pretrained/dalle_bart_mega dalle-mini/dalle-mini/mega-1-fp16:v14
|
wandb artifact get --root=./pretrained/dalle_bart_mega dalle-mini/dalle-mini/mega-1-fp16:v14
|
||||||
|
|
Loading…
Reference in New Issue
Block a user