license and cleanup
This commit is contained in:
@@ -132,7 +132,7 @@ def super_conditioned(logits: jnp.ndarray, a: float) -> jnp.ndarray:
|
||||
return a * logits[0, -1] + (1 - a) * logits[1, -1]
|
||||
|
||||
def keep_top_k(logits: jnp.ndarray, k: int) -> jnp.ndarray:
|
||||
top_logits, top_tokens = lax.top_k(logits, k)
|
||||
top_logits, _ = lax.top_k(logits, k)
|
||||
suppressed = -jnp.inf * jnp.ones_like(logits)
|
||||
return lax.select(logits < top_logits[-1], suppressed, logits)
|
||||
|
||||
@@ -198,34 +198,6 @@ class DalleBartDecoderFlax(nn.Module):
|
||||
decoder_state = self.lm_head(decoder_state)
|
||||
return decoder_state, keys_state, values_state
|
||||
|
||||
def compute_logits(self,
|
||||
text_tokens: jnp.ndarray,
|
||||
encoder_state: jnp.ndarray,
|
||||
params: dict
|
||||
) -> jnp.ndarray:
|
||||
batch_count = encoder_state.shape[0]
|
||||
state_shape = (
|
||||
self.layer_count,
|
||||
batch_count,
|
||||
self.image_token_count,
|
||||
self.attention_head_count,
|
||||
self.embed_count // self.attention_head_count
|
||||
)
|
||||
keys_state = jnp.zeros(state_shape)
|
||||
values_state = jnp.zeros(state_shape)
|
||||
|
||||
logits, _, _ = self.apply(
|
||||
{ 'params': params },
|
||||
encoder_state = encoder_state,
|
||||
keys_state = keys_state,
|
||||
values_state = values_state,
|
||||
attention_mask = jnp.not_equal(text_tokens, 1),
|
||||
prev_token = self.start_token,
|
||||
token_index = 0
|
||||
)
|
||||
|
||||
return super_conditioned(logits, 10.0)
|
||||
|
||||
def sample_image_tokens(self,
|
||||
text_tokens: jnp.ndarray,
|
||||
encoder_state: jnp.ndarray,
|
||||
|
@@ -23,6 +23,7 @@ class GLUFlax(nn.Module):
|
||||
z = self.fc2(z)
|
||||
return z
|
||||
|
||||
|
||||
class AttentionFlax(nn.Module):
|
||||
head_count: int
|
||||
embed_count: int
|
||||
@@ -61,6 +62,7 @@ class AttentionFlax(nn.Module):
|
||||
attention_output = self.out_proj(attention_output)
|
||||
return attention_output
|
||||
|
||||
|
||||
class EncoderSelfAttentionFlax(AttentionFlax):
|
||||
def __call__(
|
||||
self,
|
||||
@@ -74,6 +76,7 @@ class EncoderSelfAttentionFlax(AttentionFlax):
|
||||
queries /= queries.shape[-1] ** 0.5
|
||||
return self.forward(keys, values, queries, attention_mask)
|
||||
|
||||
|
||||
class DalleBartEncoderLayerFlax(nn.Module):
|
||||
attention_head_count: int
|
||||
embed_count: int
|
||||
@@ -103,6 +106,7 @@ class DalleBartEncoderLayerFlax(nn.Module):
|
||||
encoder_state = residual + encoder_state
|
||||
return encoder_state, None
|
||||
|
||||
|
||||
class DalleBartEncoderFlax(nn.Module):
|
||||
attention_head_count: int
|
||||
embed_count: int
|
||||
|
@@ -2,6 +2,7 @@ from typing import List
|
||||
import torch
|
||||
from torch import nn, BoolTensor, FloatTensor, LongTensor
|
||||
|
||||
|
||||
class GLUTorch(nn.Module):
|
||||
def __init__(self, count_in_out, count_middle):
|
||||
super().__init__()
|
||||
@@ -21,6 +22,7 @@ class GLUTorch(nn.Module):
|
||||
z = self.fc2.forward(z)
|
||||
return z
|
||||
|
||||
|
||||
class AttentionTorch(nn.Module):
|
||||
def __init__(self, head_count: int, embed_count: int):
|
||||
super().__init__()
|
||||
|
@@ -2,7 +2,8 @@ import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Module, ModuleList, GroupNorm, Conv2d, Embedding
|
||||
|
||||
batch_size: int = 1
|
||||
BATCH_COUNT: int = 1
|
||||
|
||||
|
||||
class ResnetBlock(Module):
|
||||
def __init__(self, log2_count_in: int, log2_count_out: int):
|
||||
@@ -46,16 +47,16 @@ class AttentionBlock(Module):
|
||||
q = self.q.forward(h)
|
||||
k = self.k.forward(h)
|
||||
v = self.v.forward(h)
|
||||
q = q.reshape(batch_size, n, 2 ** 8)
|
||||
q = q.reshape(BATCH_COUNT, n, 2 ** 8)
|
||||
q = q.permute(0, 2, 1)
|
||||
k = k.reshape(batch_size, n, 2 ** 8)
|
||||
k = k.reshape(BATCH_COUNT, n, 2 ** 8)
|
||||
w = torch.bmm(q, k)
|
||||
w /= n ** 0.5
|
||||
w = torch.softmax(w, dim=2)
|
||||
v = v.reshape(batch_size, n, 2 ** 8)
|
||||
v = v.reshape(BATCH_COUNT, n, 2 ** 8)
|
||||
w = w.permute(0, 2, 1)
|
||||
h = torch.bmm(v, w)
|
||||
h = h.reshape(batch_size, n, 2 ** 4, 2 ** 4)
|
||||
h = h.reshape(BATCH_COUNT, n, 2 ** 4, 2 ** 4)
|
||||
h = self.proj_out.forward(h)
|
||||
return x + h
|
||||
|
||||
@@ -162,14 +163,10 @@ class VQGanDetokenizer(Module):
|
||||
|
||||
def forward(self, z: Tensor) -> Tensor:
|
||||
z = self.embedding.forward(z)
|
||||
z = z.view((batch_size, 2 ** 4, 2 ** 4, 2 ** 8))
|
||||
z = z.view((BATCH_COUNT, 2 ** 4, 2 ** 4, 2 ** 8))
|
||||
z = z.permute(0, 3, 1, 2).contiguous()
|
||||
z = self.post_quant_conv.forward(z)
|
||||
z = self.decoder.forward(z)
|
||||
z = z.permute(0, 2, 3, 1)
|
||||
# z = torch.concat((
|
||||
# torch.concat((z[0], z[1]), axis=1),
|
||||
# torch.concat((z[2], z[3]), axis=1)
|
||||
# ), axis=0)
|
||||
z = z.clip(0.0, 1.0) * 255
|
||||
return z[0]
|
||||
|
Reference in New Issue
Block a user