fix typing

This commit is contained in:
Brett Kuprel 2022-07-07 17:18:30 -04:00
parent 2cac9220b5
commit 736904ef2f
3 changed files with 14 additions and 14 deletions

View File

@ -2,7 +2,7 @@ import os
from PIL import Image from PIL import Image
from matplotlib.pyplot import grid from matplotlib.pyplot import grid
import numpy import numpy
from torch import LongTensor from torch import LongTensor, FloatTensor
from math import sqrt from math import sqrt
import torch import torch
import json import json
@ -148,7 +148,7 @@ class MinDalle:
self, self,
image_tokens: LongTensor, image_tokens: LongTensor,
is_verbose: bool = False is_verbose: bool = False
) -> LongTensor: ) -> FloatTensor:
if not self.is_reusable: del self.decoder if not self.is_reusable: del self.decoder
if torch.cuda.is_available(): torch.cuda.empty_cache() if torch.cuda.is_available(): torch.cuda.empty_cache()
if not self.is_reusable: self.init_detokenizer() if not self.is_reusable: self.init_detokenizer()
@ -158,7 +158,7 @@ class MinDalle:
return images return images
def grid_from_images(self, images: LongTensor) -> Image.Image: def grid_from_images(self, images: FloatTensor) -> Image.Image:
grid_size = int(sqrt(images.shape[0])) grid_size = int(sqrt(images.shape[0]))
images = images.reshape([grid_size] * 2 + list(images.shape[1:])) images = images.reshape([grid_size] * 2 + list(images.shape[1:]))
image = images.flatten(1, 2).transpose(0, 1).flatten(1, 2) image = images.flatten(1, 2).transpose(0, 1).flatten(1, 2)
@ -175,7 +175,7 @@ class MinDalle:
log2_k: int = 6, log2_k: int = 6,
log2_supercondition_factor: int = 3, log2_supercondition_factor: int = 3,
is_verbose: bool = False is_verbose: bool = False
) -> Iterator[LongTensor]: ) -> Iterator[FloatTensor]:
assert(log2_mid_count in range(5)) assert(log2_mid_count in range(5))
if is_verbose: print("tokenizing text") if is_verbose: print("tokenizing text")
tokens = self.tokenizer.tokenize(text, is_verbose=is_verbose) tokens = self.tokenizer.tokenize(text, is_verbose=is_verbose)
@ -260,7 +260,7 @@ class MinDalle:
log2_k: int = 6, log2_k: int = 6,
log2_supercondition_factor: int = 3, log2_supercondition_factor: int = 3,
is_verbose: bool = False is_verbose: bool = False
) -> LongTensor: ) -> FloatTensor:
log2_mid_count = 0 log2_mid_count = 0
images_stream = self.generate_images_stream( images_stream = self.generate_images_stream(
text, text,

View File

@ -1,6 +1,6 @@
from typing import Tuple, List from typing import Tuple, List
import torch import torch
from torch import LongTensor, nn, FloatTensor, BoolTensor from torch import nn, LongTensor, FloatTensor, BoolTensor
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
from .dalle_bart_encoder import GLU, AttentionBase from .dalle_bart_encoder import GLU, AttentionBase

View File

@ -1,5 +1,5 @@
import torch import torch
from torch import Tensor from torch import FloatTensor
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
@ -16,7 +16,7 @@ class ResnetBlock(Module):
if not self.is_middle: if not self.is_middle:
self.nin_shortcut = Conv2d(m, n, 1) self.nin_shortcut = Conv2d(m, n, 1)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: FloatTensor) -> FloatTensor:
h = x h = x
h = self.norm1.forward(h) h = self.norm1.forward(h)
h *= torch.sigmoid(h) h *= torch.sigmoid(h)
@ -39,7 +39,7 @@ class AttentionBlock(Module):
self.v = Conv2d(n, n, 1) self.v = Conv2d(n, n, 1)
self.proj_out = Conv2d(n, n, 1) self.proj_out = Conv2d(n, n, 1)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: FloatTensor) -> FloatTensor:
n, m = 2 ** 9, x.shape[0] n, m = 2 ** 9, x.shape[0]
h = x h = x
h = self.norm(h) h = self.norm(h)
@ -67,7 +67,7 @@ class MiddleLayer(Module):
self.attn_1 = AttentionBlock() self.attn_1 = AttentionBlock()
self.block_2 = ResnetBlock(9, 9) self.block_2 = ResnetBlock(9, 9)
def forward(self, h: Tensor) -> Tensor: def forward(self, h: FloatTensor) -> FloatTensor:
h = self.block_1.forward(h) h = self.block_1.forward(h)
h = self.attn_1.forward(h) h = self.attn_1.forward(h)
h = self.block_2.forward(h) h = self.block_2.forward(h)
@ -81,7 +81,7 @@ class Upsample(Module):
self.upsample = torch.nn.UpsamplingNearest2d(scale_factor=2) self.upsample = torch.nn.UpsamplingNearest2d(scale_factor=2)
self.conv = Conv2d(n, n, 3, padding=1) self.conv = Conv2d(n, n, 3, padding=1)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: FloatTensor) -> FloatTensor:
x = self.upsample.forward(x.to(torch.float32)) x = self.upsample.forward(x.to(torch.float32))
x = self.conv.forward(x) x = self.conv.forward(x)
return x return x
@ -116,7 +116,7 @@ class UpsampleBlock(Module):
self.upsample = Upsample(log2_count_out) self.upsample = Upsample(log2_count_out)
def forward(self, h: Tensor) -> Tensor: def forward(self, h: FloatTensor) -> FloatTensor:
for j in range(3): for j in range(3):
h = self.block[j].forward(h) h = self.block[j].forward(h)
if self.has_attention: if self.has_attention:
@ -144,7 +144,7 @@ class Decoder(Module):
self.norm_out = GroupNorm(2 ** 5, 2 ** 7) self.norm_out = GroupNorm(2 ** 5, 2 ** 7)
self.conv_out = Conv2d(2 ** 7, 3, 3, padding=1) self.conv_out = Conv2d(2 ** 7, 3, 3, padding=1)
def forward(self, z: Tensor) -> Tensor: def forward(self, z: FloatTensor) -> FloatTensor:
z = self.conv_in.forward(z) z = self.conv_in.forward(z)
z = self.mid.forward(z) z = self.mid.forward(z)
@ -165,7 +165,7 @@ class VQGanDetokenizer(Module):
self.post_quant_conv = Conv2d(n, n, 1) self.post_quant_conv = Conv2d(n, n, 1)
self.decoder = Decoder() self.decoder = Decoder()
def forward(self, z: Tensor) -> Tensor: def forward(self, z: FloatTensor) -> FloatTensor:
z = self.embedding.forward(z) z = self.embedding.forward(z)
z = z.view((z.shape[0], 2 ** 4, 2 ** 4, 2 ** 8)) z = z.view((z.shape[0], 2 ** 4, 2 ** 4, 2 ** 8))
z = z.permute(0, 3, 1, 2).contiguous() z = z.permute(0, 3, 1, 2).contiguous()