license and cleanup
This commit is contained in:
parent
32b7aa196b
commit
18e6a9852f
7
LICENSE
Normal file
7
LICENSE
Normal file
|
@ -0,0 +1,7 @@
|
|||
Copyright 2022 Brett Kuprel
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
@ -2,11 +2,11 @@
|
|||
|
||||
This is a minimal implementation of [DALL·E Mini](https://github.com/borisdayma/dalle-mini) in both Flax and PyTorch
|
||||
|
||||
## Setup
|
||||
### Setup
|
||||
|
||||
Run `sh setup.sh` to install dependencies and download pretrained models. The only required dependencies are Flax and Torch. In the ash script, GitHub LFS is used to download VQGan detokenizer and the Weight and Biases python package is used to download the DALL·E Mini and DALL·E Mega transformer models. You can also download those files manually by visting the links in the bash script.
|
||||
Run `sh setup.sh` to install dependencies and download pretrained models. The only required dependencies are `flax` and `torch`. In the bash script, GitHub LFS is used to download the VQGan detokenizer and the Weight & Biases python package is used to download the DALL·E Mini and DALL·E Mega transformer models. You can also download those files manually by visting the links in the bash script.
|
||||
|
||||
## Run
|
||||
### Run
|
||||
|
||||
Here are some examples
|
||||
|
||||
|
|
|
@ -132,7 +132,7 @@ def super_conditioned(logits: jnp.ndarray, a: float) -> jnp.ndarray:
|
|||
return a * logits[0, -1] + (1 - a) * logits[1, -1]
|
||||
|
||||
def keep_top_k(logits: jnp.ndarray, k: int) -> jnp.ndarray:
|
||||
top_logits, top_tokens = lax.top_k(logits, k)
|
||||
top_logits, _ = lax.top_k(logits, k)
|
||||
suppressed = -jnp.inf * jnp.ones_like(logits)
|
||||
return lax.select(logits < top_logits[-1], suppressed, logits)
|
||||
|
||||
|
@ -198,34 +198,6 @@ class DalleBartDecoderFlax(nn.Module):
|
|||
decoder_state = self.lm_head(decoder_state)
|
||||
return decoder_state, keys_state, values_state
|
||||
|
||||
def compute_logits(self,
|
||||
text_tokens: jnp.ndarray,
|
||||
encoder_state: jnp.ndarray,
|
||||
params: dict
|
||||
) -> jnp.ndarray:
|
||||
batch_count = encoder_state.shape[0]
|
||||
state_shape = (
|
||||
self.layer_count,
|
||||
batch_count,
|
||||
self.image_token_count,
|
||||
self.attention_head_count,
|
||||
self.embed_count // self.attention_head_count
|
||||
)
|
||||
keys_state = jnp.zeros(state_shape)
|
||||
values_state = jnp.zeros(state_shape)
|
||||
|
||||
logits, _, _ = self.apply(
|
||||
{ 'params': params },
|
||||
encoder_state = encoder_state,
|
||||
keys_state = keys_state,
|
||||
values_state = values_state,
|
||||
attention_mask = jnp.not_equal(text_tokens, 1),
|
||||
prev_token = self.start_token,
|
||||
token_index = 0
|
||||
)
|
||||
|
||||
return super_conditioned(logits, 10.0)
|
||||
|
||||
def sample_image_tokens(self,
|
||||
text_tokens: jnp.ndarray,
|
||||
encoder_state: jnp.ndarray,
|
||||
|
|
|
@ -23,6 +23,7 @@ class GLUFlax(nn.Module):
|
|||
z = self.fc2(z)
|
||||
return z
|
||||
|
||||
|
||||
class AttentionFlax(nn.Module):
|
||||
head_count: int
|
||||
embed_count: int
|
||||
|
@ -61,6 +62,7 @@ class AttentionFlax(nn.Module):
|
|||
attention_output = self.out_proj(attention_output)
|
||||
return attention_output
|
||||
|
||||
|
||||
class EncoderSelfAttentionFlax(AttentionFlax):
|
||||
def __call__(
|
||||
self,
|
||||
|
@ -74,6 +76,7 @@ class EncoderSelfAttentionFlax(AttentionFlax):
|
|||
queries /= queries.shape[-1] ** 0.5
|
||||
return self.forward(keys, values, queries, attention_mask)
|
||||
|
||||
|
||||
class DalleBartEncoderLayerFlax(nn.Module):
|
||||
attention_head_count: int
|
||||
embed_count: int
|
||||
|
@ -103,6 +106,7 @@ class DalleBartEncoderLayerFlax(nn.Module):
|
|||
encoder_state = residual + encoder_state
|
||||
return encoder_state, None
|
||||
|
||||
|
||||
class DalleBartEncoderFlax(nn.Module):
|
||||
attention_head_count: int
|
||||
embed_count: int
|
||||
|
|
|
@ -2,6 +2,7 @@ from typing import List
|
|||
import torch
|
||||
from torch import nn, BoolTensor, FloatTensor, LongTensor
|
||||
|
||||
|
||||
class GLUTorch(nn.Module):
|
||||
def __init__(self, count_in_out, count_middle):
|
||||
super().__init__()
|
||||
|
@ -21,6 +22,7 @@ class GLUTorch(nn.Module):
|
|||
z = self.fc2.forward(z)
|
||||
return z
|
||||
|
||||
|
||||
class AttentionTorch(nn.Module):
|
||||
def __init__(self, head_count: int, embed_count: int):
|
||||
super().__init__()
|
||||
|
|
|
@ -2,7 +2,8 @@ import torch
|
|||
from torch import Tensor
|
||||
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
|
||||
|
||||
batch_size: int = 1
|
||||
BATCH_COUNT: int = 1
|
||||
|
||||
|
||||
class ResnetBlock(Module):
|
||||
def __init__(self, log2_count_in: int, log2_count_out: int):
|
||||
|
@ -46,16 +47,16 @@ class AttentionBlock(Module):
|
|||
q = self.q.forward(h)
|
||||
k = self.k.forward(h)
|
||||
v = self.v.forward(h)
|
||||
q = q.reshape(batch_size, n, 2 ** 8)
|
||||
q = q.reshape(BATCH_COUNT, n, 2 ** 8)
|
||||
q = q.permute(0, 2, 1)
|
||||
k = k.reshape(batch_size, n, 2 ** 8)
|
||||
k = k.reshape(BATCH_COUNT, n, 2 ** 8)
|
||||
w = torch.bmm(q, k)
|
||||
w /= n ** 0.5
|
||||
w = torch.softmax(w, dim=2)
|
||||
v = v.reshape(batch_size, n, 2 ** 8)
|
||||
v = v.reshape(BATCH_COUNT, n, 2 ** 8)
|
||||
w = w.permute(0, 2, 1)
|
||||
h = torch.bmm(v, w)
|
||||
h = h.reshape(batch_size, n, 2 ** 4, 2 ** 4)
|
||||
h = h.reshape(BATCH_COUNT, n, 2 ** 4, 2 ** 4)
|
||||
h = self.proj_out.forward(h)
|
||||
return x + h
|
||||
|
||||
|
@ -162,14 +163,10 @@ class VQGanDetokenizer(Module):
|
|||
|
||||
def forward(self, z: Tensor) -> Tensor:
|
||||
z = self.embedding.forward(z)
|
||||
z = z.view((batch_size, 2 ** 4, 2 ** 4, 2 ** 8))
|
||||
z = z.view((BATCH_COUNT, 2 ** 4, 2 ** 4, 2 ** 8))
|
||||
z = z.permute(0, 3, 1, 2).contiguous()
|
||||
z = self.post_quant_conv.forward(z)
|
||||
z = self.decoder.forward(z)
|
||||
z = z.permute(0, 2, 3, 1)
|
||||
# z = torch.concat((
|
||||
# torch.concat((z[0], z[1]), axis=1),
|
||||
# torch.concat((z[2], z[3]), axis=1)
|
||||
# ), axis=0)
|
||||
z = z.clip(0.0, 1.0) * 255
|
||||
return z[0]
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from math import inf
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class TextTokenizer:
|
||||
def __init__(self, vocab: dict, merges: List[str]):
|
||||
self.token_from_subword = vocab
|
||||
|
|
Loading…
Reference in New Issue
Block a user