fixed typing error for older python versions

main
Brett Kuprel 2 years ago
parent 2dadfdfb31
commit 313635e914
  1. 2
      min_dalle/min_dalle.py
  2. 11
      min_dalle/models/dalle_bart_decoder.py
  3. 3
      min_dalle/models/dalle_bart_encoder.py
  4. 9
      min_dalle/text_tokenizer.py
  5. 5
      setup.py

@ -1,12 +1,10 @@
import os
from re import I
from PIL import Image
import numpy
from torch import LongTensor
import torch
import json
import requests
import random
torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count())

@ -1,3 +1,4 @@
from typing import Tuple, List
import torch
from torch import LongTensor, nn, FloatTensor, BoolTensor
torch.set_grad_enabled(False)
@ -25,7 +26,7 @@ class DecoderSelfAttention(AttentionBase):
attention_state: FloatTensor,
attention_mask: BoolTensor,
token_mask: BoolTensor
) -> tuple[FloatTensor, FloatTensor]:
) -> Tuple[FloatTensor, FloatTensor]:
keys = self.k_proj.forward(decoder_state)
values = self.v_proj.forward(decoder_state)
queries = self.q_proj.forward(decoder_state)
@ -70,7 +71,7 @@ class DecoderLayer(nn.Module):
attention_state: FloatTensor,
attention_mask: BoolTensor,
token_index: LongTensor
) -> tuple[FloatTensor, FloatTensor]:
) -> Tuple[FloatTensor, FloatTensor]:
# Self Attention
residual = decoder_state
decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state)
@ -125,7 +126,7 @@ class DalleBartDecoder(nn.Module):
self.image_token_count = image_token_count
self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count)
self.embed_positions = nn.Embedding(image_token_count, embed_count)
self.layers: list[DecoderLayer] = nn.ModuleList([
self.layers: List[DecoderLayer] = nn.ModuleList([
DecoderLayer(
image_token_count,
attention_head_count,
@ -153,7 +154,7 @@ class DalleBartDecoder(nn.Module):
attention_state: FloatTensor,
prev_tokens: LongTensor,
token_index: LongTensor
) -> tuple[LongTensor, FloatTensor]:
) -> Tuple[LongTensor, FloatTensor]:
image_count = encoder_state.shape[0] // 2
token_index_batched = token_index[[0] * image_count * 2]
prev_tokens = prev_tokens[list(range(image_count)) * 2]
@ -209,7 +210,7 @@ class DalleBartDecoder(nn.Module):
if torch.cuda.is_available(): attention_state = attention_state.cuda()
image_tokens = self.start_token[[0] * image_count]
image_tokens_sequence: list[LongTensor] = []
image_tokens_sequence: List[LongTensor] = []
for i in range(self.sample_token_count):
probs, attention_state = self.decode_step(
attention_mask = attention_mask,

@ -1,3 +1,4 @@
from typing import List
import torch
from torch import nn, BoolTensor, FloatTensor, LongTensor
torch.set_grad_enabled(False)
@ -120,7 +121,7 @@ class DalleBartEncoder(nn.Module):
super().__init__()
self.embed_tokens = nn.Embedding(text_vocab_count, embed_count)
self.embed_positions = nn.Embedding(text_token_count, embed_count)
self.layers: list[EncoderLayer] = nn.ModuleList([
self.layers: List[EncoderLayer] = nn.ModuleList([
EncoderLayer(
embed_count = embed_count,
head_count = attention_head_count,

@ -1,13 +1,14 @@
from math import inf
from typing import List, Tuple
class TextTokenizer:
def __init__(self, vocab: dict, merges: list[str], is_verbose: bool = True):
def __init__(self, vocab: dict, merges: List[str], is_verbose: bool = True):
self.is_verbose = is_verbose
self.token_from_subword = vocab
pairs = [tuple(pair.split()) for pair in merges]
self.rank_from_pair = dict(zip(pairs, range(len(pairs))))
def tokenize(self, text: str) -> list[int]:
def tokenize(self, text: str) -> List[int]:
sep_token = self.token_from_subword['</s>']
cls_token = self.token_from_subword['<s>']
unk_token = self.token_from_subword['<unk>']
@ -19,8 +20,8 @@ class TextTokenizer:
]
return [cls_token] + tokens + [sep_token]
def get_byte_pair_encoding(self, word: str) -> list[str]:
def get_pair_rank(pair: tuple[str, str]) -> int:
def get_byte_pair_encoding(self, word: str) -> List[str]:
def get_pair_rank(pair: Tuple[str, str]) -> int:
return self.rank_from_pair.get(pair, inf)
subwords = [chr(ord(" ") + 256)] + list(word)

@ -5,7 +5,7 @@ setuptools.setup(
name='min-dalle',
description = 'min(DALL·E)',
long_description=(Path(__file__).parent / "README").read_text(),
version='0.2.6',
version='0.2.9',
author='Brett Kuprel',
author_email='brkuprel@gmail.com',
url='https://github.com/kuprel/min-dalle',
@ -15,7 +15,8 @@ setuptools.setup(
],
license='MIT',
install_requires=[
'torch>=1.10.0'
'torch>=1.10.0',
'typing_extensions>=4.1.0'
],
keywords = [
'artificial intelligence',

Loading…
Cancel
Save