fixed typing error for older python versions

This commit is contained in:
Brett Kuprel 2022-07-02 09:06:22 -04:00
parent 2dadfdfb31
commit 313635e914
5 changed files with 16 additions and 14 deletions

View File

@ -1,12 +1,10 @@
import os import os
from re import I
from PIL import Image from PIL import Image
import numpy import numpy
from torch import LongTensor from torch import LongTensor
import torch import torch
import json import json
import requests import requests
import random
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count()) torch.set_num_threads(os.cpu_count())

View File

@ -1,3 +1,4 @@
from typing import Tuple, List
import torch import torch
from torch import LongTensor, nn, FloatTensor, BoolTensor from torch import LongTensor, nn, FloatTensor, BoolTensor
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
@ -25,7 +26,7 @@ class DecoderSelfAttention(AttentionBase):
attention_state: FloatTensor, attention_state: FloatTensor,
attention_mask: BoolTensor, attention_mask: BoolTensor,
token_mask: BoolTensor token_mask: BoolTensor
) -> tuple[FloatTensor, FloatTensor]: ) -> Tuple[FloatTensor, FloatTensor]:
keys = self.k_proj.forward(decoder_state) keys = self.k_proj.forward(decoder_state)
values = self.v_proj.forward(decoder_state) values = self.v_proj.forward(decoder_state)
queries = self.q_proj.forward(decoder_state) queries = self.q_proj.forward(decoder_state)
@ -70,7 +71,7 @@ class DecoderLayer(nn.Module):
attention_state: FloatTensor, attention_state: FloatTensor,
attention_mask: BoolTensor, attention_mask: BoolTensor,
token_index: LongTensor token_index: LongTensor
) -> tuple[FloatTensor, FloatTensor]: ) -> Tuple[FloatTensor, FloatTensor]:
# Self Attention # Self Attention
residual = decoder_state residual = decoder_state
decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state) decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state)
@ -125,7 +126,7 @@ class DalleBartDecoder(nn.Module):
self.image_token_count = image_token_count self.image_token_count = image_token_count
self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count) self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count)
self.embed_positions = nn.Embedding(image_token_count, embed_count) self.embed_positions = nn.Embedding(image_token_count, embed_count)
self.layers: list[DecoderLayer] = nn.ModuleList([ self.layers: List[DecoderLayer] = nn.ModuleList([
DecoderLayer( DecoderLayer(
image_token_count, image_token_count,
attention_head_count, attention_head_count,
@ -153,7 +154,7 @@ class DalleBartDecoder(nn.Module):
attention_state: FloatTensor, attention_state: FloatTensor,
prev_tokens: LongTensor, prev_tokens: LongTensor,
token_index: LongTensor token_index: LongTensor
) -> tuple[LongTensor, FloatTensor]: ) -> Tuple[LongTensor, FloatTensor]:
image_count = encoder_state.shape[0] // 2 image_count = encoder_state.shape[0] // 2
token_index_batched = token_index[[0] * image_count * 2] token_index_batched = token_index[[0] * image_count * 2]
prev_tokens = prev_tokens[list(range(image_count)) * 2] prev_tokens = prev_tokens[list(range(image_count)) * 2]
@ -209,7 +210,7 @@ class DalleBartDecoder(nn.Module):
if torch.cuda.is_available(): attention_state = attention_state.cuda() if torch.cuda.is_available(): attention_state = attention_state.cuda()
image_tokens = self.start_token[[0] * image_count] image_tokens = self.start_token[[0] * image_count]
image_tokens_sequence: list[LongTensor] = [] image_tokens_sequence: List[LongTensor] = []
for i in range(self.sample_token_count): for i in range(self.sample_token_count):
probs, attention_state = self.decode_step( probs, attention_state = self.decode_step(
attention_mask = attention_mask, attention_mask = attention_mask,

View File

@ -1,3 +1,4 @@
from typing import List
import torch import torch
from torch import nn, BoolTensor, FloatTensor, LongTensor from torch import nn, BoolTensor, FloatTensor, LongTensor
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
@ -120,7 +121,7 @@ class DalleBartEncoder(nn.Module):
super().__init__() super().__init__()
self.embed_tokens = nn.Embedding(text_vocab_count, embed_count) self.embed_tokens = nn.Embedding(text_vocab_count, embed_count)
self.embed_positions = nn.Embedding(text_token_count, embed_count) self.embed_positions = nn.Embedding(text_token_count, embed_count)
self.layers: list[EncoderLayer] = nn.ModuleList([ self.layers: List[EncoderLayer] = nn.ModuleList([
EncoderLayer( EncoderLayer(
embed_count = embed_count, embed_count = embed_count,
head_count = attention_head_count, head_count = attention_head_count,

View File

@ -1,13 +1,14 @@
from math import inf from math import inf
from typing import List, Tuple
class TextTokenizer: class TextTokenizer:
def __init__(self, vocab: dict, merges: list[str], is_verbose: bool = True): def __init__(self, vocab: dict, merges: List[str], is_verbose: bool = True):
self.is_verbose = is_verbose self.is_verbose = is_verbose
self.token_from_subword = vocab self.token_from_subword = vocab
pairs = [tuple(pair.split()) for pair in merges] pairs = [tuple(pair.split()) for pair in merges]
self.rank_from_pair = dict(zip(pairs, range(len(pairs)))) self.rank_from_pair = dict(zip(pairs, range(len(pairs))))
def tokenize(self, text: str) -> list[int]: def tokenize(self, text: str) -> List[int]:
sep_token = self.token_from_subword['</s>'] sep_token = self.token_from_subword['</s>']
cls_token = self.token_from_subword['<s>'] cls_token = self.token_from_subword['<s>']
unk_token = self.token_from_subword['<unk>'] unk_token = self.token_from_subword['<unk>']
@ -19,8 +20,8 @@ class TextTokenizer:
] ]
return [cls_token] + tokens + [sep_token] return [cls_token] + tokens + [sep_token]
def get_byte_pair_encoding(self, word: str) -> list[str]: def get_byte_pair_encoding(self, word: str) -> List[str]:
def get_pair_rank(pair: tuple[str, str]) -> int: def get_pair_rank(pair: Tuple[str, str]) -> int:
return self.rank_from_pair.get(pair, inf) return self.rank_from_pair.get(pair, inf)
subwords = [chr(ord(" ") + 256)] + list(word) subwords = [chr(ord(" ") + 256)] + list(word)

View File

@ -5,7 +5,7 @@ setuptools.setup(
name='min-dalle', name='min-dalle',
description = 'min(DALL·E)', description = 'min(DALL·E)',
long_description=(Path(__file__).parent / "README").read_text(), long_description=(Path(__file__).parent / "README").read_text(),
version='0.2.6', version='0.2.9',
author='Brett Kuprel', author='Brett Kuprel',
author_email='brkuprel@gmail.com', author_email='brkuprel@gmail.com',
url='https://github.com/kuprel/min-dalle', url='https://github.com/kuprel/min-dalle',
@ -15,7 +15,8 @@ setuptools.setup(
], ],
license='MIT', license='MIT',
install_requires=[ install_requires=[
'torch>=1.10.0' 'torch>=1.10.0',
'typing_extensions>=4.1.0'
], ],
keywords = [ keywords = [
'artificial intelligence', 'artificial intelligence',