torch.no_grad(), cleanup

This commit is contained in:
Brett Kuprel 2022-06-28 12:16:44 -04:00
parent 1e2649148b
commit aef24ea157
7 changed files with 24 additions and 115 deletions

View File

@ -2,7 +2,7 @@
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kuprel/min-dalle/blob/main/min_dalle.ipynb)
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are `torch` for the torch model and `flax` for the flax model.
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini). It has been stripped to the bare essentials necessary for doing inference, and converted to PyTorch. The only third party dependencies are `numpy` and `torch` for the torch model and `flax` for the flax model.
### Setup

File diff suppressed because one or more lines are too long

View File

@ -1,7 +1,8 @@
import numpy
from typing import Dict
import torch
from torch import Tensor
from typing import Dict
torch.no_grad()
from .models.vqgan_detokenizer import VQGanDetokenizer
from .models.dalle_bart_encoder_torch import DalleBartEncoderTorch

View File

@ -1,6 +1,7 @@
from typing import List, Tuple
import torch
from torch import LongTensor, nn, FloatTensor, BoolTensor
from typing import List, Tuple
torch.no_grad()
from .dalle_bart_encoder_torch import GLUTorch, AttentionTorch

View File

@ -1,6 +1,7 @@
from typing import List
import torch
from torch import nn, BoolTensor, FloatTensor, LongTensor
torch.no_grad()
class GLUTorch(nn.Module):

View File

@ -1,6 +1,7 @@
import torch
from torch import Tensor
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
torch.no_grad()
BATCH_COUNT: int = 1

View File

@ -1,7 +1,6 @@
#!/bin/bash
_pip=$(command -v pip pip3)
$_pip install -r requirements.txt
pip install -r requirements.txt
mkdir -p pretrained
@ -10,7 +9,7 @@ git lfs install
git clone https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384 ./pretrained/vqgan
# download dalle-mini and dalle mega
$_pip install wandb
pip install wandb
wandb login
wandb artifact get --root=./pretrained/dalle_bart_mini dalle-mini/dalle-mini/mini-1:v0
wandb artifact get --root=./pretrained/dalle_bart_mega dalle-mini/dalle-mini/mega-1-fp16:v14