fixed typing error for older python versions
This commit is contained in:
		| @@ -1,12 +1,10 @@ | |||||||
| import os | import os | ||||||
| from re import I |  | ||||||
| from PIL import Image | from PIL import Image | ||||||
| import numpy | import numpy | ||||||
| from torch import LongTensor | from torch import LongTensor | ||||||
| import torch | import torch | ||||||
| import json | import json | ||||||
| import requests | import requests | ||||||
| import random |  | ||||||
| torch.set_grad_enabled(False) | torch.set_grad_enabled(False) | ||||||
| torch.set_num_threads(os.cpu_count()) | torch.set_num_threads(os.cpu_count()) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,3 +1,4 @@ | |||||||
|  | from typing import Tuple, List | ||||||
| import torch | import torch | ||||||
| from torch import LongTensor, nn, FloatTensor, BoolTensor | from torch import LongTensor, nn, FloatTensor, BoolTensor | ||||||
| torch.set_grad_enabled(False) | torch.set_grad_enabled(False) | ||||||
| @@ -25,7 +26,7 @@ class DecoderSelfAttention(AttentionBase): | |||||||
|         attention_state: FloatTensor, |         attention_state: FloatTensor, | ||||||
|         attention_mask: BoolTensor, |         attention_mask: BoolTensor, | ||||||
|         token_mask: BoolTensor |         token_mask: BoolTensor | ||||||
|     ) -> tuple[FloatTensor, FloatTensor]: |     ) -> Tuple[FloatTensor, FloatTensor]: | ||||||
|         keys = self.k_proj.forward(decoder_state) |         keys = self.k_proj.forward(decoder_state) | ||||||
|         values = self.v_proj.forward(decoder_state) |         values = self.v_proj.forward(decoder_state) | ||||||
|         queries = self.q_proj.forward(decoder_state) |         queries = self.q_proj.forward(decoder_state) | ||||||
| @@ -70,7 +71,7 @@ class DecoderLayer(nn.Module): | |||||||
|         attention_state: FloatTensor, |         attention_state: FloatTensor, | ||||||
|         attention_mask: BoolTensor, |         attention_mask: BoolTensor, | ||||||
|         token_index: LongTensor |         token_index: LongTensor | ||||||
|     ) -> tuple[FloatTensor, FloatTensor]: |     ) -> Tuple[FloatTensor, FloatTensor]: | ||||||
|         # Self Attention |         # Self Attention | ||||||
|         residual = decoder_state |         residual = decoder_state | ||||||
|         decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state) |         decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state) | ||||||
| @@ -125,7 +126,7 @@ class DalleBartDecoder(nn.Module): | |||||||
|         self.image_token_count = image_token_count |         self.image_token_count = image_token_count | ||||||
|         self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count) |         self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count) | ||||||
|         self.embed_positions = nn.Embedding(image_token_count, embed_count) |         self.embed_positions = nn.Embedding(image_token_count, embed_count) | ||||||
|         self.layers: list[DecoderLayer] = nn.ModuleList([ |         self.layers: List[DecoderLayer] = nn.ModuleList([ | ||||||
|             DecoderLayer( |             DecoderLayer( | ||||||
|                 image_token_count, |                 image_token_count, | ||||||
|                 attention_head_count, |                 attention_head_count, | ||||||
| @@ -153,7 +154,7 @@ class DalleBartDecoder(nn.Module): | |||||||
|         attention_state: FloatTensor, |         attention_state: FloatTensor, | ||||||
|         prev_tokens: LongTensor, |         prev_tokens: LongTensor, | ||||||
|         token_index: LongTensor |         token_index: LongTensor | ||||||
|     ) -> tuple[LongTensor, FloatTensor]: |     ) -> Tuple[LongTensor, FloatTensor]: | ||||||
|         image_count = encoder_state.shape[0] // 2 |         image_count = encoder_state.shape[0] // 2 | ||||||
|         token_index_batched = token_index[[0] * image_count * 2] |         token_index_batched = token_index[[0] * image_count * 2] | ||||||
|         prev_tokens = prev_tokens[list(range(image_count)) * 2] |         prev_tokens = prev_tokens[list(range(image_count)) * 2] | ||||||
| @@ -209,7 +210,7 @@ class DalleBartDecoder(nn.Module): | |||||||
|         if torch.cuda.is_available(): attention_state = attention_state.cuda() |         if torch.cuda.is_available(): attention_state = attention_state.cuda() | ||||||
|          |          | ||||||
|         image_tokens = self.start_token[[0] * image_count] |         image_tokens = self.start_token[[0] * image_count] | ||||||
|         image_tokens_sequence: list[LongTensor] = [] |         image_tokens_sequence: List[LongTensor] = [] | ||||||
|         for i in range(self.sample_token_count): |         for i in range(self.sample_token_count): | ||||||
|             probs, attention_state = self.decode_step( |             probs, attention_state = self.decode_step( | ||||||
|                 attention_mask = attention_mask, |                 attention_mask = attention_mask, | ||||||
|   | |||||||
| @@ -1,3 +1,4 @@ | |||||||
|  | from typing import List | ||||||
| import torch | import torch | ||||||
| from torch import nn, BoolTensor, FloatTensor, LongTensor | from torch import nn, BoolTensor, FloatTensor, LongTensor | ||||||
| torch.set_grad_enabled(False) | torch.set_grad_enabled(False) | ||||||
| @@ -120,7 +121,7 @@ class DalleBartEncoder(nn.Module): | |||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.embed_tokens = nn.Embedding(text_vocab_count, embed_count) |         self.embed_tokens = nn.Embedding(text_vocab_count, embed_count) | ||||||
|         self.embed_positions = nn.Embedding(text_token_count, embed_count) |         self.embed_positions = nn.Embedding(text_token_count, embed_count) | ||||||
|         self.layers: list[EncoderLayer] = nn.ModuleList([ |         self.layers: List[EncoderLayer] = nn.ModuleList([ | ||||||
|             EncoderLayer( |             EncoderLayer( | ||||||
|                 embed_count = embed_count, |                 embed_count = embed_count, | ||||||
|                 head_count = attention_head_count, |                 head_count = attention_head_count, | ||||||
|   | |||||||
| @@ -1,13 +1,14 @@ | |||||||
| from math import inf | from math import inf | ||||||
|  | from typing import List, Tuple | ||||||
|  |  | ||||||
| class TextTokenizer: | class TextTokenizer: | ||||||
|     def __init__(self, vocab: dict, merges: list[str], is_verbose: bool = True): |     def __init__(self, vocab: dict, merges: List[str], is_verbose: bool = True): | ||||||
|         self.is_verbose = is_verbose |         self.is_verbose = is_verbose | ||||||
|         self.token_from_subword = vocab |         self.token_from_subword = vocab | ||||||
|         pairs = [tuple(pair.split()) for pair in merges] |         pairs = [tuple(pair.split()) for pair in merges] | ||||||
|         self.rank_from_pair = dict(zip(pairs, range(len(pairs)))) |         self.rank_from_pair = dict(zip(pairs, range(len(pairs)))) | ||||||
|  |  | ||||||
|     def tokenize(self, text: str) -> list[int]: |     def tokenize(self, text: str) -> List[int]: | ||||||
|         sep_token = self.token_from_subword['</s>'] |         sep_token = self.token_from_subword['</s>'] | ||||||
|         cls_token = self.token_from_subword['<s>'] |         cls_token = self.token_from_subword['<s>'] | ||||||
|         unk_token = self.token_from_subword['<unk>'] |         unk_token = self.token_from_subword['<unk>'] | ||||||
| @@ -19,8 +20,8 @@ class TextTokenizer: | |||||||
|         ] |         ] | ||||||
|         return [cls_token] + tokens + [sep_token] |         return [cls_token] + tokens + [sep_token] | ||||||
|  |  | ||||||
|     def get_byte_pair_encoding(self, word: str) -> list[str]: |     def get_byte_pair_encoding(self, word: str) -> List[str]: | ||||||
|         def get_pair_rank(pair: tuple[str, str]) -> int: |         def get_pair_rank(pair: Tuple[str, str]) -> int: | ||||||
|             return self.rank_from_pair.get(pair, inf) |             return self.rank_from_pair.get(pair, inf) | ||||||
|  |  | ||||||
|         subwords = [chr(ord(" ") + 256)] + list(word) |         subwords = [chr(ord(" ") + 256)] + list(word) | ||||||
|   | |||||||
							
								
								
									
										5
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								setup.py
									
									
									
									
									
								
							| @@ -5,7 +5,7 @@ setuptools.setup( | |||||||
|     name='min-dalle', |     name='min-dalle', | ||||||
|     description = 'min(DALL·E)', |     description = 'min(DALL·E)', | ||||||
|     long_description=(Path(__file__).parent / "README").read_text(), |     long_description=(Path(__file__).parent / "README").read_text(), | ||||||
|     version='0.2.6', |     version='0.2.9', | ||||||
|     author='Brett Kuprel', |     author='Brett Kuprel', | ||||||
|     author_email='brkuprel@gmail.com', |     author_email='brkuprel@gmail.com', | ||||||
|     url='https://github.com/kuprel/min-dalle', |     url='https://github.com/kuprel/min-dalle', | ||||||
| @@ -15,7 +15,8 @@ setuptools.setup( | |||||||
|     ], |     ], | ||||||
|     license='MIT', |     license='MIT', | ||||||
|     install_requires=[ |     install_requires=[ | ||||||
|         'torch>=1.10.0' |         'torch>=1.10.0', | ||||||
|  |         'typing_extensions>=4.1.0' | ||||||
|     ], |     ], | ||||||
|     keywords = [ |     keywords = [ | ||||||
|         'artificial intelligence', |         'artificial intelligence', | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user