diff --git a/min_dalle/min_dalle.py b/min_dalle/min_dalle.py index f7bb6ef..c47ad39 100644 --- a/min_dalle/min_dalle.py +++ b/min_dalle/min_dalle.py @@ -1,12 +1,10 @@ import os -from re import I from PIL import Image import numpy from torch import LongTensor import torch import json import requests -import random torch.set_grad_enabled(False) torch.set_num_threads(os.cpu_count()) diff --git a/min_dalle/models/dalle_bart_decoder.py b/min_dalle/models/dalle_bart_decoder.py index ffbf51c..ce18c93 100644 --- a/min_dalle/models/dalle_bart_decoder.py +++ b/min_dalle/models/dalle_bart_decoder.py @@ -1,3 +1,4 @@ +from typing import Tuple, List import torch from torch import LongTensor, nn, FloatTensor, BoolTensor torch.set_grad_enabled(False) @@ -25,7 +26,7 @@ class DecoderSelfAttention(AttentionBase): attention_state: FloatTensor, attention_mask: BoolTensor, token_mask: BoolTensor - ) -> tuple[FloatTensor, FloatTensor]: + ) -> Tuple[FloatTensor, FloatTensor]: keys = self.k_proj.forward(decoder_state) values = self.v_proj.forward(decoder_state) queries = self.q_proj.forward(decoder_state) @@ -70,7 +71,7 @@ class DecoderLayer(nn.Module): attention_state: FloatTensor, attention_mask: BoolTensor, token_index: LongTensor - ) -> tuple[FloatTensor, FloatTensor]: + ) -> Tuple[FloatTensor, FloatTensor]: # Self Attention residual = decoder_state decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state) @@ -125,7 +126,7 @@ class DalleBartDecoder(nn.Module): self.image_token_count = image_token_count self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count) self.embed_positions = nn.Embedding(image_token_count, embed_count) - self.layers: list[DecoderLayer] = nn.ModuleList([ + self.layers: List[DecoderLayer] = nn.ModuleList([ DecoderLayer( image_token_count, attention_head_count, @@ -153,7 +154,7 @@ class DalleBartDecoder(nn.Module): attention_state: FloatTensor, prev_tokens: LongTensor, token_index: LongTensor - ) -> tuple[LongTensor, FloatTensor]: + ) -> Tuple[LongTensor, FloatTensor]: image_count = encoder_state.shape[0] // 2 token_index_batched = token_index[[0] * image_count * 2] prev_tokens = prev_tokens[list(range(image_count)) * 2] @@ -209,7 +210,7 @@ class DalleBartDecoder(nn.Module): if torch.cuda.is_available(): attention_state = attention_state.cuda() image_tokens = self.start_token[[0] * image_count] - image_tokens_sequence: list[LongTensor] = [] + image_tokens_sequence: List[LongTensor] = [] for i in range(self.sample_token_count): probs, attention_state = self.decode_step( attention_mask = attention_mask, diff --git a/min_dalle/models/dalle_bart_encoder.py b/min_dalle/models/dalle_bart_encoder.py index 7d18a65..a96cd6b 100644 --- a/min_dalle/models/dalle_bart_encoder.py +++ b/min_dalle/models/dalle_bart_encoder.py @@ -1,3 +1,4 @@ +from typing import List import torch from torch import nn, BoolTensor, FloatTensor, LongTensor torch.set_grad_enabled(False) @@ -120,7 +121,7 @@ class DalleBartEncoder(nn.Module): super().__init__() self.embed_tokens = nn.Embedding(text_vocab_count, embed_count) self.embed_positions = nn.Embedding(text_token_count, embed_count) - self.layers: list[EncoderLayer] = nn.ModuleList([ + self.layers: List[EncoderLayer] = nn.ModuleList([ EncoderLayer( embed_count = embed_count, head_count = attention_head_count, diff --git a/min_dalle/text_tokenizer.py b/min_dalle/text_tokenizer.py index 0cc0a55..f2f201a 100644 --- a/min_dalle/text_tokenizer.py +++ b/min_dalle/text_tokenizer.py @@ -1,13 +1,14 @@ from math import inf +from typing import List, Tuple class TextTokenizer: - def __init__(self, vocab: dict, merges: list[str], is_verbose: bool = True): + def __init__(self, vocab: dict, merges: List[str], is_verbose: bool = True): self.is_verbose = is_verbose self.token_from_subword = vocab pairs = [tuple(pair.split()) for pair in merges] self.rank_from_pair = dict(zip(pairs, range(len(pairs)))) - def tokenize(self, text: str) -> list[int]: + def tokenize(self, text: str) -> List[int]: sep_token = self.token_from_subword[''] cls_token = self.token_from_subword[''] unk_token = self.token_from_subword[''] @@ -19,8 +20,8 @@ class TextTokenizer: ] return [cls_token] + tokens + [sep_token] - def get_byte_pair_encoding(self, word: str) -> list[str]: - def get_pair_rank(pair: tuple[str, str]) -> int: + def get_byte_pair_encoding(self, word: str) -> List[str]: + def get_pair_rank(pair: Tuple[str, str]) -> int: return self.rank_from_pair.get(pair, inf) subwords = [chr(ord(" ") + 256)] + list(word) diff --git a/setup.py b/setup.py index 2eca2d6..aced9b4 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name='min-dalle', description = 'min(DALLĀ·E)', long_description=(Path(__file__).parent / "README").read_text(), - version='0.2.6', + version='0.2.9', author='Brett Kuprel', author_email='brkuprel@gmail.com', url='https://github.com/kuprel/min-dalle', @@ -15,7 +15,8 @@ setuptools.setup( ], license='MIT', install_requires=[ - 'torch>=1.10.0' + 'torch>=1.10.0', + 'typing_extensions>=4.1.0' ], keywords = [ 'artificial intelligence',