first commit
This commit is contained in:
282
comfy/ldm/audio/autoencoder.py
Normal file
282
comfy/ldm/audio/autoencoder.py
Normal file
@@ -0,0 +1,282 @@
|
||||
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing import Literal, Dict, Any
|
||||
import math
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def vae_sample(mean, scale):
|
||||
stdev = nn.functional.softplus(scale) + 1e-4
|
||||
var = stdev * stdev
|
||||
logvar = torch.log(var)
|
||||
latents = torch.randn_like(mean) * stdev + mean
|
||||
|
||||
kl = (mean * mean + var - logvar - 1).sum(1).mean()
|
||||
|
||||
return latents, kl
|
||||
|
||||
class VAEBottleneck(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.is_discrete = False
|
||||
|
||||
def encode(self, x, return_info=False, **kwargs):
|
||||
info = {}
|
||||
|
||||
mean, scale = x.chunk(2, dim=1)
|
||||
|
||||
x, kl = vae_sample(mean, scale)
|
||||
|
||||
info["kl"] = kl
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
else:
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
|
||||
def snake_beta(x, alpha, beta):
|
||||
return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
|
||||
|
||||
# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
|
||||
class SnakeBeta(nn.Module):
|
||||
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
|
||||
super(SnakeBeta, self).__init__()
|
||||
self.in_features = in_features
|
||||
|
||||
# initialize alpha
|
||||
self.alpha_logscale = alpha_logscale
|
||||
if self.alpha_logscale: # log scale alphas initialized to zeros
|
||||
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
|
||||
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
|
||||
else: # linear scale alphas initialized to ones
|
||||
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
|
||||
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
|
||||
|
||||
# self.alpha.requires_grad = alpha_trainable
|
||||
# self.beta.requires_grad = alpha_trainable
|
||||
|
||||
self.no_div_by_zero = 0.000000001
|
||||
|
||||
def forward(self, x):
|
||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1).to(x.device) # line up with x to [B, C, T]
|
||||
beta = self.beta.unsqueeze(0).unsqueeze(-1).to(x.device)
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
beta = torch.exp(beta)
|
||||
x = snake_beta(x, alpha, beta)
|
||||
|
||||
return x
|
||||
|
||||
def WNConv1d(*args, **kwargs):
|
||||
try:
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
|
||||
except:
|
||||
return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older
|
||||
|
||||
def WNConvTranspose1d(*args, **kwargs):
|
||||
try:
|
||||
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
|
||||
except:
|
||||
return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older
|
||||
|
||||
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
||||
if activation == "elu":
|
||||
act = torch.nn.ELU()
|
||||
elif activation == "snake":
|
||||
act = SnakeBeta(channels)
|
||||
elif activation == "none":
|
||||
act = torch.nn.Identity()
|
||||
else:
|
||||
raise ValueError(f"Unknown activation {activation}")
|
||||
|
||||
if antialias:
|
||||
act = Activation1d(act)
|
||||
|
||||
return act
|
||||
|
||||
|
||||
class ResidualUnit(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
|
||||
super().__init__()
|
||||
|
||||
self.dilation = dilation
|
||||
|
||||
padding = (dilation * (7-1)) // 2
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
||||
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
||||
kernel_size=7, dilation=dilation, padding=padding),
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
||||
WNConv1d(in_channels=out_channels, out_channels=out_channels,
|
||||
kernel_size=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
res = x
|
||||
|
||||
#x = checkpoint(self.layers, x)
|
||||
x = self.layers(x)
|
||||
|
||||
return x + res
|
||||
|
||||
class EncoderBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
|
||||
super().__init__()
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
ResidualUnit(in_channels=in_channels,
|
||||
out_channels=in_channels, dilation=1, use_snake=use_snake),
|
||||
ResidualUnit(in_channels=in_channels,
|
||||
out_channels=in_channels, dilation=3, use_snake=use_snake),
|
||||
ResidualUnit(in_channels=in_channels,
|
||||
out_channels=in_channels, dilation=9, use_snake=use_snake),
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
||||
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
||||
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
|
||||
super().__init__()
|
||||
|
||||
if use_nearest_upsample:
|
||||
upsample_layer = nn.Sequential(
|
||||
nn.Upsample(scale_factor=stride, mode="nearest"),
|
||||
WNConv1d(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=2*stride,
|
||||
stride=1,
|
||||
bias=False,
|
||||
padding='same')
|
||||
)
|
||||
else:
|
||||
upsample_layer = WNConvTranspose1d(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
||||
upsample_layer,
|
||||
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||
dilation=1, use_snake=use_snake),
|
||||
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||
dilation=3, use_snake=use_snake),
|
||||
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||
dilation=9, use_snake=use_snake),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
class OobleckEncoder(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels=2,
|
||||
channels=128,
|
||||
latent_dim=32,
|
||||
c_mults = [1, 2, 4, 8],
|
||||
strides = [2, 4, 8, 8],
|
||||
use_snake=False,
|
||||
antialias_activation=False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
c_mults = [1] + c_mults
|
||||
|
||||
self.depth = len(c_mults)
|
||||
|
||||
layers = [
|
||||
WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
|
||||
]
|
||||
|
||||
for i in range(self.depth-1):
|
||||
layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
|
||||
|
||||
layers += [
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
|
||||
WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
|
||||
]
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
class OobleckDecoder(nn.Module):
|
||||
def __init__(self,
|
||||
out_channels=2,
|
||||
channels=128,
|
||||
latent_dim=32,
|
||||
c_mults = [1, 2, 4, 8],
|
||||
strides = [2, 4, 8, 8],
|
||||
use_snake=False,
|
||||
antialias_activation=False,
|
||||
use_nearest_upsample=False,
|
||||
final_tanh=True):
|
||||
super().__init__()
|
||||
|
||||
c_mults = [1] + c_mults
|
||||
|
||||
self.depth = len(c_mults)
|
||||
|
||||
layers = [
|
||||
WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
|
||||
]
|
||||
|
||||
for i in range(self.depth-1, 0, -1):
|
||||
layers += [DecoderBlock(
|
||||
in_channels=c_mults[i]*channels,
|
||||
out_channels=c_mults[i-1]*channels,
|
||||
stride=strides[i-1],
|
||||
use_snake=use_snake,
|
||||
antialias_activation=antialias_activation,
|
||||
use_nearest_upsample=use_nearest_upsample
|
||||
)
|
||||
]
|
||||
|
||||
layers += [
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
|
||||
WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
|
||||
nn.Tanh() if final_tanh else nn.Identity()
|
||||
]
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
class AudioOobleckVAE(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels=2,
|
||||
channels=128,
|
||||
latent_dim=64,
|
||||
c_mults = [1, 2, 4, 8, 16],
|
||||
strides = [2, 4, 4, 8, 8],
|
||||
use_snake=True,
|
||||
antialias_activation=False,
|
||||
use_nearest_upsample=False,
|
||||
final_tanh=False):
|
||||
super().__init__()
|
||||
self.encoder = OobleckEncoder(in_channels, channels, latent_dim * 2, c_mults, strides, use_snake, antialias_activation)
|
||||
self.decoder = OobleckDecoder(in_channels, channels, latent_dim, c_mults, strides, use_snake, antialias_activation,
|
||||
use_nearest_upsample=use_nearest_upsample, final_tanh=final_tanh)
|
||||
self.bottleneck = VAEBottleneck()
|
||||
|
||||
def encode(self, x):
|
||||
return self.bottleneck.encode(self.encoder(x))
|
||||
|
||||
def decode(self, x):
|
||||
return self.decoder(self.bottleneck.decode(x))
|
||||
|
||||
891
comfy/ldm/audio/dit.py
Normal file
891
comfy/ldm/audio/dit.py
Normal file
@@ -0,0 +1,891 @@
|
||||
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import typing as tp
|
||||
|
||||
import torch
|
||||
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
import math
|
||||
import comfy.ops
|
||||
|
||||
class FourierFeatures(nn.Module):
|
||||
def __init__(self, in_features, out_features, std=1., dtype=None, device=None):
|
||||
super().__init__()
|
||||
assert out_features % 2 == 0
|
||||
self.weight = nn.Parameter(torch.empty(
|
||||
[out_features // 2, in_features], dtype=dtype, device=device))
|
||||
|
||||
def forward(self, input):
|
||||
f = 2 * math.pi * input @ comfy.ops.cast_to_input(self.weight.T, input)
|
||||
return torch.cat([f.cos(), f.sin()], dim=-1)
|
||||
|
||||
# norms
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, bias=False, fix_scale=False, dtype=None, device=None):
|
||||
"""
|
||||
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.gamma = nn.Parameter(torch.empty(dim, dtype=dtype, device=device))
|
||||
|
||||
if bias:
|
||||
self.beta = nn.Parameter(torch.empty(dim, dtype=dtype, device=device))
|
||||
else:
|
||||
self.beta = None
|
||||
|
||||
def forward(self, x):
|
||||
beta = self.beta
|
||||
if beta is not None:
|
||||
beta = comfy.ops.cast_to_input(beta, x)
|
||||
return F.layer_norm(x, x.shape[-1:], weight=comfy.ops.cast_to_input(self.gamma, x), bias=beta)
|
||||
|
||||
class GLU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_in,
|
||||
dim_out,
|
||||
activation,
|
||||
use_conv = False,
|
||||
conv_kernel_size = 3,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.act = activation
|
||||
self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2), dtype=dtype, device=device)
|
||||
self.use_conv = use_conv
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_conv:
|
||||
x = rearrange(x, 'b n d -> b d n')
|
||||
x = self.proj(x)
|
||||
x = rearrange(x, 'b d n -> b n d')
|
||||
else:
|
||||
x = self.proj(x)
|
||||
|
||||
x, gate = x.chunk(2, dim = -1)
|
||||
return x * self.act(gate)
|
||||
|
||||
class AbsolutePositionalEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_seq_len):
|
||||
super().__init__()
|
||||
self.scale = dim ** -0.5
|
||||
self.max_seq_len = max_seq_len
|
||||
self.emb = nn.Embedding(max_seq_len, dim)
|
||||
|
||||
def forward(self, x, pos = None, seq_start_pos = None):
|
||||
seq_len, device = x.shape[1], x.device
|
||||
assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
|
||||
|
||||
if pos is None:
|
||||
pos = torch.arange(seq_len, device = device)
|
||||
|
||||
if seq_start_pos is not None:
|
||||
pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
|
||||
|
||||
pos_emb = self.emb(pos)
|
||||
pos_emb = pos_emb * self.scale
|
||||
return pos_emb
|
||||
|
||||
class ScaledSinusoidalEmbedding(nn.Module):
|
||||
def __init__(self, dim, theta = 10000):
|
||||
super().__init__()
|
||||
assert (dim % 2) == 0, 'dimension must be divisible by 2'
|
||||
self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
|
||||
|
||||
half_dim = dim // 2
|
||||
freq_seq = torch.arange(half_dim).float() / half_dim
|
||||
inv_freq = theta ** -freq_seq
|
||||
self.register_buffer('inv_freq', inv_freq, persistent = False)
|
||||
|
||||
def forward(self, x, pos = None, seq_start_pos = None):
|
||||
seq_len, device = x.shape[1], x.device
|
||||
|
||||
if pos is None:
|
||||
pos = torch.arange(seq_len, device = device)
|
||||
|
||||
if seq_start_pos is not None:
|
||||
pos = pos - seq_start_pos[..., None]
|
||||
|
||||
emb = torch.einsum('i, j -> i j', pos, self.inv_freq)
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
|
||||
return emb * self.scale
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
use_xpos = False,
|
||||
scale_base = 512,
|
||||
interpolation_factor = 1.,
|
||||
base = 10000,
|
||||
base_rescale_factor = 1.,
|
||||
dtype=None,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
||||
# has some connection to NTK literature
|
||||
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||
base *= base_rescale_factor ** (dim / (dim - 2))
|
||||
|
||||
# inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer('inv_freq', torch.empty((dim // 2,), device=device, dtype=dtype))
|
||||
|
||||
assert interpolation_factor >= 1.
|
||||
self.interpolation_factor = interpolation_factor
|
||||
|
||||
if not use_xpos:
|
||||
self.register_buffer('scale', None)
|
||||
return
|
||||
|
||||
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
||||
|
||||
self.scale_base = scale_base
|
||||
self.register_buffer('scale', scale)
|
||||
|
||||
def forward_from_seq_len(self, seq_len, device, dtype):
|
||||
# device = self.inv_freq.device
|
||||
|
||||
t = torch.arange(seq_len, device=device, dtype=dtype)
|
||||
return self.forward(t)
|
||||
|
||||
def forward(self, t):
|
||||
# device = self.inv_freq.device
|
||||
device = t.device
|
||||
dtype = t.dtype
|
||||
|
||||
# t = t.to(torch.float32)
|
||||
|
||||
t = t / self.interpolation_factor
|
||||
|
||||
freqs = torch.einsum('i , j -> i j', t, comfy.ops.cast_to_input(self.inv_freq, t))
|
||||
freqs = torch.cat((freqs, freqs), dim = -1)
|
||||
|
||||
if self.scale is None:
|
||||
return freqs, 1.
|
||||
|
||||
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
|
||||
scale = comfy.ops.cast_to_input(self.scale, t) ** rearrange(power, 'n -> n 1')
|
||||
scale = torch.cat((scale, scale), dim = -1)
|
||||
|
||||
return freqs, scale
|
||||
|
||||
def rotate_half(x):
|
||||
x = rearrange(x, '... (j d) -> ... j d', j = 2)
|
||||
x1, x2 = x.unbind(dim = -2)
|
||||
return torch.cat((-x2, x1), dim = -1)
|
||||
|
||||
def apply_rotary_pos_emb(t, freqs, scale = 1):
|
||||
out_dtype = t.dtype
|
||||
|
||||
# cast to float32 if necessary for numerical stability
|
||||
dtype = t.dtype #reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
|
||||
rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
|
||||
freqs, t = freqs.to(dtype), t.to(dtype)
|
||||
freqs = freqs[-seq_len:, :]
|
||||
|
||||
if t.ndim == 4 and freqs.ndim == 3:
|
||||
freqs = rearrange(freqs, 'b n d -> b 1 n d')
|
||||
|
||||
# partial rotary embeddings, Wang et al. GPT-J
|
||||
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
|
||||
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
||||
|
||||
t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
|
||||
|
||||
return torch.cat((t, t_unrotated), dim = -1)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_out = None,
|
||||
mult = 4,
|
||||
no_bias = False,
|
||||
glu = True,
|
||||
use_conv = False,
|
||||
conv_kernel_size = 3,
|
||||
zero_init_output = True,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
|
||||
# Default to SwiGLU
|
||||
|
||||
activation = nn.SiLU()
|
||||
|
||||
dim_out = dim if dim_out is None else dim_out
|
||||
|
||||
if glu:
|
||||
linear_in = GLU(dim, inner_dim, activation, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
linear_in = nn.Sequential(
|
||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
operations.Linear(dim, inner_dim, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device),
|
||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
activation
|
||||
)
|
||||
|
||||
linear_out = operations.Linear(inner_dim, dim_out, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device)
|
||||
|
||||
# # init last linear layer to 0
|
||||
# if zero_init_output:
|
||||
# nn.init.zeros_(linear_out.weight)
|
||||
# if not no_bias:
|
||||
# nn.init.zeros_(linear_out.bias)
|
||||
|
||||
|
||||
self.ff = nn.Sequential(
|
||||
linear_in,
|
||||
Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
|
||||
linear_out,
|
||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.ff(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_heads = 64,
|
||||
dim_context = None,
|
||||
causal = False,
|
||||
zero_init_output=True,
|
||||
qk_norm = False,
|
||||
natten_kernel_size = None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_heads = dim_heads
|
||||
self.causal = causal
|
||||
|
||||
dim_kv = dim_context if dim_context is not None else dim
|
||||
|
||||
self.num_heads = dim // dim_heads
|
||||
self.kv_heads = dim_kv // dim_heads
|
||||
|
||||
if dim_context is not None:
|
||||
self.to_q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_kv = operations.Linear(dim_kv, dim_kv * 2, bias=False, dtype=dtype, device=device)
|
||||
else:
|
||||
self.to_qkv = operations.Linear(dim, dim * 3, bias=False, dtype=dtype, device=device)
|
||||
|
||||
self.to_out = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
# if zero_init_output:
|
||||
# nn.init.zeros_(self.to_out.weight)
|
||||
|
||||
self.qk_norm = qk_norm
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context = None,
|
||||
mask = None,
|
||||
context_mask = None,
|
||||
rotary_pos_emb = None,
|
||||
causal = None
|
||||
):
|
||||
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
|
||||
|
||||
kv_input = context if has_context else x
|
||||
|
||||
if hasattr(self, 'to_q'):
|
||||
# Use separate linear projections for q and k/v
|
||||
q = self.to_q(x)
|
||||
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
|
||||
|
||||
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||
|
||||
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
|
||||
else:
|
||||
# Use fused linear projection
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
||||
|
||||
# Normalize q and k for cosine sim attention
|
||||
if self.qk_norm:
|
||||
q = F.normalize(q, dim=-1)
|
||||
k = F.normalize(k, dim=-1)
|
||||
|
||||
if rotary_pos_emb is not None and not has_context:
|
||||
freqs, _ = rotary_pos_emb
|
||||
|
||||
q_dtype = q.dtype
|
||||
k_dtype = k.dtype
|
||||
|
||||
q = q.to(torch.float32)
|
||||
k = k.to(torch.float32)
|
||||
freqs = freqs.to(torch.float32)
|
||||
|
||||
q = apply_rotary_pos_emb(q, freqs)
|
||||
k = apply_rotary_pos_emb(k, freqs)
|
||||
|
||||
q = q.to(q_dtype)
|
||||
k = k.to(k_dtype)
|
||||
|
||||
input_mask = context_mask
|
||||
|
||||
if input_mask is None and not has_context:
|
||||
input_mask = mask
|
||||
|
||||
# determine masking
|
||||
masks = []
|
||||
final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
|
||||
|
||||
if input_mask is not None:
|
||||
input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
|
||||
masks.append(~input_mask)
|
||||
|
||||
# Other masks will be added here later
|
||||
|
||||
if len(masks) > 0:
|
||||
final_attn_mask = ~or_reduce(masks)
|
||||
|
||||
n, device = q.shape[-2], q.device
|
||||
|
||||
causal = self.causal if causal is None else causal
|
||||
|
||||
if n == 1 and causal:
|
||||
causal = False
|
||||
|
||||
if h != kv_h:
|
||||
# Repeat interleave kv_heads to match q_heads
|
||||
heads_per_kv_head = h // kv_h
|
||||
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
|
||||
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True)
|
||||
out = self.to_out(out)
|
||||
|
||||
if mask is not None:
|
||||
mask = rearrange(mask, 'b n -> b n 1')
|
||||
out = out.masked_fill(~mask, 0.)
|
||||
|
||||
return out
|
||||
|
||||
class ConformerModule(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
norm_kwargs = {},
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
|
||||
self.in_norm = LayerNorm(dim, **norm_kwargs)
|
||||
self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
||||
self.glu = GLU(dim, dim, nn.SiLU())
|
||||
self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
|
||||
self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
|
||||
self.swish = nn.SiLU()
|
||||
self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.in_norm(x)
|
||||
x = rearrange(x, 'b n d -> b d n')
|
||||
x = self.pointwise_conv(x)
|
||||
x = rearrange(x, 'b d n -> b n d')
|
||||
x = self.glu(x)
|
||||
x = rearrange(x, 'b n d -> b d n')
|
||||
x = self.depthwise_conv(x)
|
||||
x = rearrange(x, 'b d n -> b n d')
|
||||
x = self.mid_norm(x)
|
||||
x = self.swish(x)
|
||||
x = rearrange(x, 'b n d -> b d n')
|
||||
x = self.pointwise_conv_2(x)
|
||||
x = rearrange(x, 'b d n -> b n d')
|
||||
|
||||
return x
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_heads = 64,
|
||||
cross_attend = False,
|
||||
dim_context = None,
|
||||
global_cond_dim = None,
|
||||
causal = False,
|
||||
zero_init_branch_outputs = True,
|
||||
conformer = False,
|
||||
layer_ix = -1,
|
||||
remove_norms = False,
|
||||
attn_kwargs = {},
|
||||
ff_kwargs = {},
|
||||
norm_kwargs = {},
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_heads = dim_heads
|
||||
self.cross_attend = cross_attend
|
||||
self.dim_context = dim_context
|
||||
self.causal = causal
|
||||
|
||||
self.pre_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
|
||||
|
||||
self.self_attn = Attention(
|
||||
dim,
|
||||
dim_heads = dim_heads,
|
||||
causal = causal,
|
||||
zero_init_output=zero_init_branch_outputs,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
**attn_kwargs
|
||||
)
|
||||
|
||||
if cross_attend:
|
||||
self.cross_attend_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
|
||||
self.cross_attn = Attention(
|
||||
dim,
|
||||
dim_heads = dim_heads,
|
||||
dim_context=dim_context,
|
||||
causal = causal,
|
||||
zero_init_output=zero_init_branch_outputs,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
**attn_kwargs
|
||||
)
|
||||
|
||||
self.ff_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
|
||||
self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, dtype=dtype, device=device, operations=operations,**ff_kwargs)
|
||||
|
||||
self.layer_ix = layer_ix
|
||||
|
||||
self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None
|
||||
|
||||
self.global_cond_dim = global_cond_dim
|
||||
|
||||
if global_cond_dim is not None:
|
||||
self.to_scale_shift_gate = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(global_cond_dim, dim * 6, bias=False)
|
||||
)
|
||||
|
||||
nn.init.zeros_(self.to_scale_shift_gate[1].weight)
|
||||
#nn.init.zeros_(self.to_scale_shift_gate_self[1].bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context = None,
|
||||
global_cond=None,
|
||||
mask = None,
|
||||
context_mask = None,
|
||||
rotary_pos_emb = None
|
||||
):
|
||||
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
|
||||
|
||||
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1)
|
||||
|
||||
# self-attention with adaLN
|
||||
residual = x
|
||||
x = self.pre_norm(x)
|
||||
x = x * (1 + scale_self) + shift_self
|
||||
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
|
||||
x = x * torch.sigmoid(1 - gate_self)
|
||||
x = x + residual
|
||||
|
||||
if context is not None:
|
||||
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
|
||||
|
||||
if self.conformer is not None:
|
||||
x = x + self.conformer(x)
|
||||
|
||||
# feedforward with adaLN
|
||||
residual = x
|
||||
x = self.ff_norm(x)
|
||||
x = x * (1 + scale_ff) + shift_ff
|
||||
x = self.ff(x)
|
||||
x = x * torch.sigmoid(1 - gate_ff)
|
||||
x = x + residual
|
||||
|
||||
else:
|
||||
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
|
||||
|
||||
if context is not None:
|
||||
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
|
||||
|
||||
if self.conformer is not None:
|
||||
x = x + self.conformer(x)
|
||||
|
||||
x = x + self.ff(self.ff_norm(x))
|
||||
|
||||
return x
|
||||
|
||||
class ContinuousTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
depth,
|
||||
*,
|
||||
dim_in = None,
|
||||
dim_out = None,
|
||||
dim_heads = 64,
|
||||
cross_attend=False,
|
||||
cond_token_dim=None,
|
||||
global_cond_dim=None,
|
||||
causal=False,
|
||||
rotary_pos_emb=True,
|
||||
zero_init_branch_outputs=True,
|
||||
conformer=False,
|
||||
use_sinusoidal_emb=False,
|
||||
use_abs_pos_emb=False,
|
||||
abs_pos_emb_max_length=10000,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.depth = depth
|
||||
self.causal = causal
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
self.project_in = operations.Linear(dim_in, dim, bias=False, dtype=dtype, device=device) if dim_in is not None else nn.Identity()
|
||||
self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity()
|
||||
|
||||
if rotary_pos_emb:
|
||||
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32), device=device, dtype=dtype)
|
||||
else:
|
||||
self.rotary_pos_emb = None
|
||||
|
||||
self.use_sinusoidal_emb = use_sinusoidal_emb
|
||||
if use_sinusoidal_emb:
|
||||
self.pos_emb = ScaledSinusoidalEmbedding(dim)
|
||||
|
||||
self.use_abs_pos_emb = use_abs_pos_emb
|
||||
if use_abs_pos_emb:
|
||||
self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length)
|
||||
|
||||
for i in range(depth):
|
||||
self.layers.append(
|
||||
TransformerBlock(
|
||||
dim,
|
||||
dim_heads = dim_heads,
|
||||
cross_attend = cross_attend,
|
||||
dim_context = cond_token_dim,
|
||||
global_cond_dim = global_cond_dim,
|
||||
causal = causal,
|
||||
zero_init_branch_outputs = zero_init_branch_outputs,
|
||||
conformer=conformer,
|
||||
layer_ix=i,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
**kwargs
|
||||
)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
mask = None,
|
||||
prepend_embeds = None,
|
||||
prepend_mask = None,
|
||||
global_cond = None,
|
||||
return_info = False,
|
||||
**kwargs
|
||||
):
|
||||
batch, seq, device = *x.shape[:2], x.device
|
||||
|
||||
info = {
|
||||
"hidden_states": [],
|
||||
}
|
||||
|
||||
x = self.project_in(x)
|
||||
|
||||
if prepend_embeds is not None:
|
||||
prepend_length, prepend_dim = prepend_embeds.shape[1:]
|
||||
|
||||
assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
|
||||
|
||||
x = torch.cat((prepend_embeds, x), dim = -2)
|
||||
|
||||
if prepend_mask is not None or mask is not None:
|
||||
mask = mask if mask is not None else torch.ones((batch, seq), device = device, dtype = torch.bool)
|
||||
prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), device = device, dtype = torch.bool)
|
||||
|
||||
mask = torch.cat((prepend_mask, mask), dim = -1)
|
||||
|
||||
# Attention layers
|
||||
|
||||
if self.rotary_pos_emb is not None:
|
||||
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=x.dtype, device=x.device)
|
||||
else:
|
||||
rotary_pos_emb = None
|
||||
|
||||
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
|
||||
x = x + self.pos_emb(x)
|
||||
|
||||
# Iterate over the transformer layers
|
||||
for layer in self.layers:
|
||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
|
||||
|
||||
if return_info:
|
||||
info["hidden_states"].append(x)
|
||||
|
||||
x = self.project_out(x)
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
|
||||
return x
|
||||
|
||||
class AudioDiffusionTransformer(nn.Module):
|
||||
def __init__(self,
|
||||
io_channels=64,
|
||||
patch_size=1,
|
||||
embed_dim=1536,
|
||||
cond_token_dim=768,
|
||||
project_cond_tokens=False,
|
||||
global_cond_dim=1536,
|
||||
project_global_cond=True,
|
||||
input_concat_dim=0,
|
||||
prepend_cond_dim=0,
|
||||
depth=24,
|
||||
num_heads=24,
|
||||
transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer",
|
||||
global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
|
||||
audio_model="",
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.dtype = dtype
|
||||
self.cond_token_dim = cond_token_dim
|
||||
|
||||
# Timestep embeddings
|
||||
timestep_features_dim = 256
|
||||
|
||||
self.timestep_features = FourierFeatures(1, timestep_features_dim, dtype=dtype, device=device)
|
||||
|
||||
self.to_timestep_embed = nn.Sequential(
|
||||
operations.Linear(timestep_features_dim, embed_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
if cond_token_dim > 0:
|
||||
# Conditioning tokens
|
||||
|
||||
cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
|
||||
self.to_cond_embed = nn.Sequential(
|
||||
operations.Linear(cond_token_dim, cond_embed_dim, bias=False, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(cond_embed_dim, cond_embed_dim, bias=False, dtype=dtype, device=device)
|
||||
)
|
||||
else:
|
||||
cond_embed_dim = 0
|
||||
|
||||
if global_cond_dim > 0:
|
||||
# Global conditioning
|
||||
global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
|
||||
self.to_global_embed = nn.Sequential(
|
||||
operations.Linear(global_cond_dim, global_embed_dim, bias=False, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(global_embed_dim, global_embed_dim, bias=False, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
if prepend_cond_dim > 0:
|
||||
# Prepend conditioning
|
||||
self.to_prepend_embed = nn.Sequential(
|
||||
operations.Linear(prepend_cond_dim, embed_dim, bias=False, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
self.input_concat_dim = input_concat_dim
|
||||
|
||||
dim_in = io_channels + self.input_concat_dim
|
||||
|
||||
self.patch_size = patch_size
|
||||
|
||||
# Transformer
|
||||
|
||||
self.transformer_type = transformer_type
|
||||
|
||||
self.global_cond_type = global_cond_type
|
||||
|
||||
if self.transformer_type == "continuous_transformer":
|
||||
|
||||
global_dim = None
|
||||
|
||||
if self.global_cond_type == "adaLN":
|
||||
# The global conditioning is projected to the embed_dim already at this point
|
||||
global_dim = embed_dim
|
||||
|
||||
self.transformer = ContinuousTransformer(
|
||||
dim=embed_dim,
|
||||
depth=depth,
|
||||
dim_heads=embed_dim // num_heads,
|
||||
dim_in=dim_in * patch_size,
|
||||
dim_out=io_channels * patch_size,
|
||||
cross_attend = cond_token_dim > 0,
|
||||
cond_token_dim = cond_embed_dim,
|
||||
global_cond_dim=global_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown transformer type: {self.transformer_type}")
|
||||
|
||||
self.preprocess_conv = operations.Conv1d(dim_in, dim_in, 1, bias=False, dtype=dtype, device=device)
|
||||
self.postprocess_conv = operations.Conv1d(io_channels, io_channels, 1, bias=False, dtype=dtype, device=device)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
mask=None,
|
||||
cross_attn_cond=None,
|
||||
cross_attn_cond_mask=None,
|
||||
input_concat_cond=None,
|
||||
global_embed=None,
|
||||
prepend_cond=None,
|
||||
prepend_cond_mask=None,
|
||||
return_info=False,
|
||||
**kwargs):
|
||||
|
||||
if cross_attn_cond is not None:
|
||||
cross_attn_cond = self.to_cond_embed(cross_attn_cond)
|
||||
|
||||
if global_embed is not None:
|
||||
# Project the global conditioning to the embedding dimension
|
||||
global_embed = self.to_global_embed(global_embed)
|
||||
|
||||
prepend_inputs = None
|
||||
prepend_mask = None
|
||||
prepend_length = 0
|
||||
if prepend_cond is not None:
|
||||
# Project the prepend conditioning to the embedding dimension
|
||||
prepend_cond = self.to_prepend_embed(prepend_cond)
|
||||
|
||||
prepend_inputs = prepend_cond
|
||||
if prepend_cond_mask is not None:
|
||||
prepend_mask = prepend_cond_mask
|
||||
|
||||
if input_concat_cond is not None:
|
||||
|
||||
# Interpolate input_concat_cond to the same length as x
|
||||
if input_concat_cond.shape[2] != x.shape[2]:
|
||||
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
|
||||
|
||||
x = torch.cat([x, input_concat_cond], dim=1)
|
||||
|
||||
# Get the batch of timestep embeddings
|
||||
timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None]).to(x.dtype)) # (b, embed_dim)
|
||||
|
||||
# Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
|
||||
if global_embed is not None:
|
||||
global_embed = global_embed + timestep_embed
|
||||
else:
|
||||
global_embed = timestep_embed
|
||||
|
||||
# Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
|
||||
if self.global_cond_type == "prepend":
|
||||
if prepend_inputs is None:
|
||||
# Prepend inputs are just the global embed, and the mask is all ones
|
||||
prepend_inputs = global_embed.unsqueeze(1)
|
||||
prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
|
||||
else:
|
||||
# Prepend inputs are the prepend conditioning + the global embed
|
||||
prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
|
||||
prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1)
|
||||
|
||||
prepend_length = prepend_inputs.shape[1]
|
||||
|
||||
x = self.preprocess_conv(x) + x
|
||||
|
||||
x = rearrange(x, "b c t -> b t c")
|
||||
|
||||
extra_args = {}
|
||||
|
||||
if self.global_cond_type == "adaLN":
|
||||
extra_args["global_cond"] = global_embed
|
||||
|
||||
if self.patch_size > 1:
|
||||
x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
|
||||
|
||||
if self.transformer_type == "x-transformers":
|
||||
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs)
|
||||
elif self.transformer_type == "continuous_transformer":
|
||||
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs)
|
||||
|
||||
if return_info:
|
||||
output, info = output
|
||||
elif self.transformer_type == "mm_transformer":
|
||||
output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask, **extra_args, **kwargs)
|
||||
|
||||
output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:]
|
||||
|
||||
if self.patch_size > 1:
|
||||
output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
|
||||
|
||||
output = self.postprocess_conv(output) + output
|
||||
|
||||
if return_info:
|
||||
return output, info
|
||||
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timestep,
|
||||
context=None,
|
||||
context_mask=None,
|
||||
input_concat_cond=None,
|
||||
global_embed=None,
|
||||
negative_global_embed=None,
|
||||
prepend_cond=None,
|
||||
prepend_cond_mask=None,
|
||||
mask=None,
|
||||
return_info=False,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
**kwargs):
|
||||
return self._forward(
|
||||
x,
|
||||
timestep,
|
||||
cross_attn_cond=context,
|
||||
cross_attn_cond_mask=context_mask,
|
||||
input_concat_cond=input_concat_cond,
|
||||
global_embed=global_embed,
|
||||
prepend_cond=prepend_cond,
|
||||
prepend_cond_mask=prepend_cond_mask,
|
||||
mask=mask,
|
||||
return_info=return_info,
|
||||
**kwargs
|
||||
)
|
||||
108
comfy/ldm/audio/embedders.py
Normal file
108
comfy/ldm/audio/embedders.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor, einsum
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
|
||||
from einops import rearrange
|
||||
import math
|
||||
import comfy.ops
|
||||
|
||||
class LearnedPositionalEmbedding(nn.Module):
|
||||
"""Used for continuous time"""
|
||||
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
assert (dim % 2) == 0
|
||||
half_dim = dim // 2
|
||||
self.weights = nn.Parameter(torch.empty(half_dim))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = rearrange(x, "b -> b 1")
|
||||
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi
|
||||
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
|
||||
fouriered = torch.cat((x, fouriered), dim=-1)
|
||||
return fouriered
|
||||
|
||||
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
|
||||
return nn.Sequential(
|
||||
LearnedPositionalEmbedding(dim),
|
||||
comfy.ops.manual_cast.Linear(in_features=dim + 1, out_features=out_features),
|
||||
)
|
||||
|
||||
|
||||
class NumberEmbedder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
features: int,
|
||||
dim: int = 256,
|
||||
):
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
|
||||
|
||||
def forward(self, x: Union[List[float], Tensor]) -> Tensor:
|
||||
if not torch.is_tensor(x):
|
||||
device = next(self.embedding.parameters()).device
|
||||
x = torch.tensor(x, device=device)
|
||||
assert isinstance(x, Tensor)
|
||||
shape = x.shape
|
||||
x = rearrange(x, "... -> (...)")
|
||||
embedding = self.embedding(x)
|
||||
x = embedding.view(*shape, self.features)
|
||||
return x # type: ignore
|
||||
|
||||
|
||||
class Conditioner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
output_dim: int,
|
||||
project_out: bool = False
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.output_dim = output_dim
|
||||
self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
raise NotImplementedError()
|
||||
|
||||
class NumberConditioner(Conditioner):
|
||||
'''
|
||||
Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings
|
||||
'''
|
||||
def __init__(self,
|
||||
output_dim: int,
|
||||
min_val: float=0,
|
||||
max_val: float=1
|
||||
):
|
||||
super().__init__(output_dim, output_dim)
|
||||
|
||||
self.min_val = min_val
|
||||
self.max_val = max_val
|
||||
|
||||
self.embedder = NumberEmbedder(features=output_dim)
|
||||
|
||||
def forward(self, floats, device=None):
|
||||
# Cast the inputs to floats
|
||||
floats = [float(x) for x in floats]
|
||||
|
||||
if device is None:
|
||||
device = next(self.embedder.parameters()).device
|
||||
|
||||
floats = torch.tensor(floats).to(device)
|
||||
|
||||
floats = floats.clamp(self.min_val, self.max_val)
|
||||
|
||||
normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val)
|
||||
|
||||
# Cast floats to same type as embedder
|
||||
embedder_dtype = next(self.embedder.parameters()).dtype
|
||||
normalized_floats = normalized_floats.to(embedder_dtype)
|
||||
|
||||
float_embeds = self.embedder(normalized_floats).unsqueeze(1)
|
||||
|
||||
return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)]
|
||||
480
comfy/ldm/aura/mmdit.py
Normal file
480
comfy/ldm/aura/mmdit.py
Normal file
@@ -0,0 +1,480 @@
|
||||
#AuraFlow MMDiT
|
||||
#Originally written by the AuraFlow Authors
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.ops
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
def find_multiple(n: int, k: int) -> int:
|
||||
if n % k == 0:
|
||||
return n
|
||||
return n + k - (n % k)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim=None, dtype=None, device=None, operations=None) -> None:
|
||||
super().__init__()
|
||||
if hidden_dim is None:
|
||||
hidden_dim = 4 * dim
|
||||
|
||||
n_hidden = int(2 * hidden_dim / 3)
|
||||
n_hidden = find_multiple(n_hidden, 256)
|
||||
|
||||
self.c_fc1 = operations.Linear(dim, n_hidden, bias=False, dtype=dtype, device=device)
|
||||
self.c_fc2 = operations.Linear(dim, n_hidden, bias=False, dtype=dtype, device=device)
|
||||
self.c_proj = operations.Linear(n_hidden, dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
|
||||
x = self.c_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class MultiHeadLayerNorm(nn.Module):
|
||||
def __init__(self, hidden_size=None, eps=1e-5, dtype=None, device=None):
|
||||
# Copy pasta from https://github.com/huggingface/transformers/blob/e5f71ecaae50ea476d1e12351003790273c4b2ed/src/transformers/models/cohere/modeling_cohere.py#L78
|
||||
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.empty(hidden_size, dtype=dtype, device=device))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
mean = hidden_states.mean(-1, keepdim=True)
|
||||
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = (hidden_states - mean) * torch.rsqrt(
|
||||
variance + self.variance_epsilon
|
||||
)
|
||||
hidden_states = self.weight.to(torch.float32) * hidden_states
|
||||
return hidden_states.to(input_dtype)
|
||||
|
||||
class SingleAttention(nn.Module):
|
||||
def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.n_heads = n_heads
|
||||
self.head_dim = dim // n_heads
|
||||
|
||||
# this is for cond
|
||||
self.w1q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.w1k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.w1v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.w1o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
self.q_norm1 = (
|
||||
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
|
||||
if mh_qknorm
|
||||
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
)
|
||||
self.k_norm1 = (
|
||||
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
|
||||
if mh_qknorm
|
||||
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, c):
|
||||
|
||||
bsz, seqlen1, _ = c.shape
|
||||
|
||||
q, k, v = self.w1q(c), self.w1k(c), self.w1v(c)
|
||||
q = q.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
||||
k = k.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
||||
v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
||||
q, k = self.q_norm1(q), self.k_norm1(k)
|
||||
|
||||
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
|
||||
c = self.w1o(output)
|
||||
return c
|
||||
|
||||
|
||||
|
||||
class DoubleAttention(nn.Module):
|
||||
def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.n_heads = n_heads
|
||||
self.head_dim = dim // n_heads
|
||||
|
||||
# this is for cond
|
||||
self.w1q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.w1k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.w1v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.w1o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
# this is for x
|
||||
self.w2q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.w2k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.w2v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.w2o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
self.q_norm1 = (
|
||||
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
|
||||
if mh_qknorm
|
||||
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
)
|
||||
self.k_norm1 = (
|
||||
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
|
||||
if mh_qknorm
|
||||
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
self.q_norm2 = (
|
||||
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
|
||||
if mh_qknorm
|
||||
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
)
|
||||
self.k_norm2 = (
|
||||
MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
|
||||
if mh_qknorm
|
||||
else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, c, x):
|
||||
|
||||
bsz, seqlen1, _ = c.shape
|
||||
bsz, seqlen2, _ = x.shape
|
||||
seqlen = seqlen1 + seqlen2
|
||||
|
||||
cq, ck, cv = self.w1q(c), self.w1k(c), self.w1v(c)
|
||||
cq = cq.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
||||
ck = ck.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
||||
cv = cv.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
||||
cq, ck = self.q_norm1(cq), self.k_norm1(ck)
|
||||
|
||||
xq, xk, xv = self.w2q(x), self.w2k(x), self.w2v(x)
|
||||
xq = xq.view(bsz, seqlen2, self.n_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seqlen2, self.n_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seqlen2, self.n_heads, self.head_dim)
|
||||
xq, xk = self.q_norm2(xq), self.k_norm2(xk)
|
||||
|
||||
# concat all
|
||||
q, k, v = (
|
||||
torch.cat([cq, xq], dim=1),
|
||||
torch.cat([ck, xk], dim=1),
|
||||
torch.cat([cv, xv], dim=1),
|
||||
)
|
||||
|
||||
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
|
||||
|
||||
c, x = output.split([seqlen1, seqlen2], dim=1)
|
||||
c = self.w1o(c)
|
||||
x = self.w2o(x)
|
||||
|
||||
return c, x
|
||||
|
||||
|
||||
class MMDiTBlock(nn.Module):
|
||||
def __init__(self, dim, heads=8, global_conddim=1024, is_last=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.normC1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
self.normC2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
if not is_last:
|
||||
self.mlpC = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
|
||||
self.modC = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
|
||||
)
|
||||
else:
|
||||
self.modC = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(global_conddim, 2 * dim, bias=False, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.normX1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
self.normX2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
self.mlpX = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
|
||||
self.modX = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.attn = DoubleAttention(dim, heads, dtype=dtype, device=device, operations=operations)
|
||||
self.is_last = is_last
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, c, x, global_cond, **kwargs):
|
||||
|
||||
cres, xres = c, x
|
||||
|
||||
cshift_msa, cscale_msa, cgate_msa, cshift_mlp, cscale_mlp, cgate_mlp = (
|
||||
self.modC(global_cond).chunk(6, dim=1)
|
||||
)
|
||||
|
||||
c = modulate(self.normC1(c), cshift_msa, cscale_msa)
|
||||
|
||||
# xpath
|
||||
xshift_msa, xscale_msa, xgate_msa, xshift_mlp, xscale_mlp, xgate_mlp = (
|
||||
self.modX(global_cond).chunk(6, dim=1)
|
||||
)
|
||||
|
||||
x = modulate(self.normX1(x), xshift_msa, xscale_msa)
|
||||
|
||||
# attention
|
||||
c, x = self.attn(c, x)
|
||||
|
||||
|
||||
c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
|
||||
c = cgate_mlp.unsqueeze(1) * self.mlpC(modulate(c, cshift_mlp, cscale_mlp))
|
||||
c = cres + c
|
||||
|
||||
x = self.normX2(xres + xgate_msa.unsqueeze(1) * x)
|
||||
x = xgate_mlp.unsqueeze(1) * self.mlpX(modulate(x, xshift_mlp, xscale_mlp))
|
||||
x = xres + x
|
||||
|
||||
return c, x
|
||||
|
||||
class DiTBlock(nn.Module):
|
||||
# like MMDiTBlock, but it only has X
|
||||
def __init__(self, dim, heads=8, global_conddim=1024, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
|
||||
|
||||
self.modCX = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.attn = SingleAttention(dim, heads, dtype=dtype, device=device, operations=operations)
|
||||
self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, cx, global_cond, **kwargs):
|
||||
cxres = cx
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
|
||||
global_cond
|
||||
).chunk(6, dim=1)
|
||||
cx = modulate(self.norm1(cx), shift_msa, scale_msa)
|
||||
cx = self.attn(cx)
|
||||
cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
|
||||
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
|
||||
cx = gate_mlp.unsqueeze(1) * mlpout
|
||||
|
||||
cx = cxres + cx
|
||||
|
||||
return cx
|
||||
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(frequency_embedding_size, hidden_size, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = 1000 * torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half) / half
|
||||
).to(t.device)
|
||||
args = t[:, None] * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||
)
|
||||
return embedding
|
||||
|
||||
#@torch.compile()
|
||||
def forward(self, t, dtype):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class MMDiT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
patch_size=2,
|
||||
dim=3072,
|
||||
n_layers=36,
|
||||
n_double_layers=4,
|
||||
n_heads=12,
|
||||
global_conddim=3072,
|
||||
cond_seq_dim=2048,
|
||||
max_seq=32 * 32,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
|
||||
self.t_embedder = TimestepEmbedder(global_conddim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.cond_seq_linear = operations.Linear(
|
||||
cond_seq_dim, dim, bias=False, dtype=dtype, device=device
|
||||
) # linear for something like text sequence.
|
||||
self.init_x_linear = operations.Linear(
|
||||
patch_size * patch_size * in_channels, dim, dtype=dtype, device=device
|
||||
) # init linear for patchified image.
|
||||
|
||||
self.positional_encoding = nn.Parameter(torch.empty(1, max_seq, dim, dtype=dtype, device=device))
|
||||
self.register_tokens = nn.Parameter(torch.empty(1, 8, dim, dtype=dtype, device=device))
|
||||
|
||||
self.double_layers = nn.ModuleList([])
|
||||
self.single_layers = nn.ModuleList([])
|
||||
|
||||
|
||||
for idx in range(n_double_layers):
|
||||
self.double_layers.append(
|
||||
MMDiTBlock(dim, n_heads, global_conddim, is_last=(idx == n_layers - 1), dtype=dtype, device=device, operations=operations)
|
||||
)
|
||||
|
||||
for idx in range(n_double_layers, n_layers):
|
||||
self.single_layers.append(
|
||||
DiTBlock(dim, n_heads, global_conddim, dtype=dtype, device=device, operations=operations)
|
||||
)
|
||||
|
||||
|
||||
self.final_linear = operations.Linear(
|
||||
dim, patch_size * patch_size * out_channels, bias=False, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
self.modF = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(global_conddim, 2 * dim, bias=False, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.out_channels = out_channels
|
||||
self.patch_size = patch_size
|
||||
self.n_double_layers = n_double_layers
|
||||
self.n_layers = n_layers
|
||||
|
||||
self.h_max = round(max_seq**0.5)
|
||||
self.w_max = round(max_seq**0.5)
|
||||
|
||||
@torch.no_grad()
|
||||
def extend_pe(self, init_dim=(16, 16), target_dim=(64, 64)):
|
||||
# extend pe
|
||||
pe_data = self.positional_encoding.data.squeeze(0)[: init_dim[0] * init_dim[1]]
|
||||
|
||||
pe_as_2d = pe_data.view(init_dim[0], init_dim[1], -1).permute(2, 0, 1)
|
||||
|
||||
# now we need to extend this to target_dim. for this we will use interpolation.
|
||||
# we will use torch.nn.functional.interpolate
|
||||
pe_as_2d = F.interpolate(
|
||||
pe_as_2d.unsqueeze(0), size=target_dim, mode="bilinear"
|
||||
)
|
||||
pe_new = pe_as_2d.squeeze(0).permute(1, 2, 0).flatten(0, 1)
|
||||
self.positional_encoding.data = pe_new.unsqueeze(0).contiguous()
|
||||
self.h_max, self.w_max = target_dim
|
||||
print("PE extended to", target_dim)
|
||||
|
||||
def pe_selection_index_based_on_dim(self, h, w):
|
||||
h_p, w_p = h // self.patch_size, w // self.patch_size
|
||||
original_pe_indexes = torch.arange(self.positional_encoding.shape[1])
|
||||
original_pe_indexes = original_pe_indexes.view(self.h_max, self.w_max)
|
||||
starth = self.h_max // 2 - h_p // 2
|
||||
endh =starth + h_p
|
||||
startw = self.w_max // 2 - w_p // 2
|
||||
endw = startw + w_p
|
||||
original_pe_indexes = original_pe_indexes[
|
||||
starth:endh, startw:endw
|
||||
]
|
||||
return original_pe_indexes.flatten()
|
||||
|
||||
def unpatchify(self, x, h, w):
|
||||
c = self.out_channels
|
||||
p = self.patch_size
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||
x = torch.einsum("nhwpqc->nchpwq", x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||
return imgs
|
||||
|
||||
def patchify(self, x):
|
||||
B, C, H, W = x.size()
|
||||
pad_h = (self.patch_size - H % self.patch_size) % self.patch_size
|
||||
pad_w = (self.patch_size - W % self.patch_size) % self.patch_size
|
||||
|
||||
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='circular')
|
||||
x = x.view(
|
||||
B,
|
||||
C,
|
||||
(H + 1) // self.patch_size,
|
||||
self.patch_size,
|
||||
(W + 1) // self.patch_size,
|
||||
self.patch_size,
|
||||
)
|
||||
x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
|
||||
return x
|
||||
|
||||
def apply_pos_embeds(self, x, h, w):
|
||||
h = (h + 1) // self.patch_size
|
||||
w = (w + 1) // self.patch_size
|
||||
max_dim = max(h, w)
|
||||
|
||||
cur_dim = self.h_max
|
||||
pos_encoding = comfy.ops.cast_to_input(self.positional_encoding.reshape(1, cur_dim, cur_dim, -1), x)
|
||||
|
||||
if max_dim > cur_dim:
|
||||
pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1, -1)
|
||||
cur_dim = max_dim
|
||||
|
||||
from_h = (cur_dim - h) // 2
|
||||
from_w = (cur_dim - w) // 2
|
||||
pos_encoding = pos_encoding[:,from_h:from_h+h,from_w:from_w+w]
|
||||
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
|
||||
|
||||
def forward(self, x, timestep, context, **kwargs):
|
||||
# patchify x, add PE
|
||||
b, c, h, w = x.shape
|
||||
|
||||
# pe_indexes = self.pe_selection_index_based_on_dim(h, w)
|
||||
# print(pe_indexes, pe_indexes.shape)
|
||||
|
||||
x = self.init_x_linear(self.patchify(x)) # B, T_x, D
|
||||
x = self.apply_pos_embeds(x, h, w)
|
||||
# x = x + self.positional_encoding[:, : x.size(1)].to(device=x.device, dtype=x.dtype)
|
||||
# x = x + self.positional_encoding[:, pe_indexes].to(device=x.device, dtype=x.dtype)
|
||||
|
||||
# process conditions for MMDiT Blocks
|
||||
c_seq = context # B, T_c, D_c
|
||||
t = timestep
|
||||
|
||||
c = self.cond_seq_linear(c_seq) # B, T_c, D
|
||||
c = torch.cat([comfy.ops.cast_to_input(self.register_tokens, c).repeat(c.size(0), 1, 1), c], dim=1)
|
||||
|
||||
global_cond = self.t_embedder(t, x.dtype) # B, D
|
||||
|
||||
if len(self.double_layers) > 0:
|
||||
for layer in self.double_layers:
|
||||
c, x = layer(c, x, global_cond, **kwargs)
|
||||
|
||||
if len(self.single_layers) > 0:
|
||||
c_len = c.size(1)
|
||||
cx = torch.cat([c, x], dim=1)
|
||||
for layer in self.single_layers:
|
||||
cx = layer(cx, global_cond, **kwargs)
|
||||
|
||||
x = cx[:, c_len:]
|
||||
|
||||
fshift, fscale = self.modF(global_cond).chunk(2, dim=1)
|
||||
|
||||
x = modulate(x, fshift, fscale)
|
||||
x = self.final_linear(x)
|
||||
x = self.unpatchify(x, (h + 1) // self.patch_size, (w + 1) // self.patch_size)[:,:,:h,:w]
|
||||
return x
|
||||
154
comfy/ldm/cascade/common.py
Normal file
154
comfy/ldm/cascade/common.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Stability AI
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.ops
|
||||
|
||||
class OptimizedAttention(nn.Module):
|
||||
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.heads = nhead
|
||||
|
||||
self.to_q = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, q, k, v):
|
||||
q = self.to_q(q)
|
||||
k = self.to_k(k)
|
||||
v = self.to_v(v)
|
||||
|
||||
out = optimized_attention(q, k, v, self.heads)
|
||||
|
||||
return self.out_proj(out)
|
||||
|
||||
class Attention2D(nn.Module):
|
||||
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
|
||||
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, kv, self_attn=False):
|
||||
orig_shape = x.shape
|
||||
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
|
||||
if self_attn:
|
||||
kv = torch.cat([x, kv], dim=1)
|
||||
# x = self.attn(x, kv, kv, need_weights=False)[0]
|
||||
x = self.attn(x, kv, kv)
|
||||
x = x.permute(0, 2, 1).view(*orig_shape)
|
||||
return x
|
||||
|
||||
|
||||
def LayerNorm2d_op(operations):
|
||||
class LayerNorm2d(operations.LayerNorm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
return LayerNorm2d
|
||||
|
||||
class GlobalResponseNorm(nn.Module):
|
||||
"from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
|
||||
def __init__(self, dim, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.empty(1, 1, 1, dim, dtype=dtype, device=device))
|
||||
self.beta = nn.Parameter(torch.empty(1, 1, 1, dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x):
|
||||
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
||||
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
||||
return comfy.ops.cast_to_input(self.gamma, x) * (x * Nx) + comfy.ops.cast_to_input(self.beta, x) + x
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0, dtype=None, device=None, operations=None): # , num_heads=4, expansion=2):
|
||||
super().__init__()
|
||||
self.depthwise = operations.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c, dtype=dtype, device=device)
|
||||
# self.depthwise = SAMBlock(c, num_heads, expansion)
|
||||
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.channelwise = nn.Sequential(
|
||||
operations.Linear(c + c_skip, c * 4, dtype=dtype, device=device),
|
||||
nn.GELU(),
|
||||
GlobalResponseNorm(c * 4, dtype=dtype, device=device),
|
||||
nn.Dropout(dropout),
|
||||
operations.Linear(c * 4, c, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x, x_skip=None):
|
||||
x_res = x
|
||||
x = self.norm(self.depthwise(x))
|
||||
if x_skip is not None:
|
||||
x = torch.cat([x, x_skip], dim=1)
|
||||
x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
return x + x_res
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.self_attn = self_attn
|
||||
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.attention = Attention2D(c, nhead, dropout, dtype=dtype, device=device, operations=operations)
|
||||
self.kv_mapper = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(c_cond, c, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x, kv):
|
||||
kv = self.kv_mapper(kv)
|
||||
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
|
||||
return x
|
||||
|
||||
|
||||
class FeedForwardBlock(nn.Module):
|
||||
def __init__(self, c, dropout=0.0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.channelwise = nn.Sequential(
|
||||
operations.Linear(c, c * 4, dtype=dtype, device=device),
|
||||
nn.GELU(),
|
||||
GlobalResponseNorm(c * 4, dtype=dtype, device=device),
|
||||
nn.Dropout(dropout),
|
||||
operations.Linear(c * 4, c, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class TimestepBlock(nn.Module):
|
||||
def __init__(self, c, c_timestep, conds=['sca'], dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.mapper = operations.Linear(c_timestep, c * 2, dtype=dtype, device=device)
|
||||
self.conds = conds
|
||||
for cname in conds:
|
||||
setattr(self, f"mapper_{cname}", operations.Linear(c_timestep, c * 2, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x, t):
|
||||
t = t.chunk(len(self.conds) + 1, dim=1)
|
||||
a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
|
||||
for i, c in enumerate(self.conds):
|
||||
ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
|
||||
a, b = a + ac, b + bc
|
||||
return x * (1 + a) + b
|
||||
93
comfy/ldm/cascade/controlnet.py
Normal file
93
comfy/ldm/cascade/controlnet.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Stability AI
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import nn
|
||||
from .common import LayerNorm2d_op
|
||||
|
||||
|
||||
class CNetResBlock(nn.Module):
|
||||
def __init__(self, c, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.blocks = nn.Sequential(
|
||||
LayerNorm2d_op(operations)(c, dtype=dtype, device=device),
|
||||
nn.GELU(),
|
||||
operations.Conv2d(c, c, kernel_size=3, padding=1),
|
||||
LayerNorm2d_op(operations)(c, dtype=dtype, device=device),
|
||||
nn.GELU(),
|
||||
operations.Conv2d(c, c, kernel_size=3, padding=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.blocks(x)
|
||||
|
||||
|
||||
class ControlNet(nn.Module):
|
||||
def __init__(self, c_in=3, c_proj=2048, proj_blocks=None, bottleneck_mode=None, dtype=None, device=None, operations=nn):
|
||||
super().__init__()
|
||||
if bottleneck_mode is None:
|
||||
bottleneck_mode = 'effnet'
|
||||
self.proj_blocks = proj_blocks
|
||||
if bottleneck_mode == 'effnet':
|
||||
embd_channels = 1280
|
||||
self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
|
||||
if c_in != 3:
|
||||
in_weights = self.backbone[0][0].weight.data
|
||||
self.backbone[0][0] = operations.Conv2d(c_in, 24, kernel_size=3, stride=2, bias=False, dtype=dtype, device=device)
|
||||
if c_in > 3:
|
||||
# nn.init.constant_(self.backbone[0][0].weight, 0)
|
||||
self.backbone[0][0].weight.data[:, :3] = in_weights[:, :3].clone()
|
||||
else:
|
||||
self.backbone[0][0].weight.data = in_weights[:, :c_in].clone()
|
||||
elif bottleneck_mode == 'simple':
|
||||
embd_channels = c_in
|
||||
self.backbone = nn.Sequential(
|
||||
operations.Conv2d(embd_channels, embd_channels * 4, kernel_size=3, padding=1, dtype=dtype, device=device),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
operations.Conv2d(embd_channels * 4, embd_channels, kernel_size=3, padding=1, dtype=dtype, device=device),
|
||||
)
|
||||
elif bottleneck_mode == 'large':
|
||||
self.backbone = nn.Sequential(
|
||||
operations.Conv2d(c_in, 4096 * 4, kernel_size=1, dtype=dtype, device=device),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
operations.Conv2d(4096 * 4, 1024, kernel_size=1, dtype=dtype, device=device),
|
||||
*[CNetResBlock(1024, dtype=dtype, device=device, operations=operations) for _ in range(8)],
|
||||
operations.Conv2d(1024, 1280, kernel_size=1, dtype=dtype, device=device),
|
||||
)
|
||||
embd_channels = 1280
|
||||
else:
|
||||
raise ValueError(f'Unknown bottleneck mode: {bottleneck_mode}')
|
||||
self.projections = nn.ModuleList()
|
||||
for _ in range(len(proj_blocks)):
|
||||
self.projections.append(nn.Sequential(
|
||||
operations.Conv2d(embd_channels, embd_channels, kernel_size=1, bias=False, dtype=dtype, device=device),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
operations.Conv2d(embd_channels, c_proj, kernel_size=1, bias=False, dtype=dtype, device=device),
|
||||
))
|
||||
# nn.init.constant_(self.projections[-1][-1].weight, 0) # zero output projection
|
||||
self.xl = False
|
||||
self.input_channels = c_in
|
||||
self.unshuffle_amount = 8
|
||||
|
||||
def forward(self, x):
|
||||
x = self.backbone(x)
|
||||
proj_outputs = [None for _ in range(max(self.proj_blocks) + 1)]
|
||||
for i, idx in enumerate(self.proj_blocks):
|
||||
proj_outputs[idx] = self.projections[i](x)
|
||||
return {"input": proj_outputs[::-1]}
|
||||
255
comfy/ldm/cascade/stage_a.py
Normal file
255
comfy/ldm/cascade/stage_a.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Stability AI
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
|
||||
class vector_quantize(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, codebook):
|
||||
with torch.no_grad():
|
||||
codebook_sqr = torch.sum(codebook ** 2, dim=1)
|
||||
x_sqr = torch.sum(x ** 2, dim=1, keepdim=True)
|
||||
|
||||
dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0)
|
||||
_, indices = dist.min(dim=1)
|
||||
|
||||
ctx.save_for_backward(indices, codebook)
|
||||
ctx.mark_non_differentiable(indices)
|
||||
|
||||
nn = torch.index_select(codebook, 0, indices)
|
||||
return nn, indices
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output, grad_indices):
|
||||
grad_inputs, grad_codebook = None, None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_inputs = grad_output.clone()
|
||||
if ctx.needs_input_grad[1]:
|
||||
# Gradient wrt. the codebook
|
||||
indices, codebook = ctx.saved_tensors
|
||||
|
||||
grad_codebook = torch.zeros_like(codebook)
|
||||
grad_codebook.index_add_(0, indices, grad_output)
|
||||
|
||||
return (grad_inputs, grad_codebook)
|
||||
|
||||
|
||||
class VectorQuantize(nn.Module):
|
||||
def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False):
|
||||
"""
|
||||
Takes an input of variable size (as long as the last dimension matches the embedding size).
|
||||
Returns one tensor containing the nearest neigbour embeddings to each of the inputs,
|
||||
with the same size as the input, vq and commitment components for the loss as a touple
|
||||
in the second output and the indices of the quantized vectors in the third:
|
||||
quantized, (vq_loss, commit_loss), indices
|
||||
"""
|
||||
super(VectorQuantize, self).__init__()
|
||||
|
||||
self.codebook = nn.Embedding(k, embedding_size)
|
||||
self.codebook.weight.data.uniform_(-1./k, 1./k)
|
||||
self.vq = vector_quantize.apply
|
||||
|
||||
self.ema_decay = ema_decay
|
||||
self.ema_loss = ema_loss
|
||||
if ema_loss:
|
||||
self.register_buffer('ema_element_count', torch.ones(k))
|
||||
self.register_buffer('ema_weight_sum', torch.zeros_like(self.codebook.weight))
|
||||
|
||||
def _laplace_smoothing(self, x, epsilon):
|
||||
n = torch.sum(x)
|
||||
return ((x + epsilon) / (n + x.size(0) * epsilon) * n)
|
||||
|
||||
def _updateEMA(self, z_e_x, indices):
|
||||
mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
|
||||
elem_count = mask.sum(dim=0)
|
||||
weight_sum = torch.mm(mask.t(), z_e_x)
|
||||
|
||||
self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count)
|
||||
self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5)
|
||||
self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum)
|
||||
|
||||
self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
|
||||
|
||||
def idx2vq(self, idx, dim=-1):
|
||||
q_idx = self.codebook(idx)
|
||||
if dim != -1:
|
||||
q_idx = q_idx.movedim(-1, dim)
|
||||
return q_idx
|
||||
|
||||
def forward(self, x, get_losses=True, dim=-1):
|
||||
if dim != -1:
|
||||
x = x.movedim(dim, -1)
|
||||
z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x
|
||||
z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach())
|
||||
vq_loss, commit_loss = None, None
|
||||
if self.ema_loss and self.training:
|
||||
self._updateEMA(z_e_x.detach(), indices.detach())
|
||||
# pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss
|
||||
z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices)
|
||||
if get_losses:
|
||||
vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean()
|
||||
commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean()
|
||||
|
||||
z_q_x = z_q_x.view(x.shape)
|
||||
if dim != -1:
|
||||
z_q_x = z_q_x.movedim(-1, dim)
|
||||
return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1])
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, c, c_hidden):
|
||||
super().__init__()
|
||||
# depthwise/attention
|
||||
self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
self.depthwise = nn.Sequential(
|
||||
nn.ReplicationPad2d(1),
|
||||
nn.Conv2d(c, c, kernel_size=3, groups=c)
|
||||
)
|
||||
|
||||
# channelwise
|
||||
self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
self.channelwise = nn.Sequential(
|
||||
nn.Linear(c, c_hidden),
|
||||
nn.GELU(),
|
||||
nn.Linear(c_hidden, c),
|
||||
)
|
||||
|
||||
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
|
||||
|
||||
# Init weights
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
self.apply(_basic_init)
|
||||
|
||||
def _norm(self, x, norm):
|
||||
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
|
||||
def forward(self, x):
|
||||
mods = self.gammas
|
||||
|
||||
x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1]
|
||||
try:
|
||||
x = x + self.depthwise(x_temp) * mods[2]
|
||||
except: #operation not implemented for bf16
|
||||
x_temp = self.depthwise[0](x_temp.float()).to(x.dtype)
|
||||
x = x + self.depthwise[1](x_temp) * mods[2]
|
||||
|
||||
x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4]
|
||||
x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class StageA(nn.Module):
|
||||
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192):
|
||||
super().__init__()
|
||||
self.c_latent = c_latent
|
||||
c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
|
||||
|
||||
# Encoder blocks
|
||||
self.in_block = nn.Sequential(
|
||||
nn.PixelUnshuffle(2),
|
||||
nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)
|
||||
)
|
||||
down_blocks = []
|
||||
for i in range(levels):
|
||||
if i > 0:
|
||||
down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
|
||||
block = ResBlock(c_levels[i], c_levels[i] * 4)
|
||||
down_blocks.append(block)
|
||||
down_blocks.append(nn.Sequential(
|
||||
nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
|
||||
))
|
||||
self.down_blocks = nn.Sequential(*down_blocks)
|
||||
self.down_blocks[0]
|
||||
|
||||
self.codebook_size = codebook_size
|
||||
self.vquantizer = VectorQuantize(c_latent, k=codebook_size)
|
||||
|
||||
# Decoder blocks
|
||||
up_blocks = [nn.Sequential(
|
||||
nn.Conv2d(c_latent, c_levels[-1], kernel_size=1)
|
||||
)]
|
||||
for i in range(levels):
|
||||
for j in range(bottleneck_blocks if i == 0 else 1):
|
||||
block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4)
|
||||
up_blocks.append(block)
|
||||
if i < levels - 1:
|
||||
up_blocks.append(
|
||||
nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
|
||||
padding=1))
|
||||
self.up_blocks = nn.Sequential(*up_blocks)
|
||||
self.out_block = nn.Sequential(
|
||||
nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
|
||||
nn.PixelShuffle(2),
|
||||
)
|
||||
|
||||
def encode(self, x, quantize=False):
|
||||
x = self.in_block(x)
|
||||
x = self.down_blocks(x)
|
||||
if quantize:
|
||||
qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
|
||||
return qe, x, indices, vq_loss + commit_loss * 0.25
|
||||
else:
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
x = self.up_blocks(x)
|
||||
x = self.out_block(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, quantize=False):
|
||||
qe, x, _, vq_loss = self.encode(x, quantize)
|
||||
x = self.decode(qe)
|
||||
return x, vq_loss
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6):
|
||||
super().__init__()
|
||||
d = max(depth - 3, 3)
|
||||
layers = [
|
||||
nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
|
||||
nn.LeakyReLU(0.2),
|
||||
]
|
||||
for i in range(depth - 1):
|
||||
c_in = c_hidden // (2 ** max((d - i), 0))
|
||||
c_out = c_hidden // (2 ** max((d - 1 - i), 0))
|
||||
layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
|
||||
layers.append(nn.InstanceNorm2d(c_out))
|
||||
layers.append(nn.LeakyReLU(0.2))
|
||||
self.encoder = nn.Sequential(*layers)
|
||||
self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
|
||||
self.logits = nn.Sigmoid()
|
||||
|
||||
def forward(self, x, cond=None):
|
||||
x = self.encoder(x)
|
||||
if cond is not None:
|
||||
cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1))
|
||||
x = torch.cat([x, cond], dim=1)
|
||||
x = self.shuffle(x)
|
||||
x = self.logits(x)
|
||||
return x
|
||||
256
comfy/ldm/cascade/stage_b.py
Normal file
256
comfy/ldm/cascade/stage_b.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Stability AI
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
|
||||
|
||||
class StageB(nn.Module):
|
||||
def __init__(self, c_in=4, c_out=4, c_r=64, patch_size=2, c_cond=1280, c_hidden=[320, 640, 1280, 1280],
|
||||
nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]],
|
||||
block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]], level_config=['CT', 'CT', 'CTA', 'CTA'], c_clip=1280,
|
||||
c_clip_seq=4, c_effnet=16, c_pixels=3, kernel_size=3, dropout=[0, 0, 0.0, 0.0], self_attn=True,
|
||||
t_conds=['sca'], stable_cascade_stage=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.c_r = c_r
|
||||
self.t_conds = t_conds
|
||||
self.c_clip_seq = c_clip_seq
|
||||
if not isinstance(dropout, list):
|
||||
dropout = [dropout] * len(c_hidden)
|
||||
if not isinstance(self_attn, list):
|
||||
self_attn = [self_attn] * len(c_hidden)
|
||||
|
||||
# CONDITIONING
|
||||
self.effnet_mapper = nn.Sequential(
|
||||
operations.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device),
|
||||
nn.GELU(),
|
||||
operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device),
|
||||
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
)
|
||||
self.pixels_mapper = nn.Sequential(
|
||||
operations.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device),
|
||||
nn.GELU(),
|
||||
operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device),
|
||||
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
)
|
||||
self.clip_mapper = operations.Linear(c_clip, c_cond * c_clip_seq, dtype=dtype, device=device)
|
||||
self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.embedding = nn.Sequential(
|
||||
nn.PixelUnshuffle(patch_size),
|
||||
operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
|
||||
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
|
||||
if block_type == 'C':
|
||||
return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
||||
elif block_type == 'A':
|
||||
return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
||||
elif block_type == 'F':
|
||||
return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
||||
elif block_type == 'T':
|
||||
return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
raise Exception(f'Block type {block_type} not supported')
|
||||
|
||||
# BLOCKS
|
||||
# -- down blocks
|
||||
self.down_blocks = nn.ModuleList()
|
||||
self.down_downscalers = nn.ModuleList()
|
||||
self.down_repeat_mappers = nn.ModuleList()
|
||||
for i in range(len(c_hidden)):
|
||||
if i > 0:
|
||||
self.down_downscalers.append(nn.Sequential(
|
||||
LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
|
||||
operations.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2, dtype=dtype, device=device),
|
||||
))
|
||||
else:
|
||||
self.down_downscalers.append(nn.Identity())
|
||||
down_block = nn.ModuleList()
|
||||
for _ in range(blocks[0][i]):
|
||||
for block_type in level_config[i]:
|
||||
block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
|
||||
down_block.append(block)
|
||||
self.down_blocks.append(down_block)
|
||||
if block_repeat is not None:
|
||||
block_repeat_mappers = nn.ModuleList()
|
||||
for _ in range(block_repeat[0][i] - 1):
|
||||
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
|
||||
self.down_repeat_mappers.append(block_repeat_mappers)
|
||||
|
||||
# -- up blocks
|
||||
self.up_blocks = nn.ModuleList()
|
||||
self.up_upscalers = nn.ModuleList()
|
||||
self.up_repeat_mappers = nn.ModuleList()
|
||||
for i in reversed(range(len(c_hidden))):
|
||||
if i > 0:
|
||||
self.up_upscalers.append(nn.Sequential(
|
||||
LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
|
||||
operations.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2, dtype=dtype, device=device),
|
||||
))
|
||||
else:
|
||||
self.up_upscalers.append(nn.Identity())
|
||||
up_block = nn.ModuleList()
|
||||
for j in range(blocks[1][::-1][i]):
|
||||
for k, block_type in enumerate(level_config[i]):
|
||||
c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
|
||||
block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
|
||||
self_attn=self_attn[i])
|
||||
up_block.append(block)
|
||||
self.up_blocks.append(up_block)
|
||||
if block_repeat is not None:
|
||||
block_repeat_mappers = nn.ModuleList()
|
||||
for _ in range(block_repeat[1][::-1][i] - 1):
|
||||
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
|
||||
self.up_repeat_mappers.append(block_repeat_mappers)
|
||||
|
||||
# OUTPUT
|
||||
self.clf = nn.Sequential(
|
||||
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
|
||||
operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
|
||||
nn.PixelShuffle(patch_size),
|
||||
)
|
||||
|
||||
# --- WEIGHT INIT ---
|
||||
# self.apply(self._init_weights) # General init
|
||||
# nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings
|
||||
# nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings
|
||||
# nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings
|
||||
# nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings
|
||||
# nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
|
||||
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
|
||||
# nn.init.constant_(self.clf[1].weight, 0) # outputs
|
||||
#
|
||||
# # blocks
|
||||
# for level_block in self.down_blocks + self.up_blocks:
|
||||
# for block in level_block:
|
||||
# if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
|
||||
# block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
|
||||
# elif isinstance(block, TimestepBlock):
|
||||
# for layer in block.modules():
|
||||
# if isinstance(layer, nn.Linear):
|
||||
# nn.init.constant_(layer.weight, 0)
|
||||
#
|
||||
# def _init_weights(self, m):
|
||||
# if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
# torch.nn.init.xavier_uniform_(m.weight)
|
||||
# if m.bias is not None:
|
||||
# nn.init.constant_(m.bias, 0)
|
||||
|
||||
def gen_r_embedding(self, r, max_positions=10000):
|
||||
r = r * max_positions
|
||||
half_dim = self.c_r // 2
|
||||
emb = math.log(max_positions) / (half_dim - 1)
|
||||
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
|
||||
emb = r[:, None] * emb[None, :]
|
||||
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
|
||||
if self.c_r % 2 == 1: # zero pad
|
||||
emb = nn.functional.pad(emb, (0, 1), mode='constant')
|
||||
return emb
|
||||
|
||||
def gen_c_embeddings(self, clip):
|
||||
if len(clip.shape) == 2:
|
||||
clip = clip.unsqueeze(1)
|
||||
clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1)
|
||||
clip = self.clip_norm(clip)
|
||||
return clip
|
||||
|
||||
def _down_encode(self, x, r_embed, clip):
|
||||
level_outputs = []
|
||||
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
||||
for down_block, downscaler, repmap in block_group:
|
||||
x = downscaler(x)
|
||||
for i in range(len(repmap) + 1):
|
||||
for block in down_block:
|
||||
if isinstance(block, ResBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
ResBlock)):
|
||||
x = block(x)
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
x = block(x, r_embed)
|
||||
else:
|
||||
x = block(x)
|
||||
if i < len(repmap):
|
||||
x = repmap[i](x)
|
||||
level_outputs.insert(0, x)
|
||||
return level_outputs
|
||||
|
||||
def _up_decode(self, level_outputs, r_embed, clip):
|
||||
x = level_outputs[0]
|
||||
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
||||
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||
for j in range(len(repmap) + 1):
|
||||
for k, block in enumerate(up_block):
|
||||
if isinstance(block, ResBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
ResBlock)):
|
||||
skip = level_outputs[i] if k == 0 and i > 0 else None
|
||||
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
|
||||
x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
|
||||
align_corners=True)
|
||||
x = block(x, skip)
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
x = block(x, r_embed)
|
||||
else:
|
||||
x = block(x)
|
||||
if j < len(repmap):
|
||||
x = repmap[j](x)
|
||||
x = upscaler(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
|
||||
if pixels is None:
|
||||
pixels = x.new_zeros(x.size(0), 3, 8, 8)
|
||||
|
||||
# Process the conditioning embeddings
|
||||
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
|
||||
for c in self.t_conds:
|
||||
t_cond = kwargs.get(c, torch.zeros_like(r))
|
||||
r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
|
||||
clip = self.gen_c_embeddings(clip)
|
||||
|
||||
# Model Blocks
|
||||
x = self.embedding(x)
|
||||
x = x + self.effnet_mapper(
|
||||
nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
|
||||
x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
|
||||
align_corners=True)
|
||||
level_outputs = self._down_encode(x, r_embed, clip)
|
||||
x = self._up_decode(level_outputs, r_embed, clip)
|
||||
return self.clf(x)
|
||||
|
||||
def update_weights_ema(self, src_model, beta=0.999):
|
||||
for self_params, src_params in zip(self.parameters(), src_model.parameters()):
|
||||
self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
|
||||
for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
|
||||
self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
|
||||
273
comfy/ldm/cascade/stage_c.py
Normal file
273
comfy/ldm/cascade/stage_c.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Stability AI
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import math
|
||||
from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
|
||||
# from .controlnet import ControlNetDeliverer
|
||||
|
||||
class UpDownBlock2d(nn.Module):
|
||||
def __init__(self, c_in, c_out, mode, enabled=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
assert mode in ['up', 'down']
|
||||
interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear',
|
||||
align_corners=True) if enabled else nn.Identity()
|
||||
mapping = operations.Conv2d(c_in, c_out, kernel_size=1, dtype=dtype, device=device)
|
||||
self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation])
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
|
||||
class StageC(nn.Module):
|
||||
def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32],
|
||||
blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'],
|
||||
c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3,
|
||||
dropout=[0.0, 0.0], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False], stable_cascade_stage=None,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.c_r = c_r
|
||||
self.t_conds = t_conds
|
||||
self.c_clip_seq = c_clip_seq
|
||||
if not isinstance(dropout, list):
|
||||
dropout = [dropout] * len(c_hidden)
|
||||
if not isinstance(self_attn, list):
|
||||
self_attn = [self_attn] * len(c_hidden)
|
||||
|
||||
# CONDITIONING
|
||||
self.clip_txt_mapper = operations.Linear(c_clip_text, c_cond, dtype=dtype, device=device)
|
||||
self.clip_txt_pooled_mapper = operations.Linear(c_clip_text_pooled, c_cond * c_clip_seq, dtype=dtype, device=device)
|
||||
self.clip_img_mapper = operations.Linear(c_clip_img, c_cond * c_clip_seq, dtype=dtype, device=device)
|
||||
self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.embedding = nn.Sequential(
|
||||
nn.PixelUnshuffle(patch_size),
|
||||
operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
|
||||
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6)
|
||||
)
|
||||
|
||||
def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
|
||||
if block_type == 'C':
|
||||
return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
||||
elif block_type == 'A':
|
||||
return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
||||
elif block_type == 'F':
|
||||
return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
|
||||
elif block_type == 'T':
|
||||
return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
raise Exception(f'Block type {block_type} not supported')
|
||||
|
||||
# BLOCKS
|
||||
# -- down blocks
|
||||
self.down_blocks = nn.ModuleList()
|
||||
self.down_downscalers = nn.ModuleList()
|
||||
self.down_repeat_mappers = nn.ModuleList()
|
||||
for i in range(len(c_hidden)):
|
||||
if i > 0:
|
||||
self.down_downscalers.append(nn.Sequential(
|
||||
LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
|
||||
UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
|
||||
))
|
||||
else:
|
||||
self.down_downscalers.append(nn.Identity())
|
||||
down_block = nn.ModuleList()
|
||||
for _ in range(blocks[0][i]):
|
||||
for block_type in level_config[i]:
|
||||
block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
|
||||
down_block.append(block)
|
||||
self.down_blocks.append(down_block)
|
||||
if block_repeat is not None:
|
||||
block_repeat_mappers = nn.ModuleList()
|
||||
for _ in range(block_repeat[0][i] - 1):
|
||||
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
|
||||
self.down_repeat_mappers.append(block_repeat_mappers)
|
||||
|
||||
# -- up blocks
|
||||
self.up_blocks = nn.ModuleList()
|
||||
self.up_upscalers = nn.ModuleList()
|
||||
self.up_repeat_mappers = nn.ModuleList()
|
||||
for i in reversed(range(len(c_hidden))):
|
||||
if i > 0:
|
||||
self.up_upscalers.append(nn.Sequential(
|
||||
LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6),
|
||||
UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
|
||||
))
|
||||
else:
|
||||
self.up_upscalers.append(nn.Identity())
|
||||
up_block = nn.ModuleList()
|
||||
for j in range(blocks[1][::-1][i]):
|
||||
for k, block_type in enumerate(level_config[i]):
|
||||
c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
|
||||
block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
|
||||
self_attn=self_attn[i])
|
||||
up_block.append(block)
|
||||
self.up_blocks.append(up_block)
|
||||
if block_repeat is not None:
|
||||
block_repeat_mappers = nn.ModuleList()
|
||||
for _ in range(block_repeat[1][::-1][i] - 1):
|
||||
block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
|
||||
self.up_repeat_mappers.append(block_repeat_mappers)
|
||||
|
||||
# OUTPUT
|
||||
self.clf = nn.Sequential(
|
||||
LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
|
||||
operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
|
||||
nn.PixelShuffle(patch_size),
|
||||
)
|
||||
|
||||
# --- WEIGHT INIT ---
|
||||
# self.apply(self._init_weights) # General init
|
||||
# nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings
|
||||
# nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings
|
||||
# nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
|
||||
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
|
||||
# nn.init.constant_(self.clf[1].weight, 0) # outputs
|
||||
#
|
||||
# # blocks
|
||||
# for level_block in self.down_blocks + self.up_blocks:
|
||||
# for block in level_block:
|
||||
# if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
|
||||
# block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
|
||||
# elif isinstance(block, TimestepBlock):
|
||||
# for layer in block.modules():
|
||||
# if isinstance(layer, nn.Linear):
|
||||
# nn.init.constant_(layer.weight, 0)
|
||||
#
|
||||
# def _init_weights(self, m):
|
||||
# if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
# torch.nn.init.xavier_uniform_(m.weight)
|
||||
# if m.bias is not None:
|
||||
# nn.init.constant_(m.bias, 0)
|
||||
|
||||
def gen_r_embedding(self, r, max_positions=10000):
|
||||
r = r * max_positions
|
||||
half_dim = self.c_r // 2
|
||||
emb = math.log(max_positions) / (half_dim - 1)
|
||||
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
|
||||
emb = r[:, None] * emb[None, :]
|
||||
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
|
||||
if self.c_r % 2 == 1: # zero pad
|
||||
emb = nn.functional.pad(emb, (0, 1), mode='constant')
|
||||
return emb
|
||||
|
||||
def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img):
|
||||
clip_txt = self.clip_txt_mapper(clip_txt)
|
||||
if len(clip_txt_pooled.shape) == 2:
|
||||
clip_txt_pooled = clip_txt_pooled.unsqueeze(1)
|
||||
if len(clip_img.shape) == 2:
|
||||
clip_img = clip_img.unsqueeze(1)
|
||||
clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1)
|
||||
clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1)
|
||||
clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
|
||||
clip = self.clip_norm(clip)
|
||||
return clip
|
||||
|
||||
def _down_encode(self, x, r_embed, clip, cnet=None):
|
||||
level_outputs = []
|
||||
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
||||
for down_block, downscaler, repmap in block_group:
|
||||
x = downscaler(x)
|
||||
for i in range(len(repmap) + 1):
|
||||
for block in down_block:
|
||||
if isinstance(block, ResBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
ResBlock)):
|
||||
if cnet is not None:
|
||||
next_cnet = cnet.pop()
|
||||
if next_cnet is not None:
|
||||
x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
|
||||
align_corners=True).to(x.dtype)
|
||||
x = block(x)
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
x = block(x, r_embed)
|
||||
else:
|
||||
x = block(x)
|
||||
if i < len(repmap):
|
||||
x = repmap[i](x)
|
||||
level_outputs.insert(0, x)
|
||||
return level_outputs
|
||||
|
||||
def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
|
||||
x = level_outputs[0]
|
||||
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
||||
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||
for j in range(len(repmap) + 1):
|
||||
for k, block in enumerate(up_block):
|
||||
if isinstance(block, ResBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
ResBlock)):
|
||||
skip = level_outputs[i] if k == 0 and i > 0 else None
|
||||
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
|
||||
x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
|
||||
align_corners=True)
|
||||
if cnet is not None:
|
||||
next_cnet = cnet.pop()
|
||||
if next_cnet is not None:
|
||||
x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
|
||||
align_corners=True).to(x.dtype)
|
||||
x = block(x, skip)
|
||||
elif isinstance(block, AttnBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
AttnBlock)):
|
||||
x = block(x, clip)
|
||||
elif isinstance(block, TimestepBlock) or (
|
||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||
TimestepBlock)):
|
||||
x = block(x, r_embed)
|
||||
else:
|
||||
x = block(x)
|
||||
if j < len(repmap):
|
||||
x = repmap[j](x)
|
||||
x = upscaler(x)
|
||||
return x
|
||||
|
||||
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs):
|
||||
# Process the conditioning embeddings
|
||||
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
|
||||
for c in self.t_conds:
|
||||
t_cond = kwargs.get(c, torch.zeros_like(r))
|
||||
r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
|
||||
clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img)
|
||||
|
||||
if control is not None:
|
||||
cnet = control.get("input")
|
||||
else:
|
||||
cnet = None
|
||||
|
||||
# Model Blocks
|
||||
x = self.embedding(x)
|
||||
level_outputs = self._down_encode(x, r_embed, clip, cnet)
|
||||
x = self._up_decode(level_outputs, r_embed, clip, cnet)
|
||||
return self.clf(x)
|
||||
|
||||
def update_weights_ema(self, src_model, beta=0.999):
|
||||
for self_params, src_params in zip(self.parameters(), src_model.parameters()):
|
||||
self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
|
||||
for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
|
||||
self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
|
||||
95
comfy/ldm/cascade/stage_c_coder.py
Normal file
95
comfy/ldm/cascade/stage_c_coder.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
This file is part of ComfyUI.
|
||||
Copyright (C) 2024 Stability AI
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import nn
|
||||
|
||||
|
||||
# EfficientNet
|
||||
class EfficientNetEncoder(nn.Module):
|
||||
def __init__(self, c_latent=16):
|
||||
super().__init__()
|
||||
self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
|
||||
self.mapper = nn.Sequential(
|
||||
nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
|
||||
)
|
||||
self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]))
|
||||
self.std = nn.Parameter(torch.tensor([0.229, 0.224, 0.225]))
|
||||
|
||||
def forward(self, x):
|
||||
x = x * 0.5 + 0.5
|
||||
x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1])
|
||||
o = self.mapper(self.backbone(x))
|
||||
return o
|
||||
|
||||
|
||||
# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192
|
||||
class Previewer(nn.Module):
|
||||
def __init__(self, c_in=16, c_hidden=512, c_out=3):
|
||||
super().__init__()
|
||||
self.blocks = nn.Sequential(
|
||||
nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden),
|
||||
|
||||
nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden),
|
||||
|
||||
nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 2),
|
||||
|
||||
nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 2),
|
||||
|
||||
nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
|
||||
nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return (self.blocks(x) - 0.5) * 2.0
|
||||
|
||||
class StageC_coder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.previewer = Previewer()
|
||||
self.encoder = EfficientNetEncoder()
|
||||
|
||||
def encode(self, x):
|
||||
return self.encoder(x)
|
||||
|
||||
def decode(self, x):
|
||||
return self.previewer(x)
|
||||
256
comfy/ldm/flux/layers.py
Normal file
256
comfy/ldm/flux/layers.py
Normal file
@@ -0,0 +1,256 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .math import attention, rope
|
||||
import comfy.ops
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: list):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: Tensor) -> Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat(
|
||||
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||
dim=-3,
|
||||
)
|
||||
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
|
||||
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
t = time_factor * t
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
||||
t.device
|
||||
)
|
||||
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
if torch.is_floating_point(t):
|
||||
embedding = embedding.to(t)
|
||||
return embedding
|
||||
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||
self.silu = nn.SiLU()
|
||||
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.out_layer(self.silu(self.in_layer(x)))
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x_dtype = x.dtype
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
return (x * rrms).to(dtype=x_dtype) * comfy.ops.cast_to(self.scale, dtype=x_dtype, device=x.device)
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
||||
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
|
||||
q = self.query_norm(q)
|
||||
k = self.key_norm(k)
|
||||
return q.to(v), k.to(v)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
x = attention(q, k, v, pe=pe)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModulationOut:
|
||||
shift: Tensor
|
||||
scale: Tensor
|
||||
gate: Tensor
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.is_double = double
|
||||
self.multiplier = 6 if double else 3
|
||||
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, vec: Tensor) -> tuple:
|
||||
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
||||
|
||||
return (
|
||||
ModulationOut(*out[:3]),
|
||||
ModulationOut(*out[3:]) if self.is_double else None,
|
||||
)
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
return img, txt
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with parallel linear layers as described in
|
||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float = None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
# qkv and mlp_in
|
||||
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
|
||||
# proj and mlp_out
|
||||
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
return x + mod.gate * output
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
35
comfy/ldm/flux/math.py
Normal file
35
comfy/ldm/flux/math.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
|
||||
heads = q.shape[1]
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True)
|
||||
return x
|
||||
|
||||
|
||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
assert dim % 2 == 0
|
||||
if comfy.model_management.is_device_mps(pos.device):
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
device = pos.device
|
||||
|
||||
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
|
||||
omega = 1.0 / (theta**scale)
|
||||
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.to(dtype=torch.float32, device=pos.device)
|
||||
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
144
comfy/ldm/flux/model.py
Normal file
144
comfy/ldm/flux/model.py
Normal file
@@ -0,0 +1,144 @@
|
||||
#Original code can be found on: https://github.com/black-forest-labs/flux
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
EmbedND,
|
||||
LastLayer,
|
||||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
in_channels: int
|
||||
vec_in_dim: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
mlp_ratio: float
|
||||
num_heads: int
|
||||
depth: int
|
||||
depth_single_blocks: int
|
||||
axes_dim: list
|
||||
theta: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
params = FluxParams(**kwargs)
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
)
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
||||
)
|
||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = 2
|
||||
pad_h = (patch_size - h % 2) % patch_size
|
||||
pad_w = (patch_size - w % 2) % patch_size
|
||||
|
||||
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='circular')
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance)
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
||||
219
comfy/ldm/hydit/attn_layers.py
Normal file
219
comfy/ldm/hydit/attn_layers.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Tuple, Union, Optional
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False):
|
||||
"""
|
||||
Reshape frequency tensor for broadcasting it with another tensor.
|
||||
|
||||
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
|
||||
for the purpose of broadcasting the frequency tensor during element-wise operations.
|
||||
|
||||
Args:
|
||||
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
|
||||
x (torch.Tensor): Target tensor for broadcasting compatibility.
|
||||
head_first (bool): head dimension first (except batch dim) or not.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Reshaped frequency tensor.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the frequency tensor doesn't match the expected shape.
|
||||
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
|
||||
"""
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
|
||||
if isinstance(freqs_cis, tuple):
|
||||
# freqs_cis: (cos, sin) in real space
|
||||
if head_first:
|
||||
assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
|
||||
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
else:
|
||||
assert freqs_cis[0].shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
||||
else:
|
||||
# freqs_cis: values in complex space
|
||||
if head_first:
|
||||
assert freqs_cis.shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
|
||||
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
else:
|
||||
assert freqs_cis.shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis.view(*shape)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: Optional[torch.Tensor],
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
head_first: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor.
|
||||
|
||||
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
|
||||
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
|
||||
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
|
||||
returned as real tensors.
|
||||
|
||||
Args:
|
||||
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
|
||||
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
|
||||
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials.
|
||||
head_first (bool): head dimension first (except batch dim) or not.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
||||
|
||||
"""
|
||||
xk_out = None
|
||||
if isinstance(freqs_cis, tuple):
|
||||
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
||||
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
||||
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
|
||||
if xk is not None:
|
||||
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
|
||||
else:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
|
||||
if xk is not None:
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
|
||||
|
||||
return xq_out, xk_out
|
||||
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
"""
|
||||
Use QK Normalization.
|
||||
"""
|
||||
def __init__(self,
|
||||
qdim,
|
||||
kdim,
|
||||
num_heads,
|
||||
qkv_bias=True,
|
||||
qk_norm=False,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
attn_precision=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.attn_precision = attn_precision
|
||||
self.qdim = qdim
|
||||
self.kdim = kdim
|
||||
self.num_heads = num_heads
|
||||
assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
|
||||
self.head_dim = self.qdim // num_heads
|
||||
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.q_proj = operations.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
|
||||
self.kv_proj = operations.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
|
||||
|
||||
# TODO: eps should be 1 / 65530 if using fp16
|
||||
self.q_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
|
||||
self.k_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.out_proj = operations.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x, y, freqs_cis_img=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
x: torch.Tensor
|
||||
(batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim)
|
||||
y: torch.Tensor
|
||||
(batch, seqlen2, hidden_dim2)
|
||||
freqs_cis_img: torch.Tensor
|
||||
(batch, hidden_dim // 2), RoPE for image
|
||||
"""
|
||||
b, s1, c = x.shape # [b, s1, D]
|
||||
_, s2, c = y.shape # [b, s2, 1024]
|
||||
|
||||
q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d]
|
||||
kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim) # [b, s2, 2, h, d]
|
||||
k, v = kv.unbind(dim=2) # [b, s, h, d]
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if freqs_cis_img is not None:
|
||||
qq, _ = apply_rotary_emb(q, None, freqs_cis_img)
|
||||
assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}'
|
||||
q = qq
|
||||
|
||||
q = q.transpose(-2, -3).contiguous() # q -> B, L1, H, C - B, H, L1, C
|
||||
k = k.transpose(-2, -3).contiguous() # k -> B, L2, H, C - B, H, C, L2
|
||||
v = v.transpose(-2, -3).contiguous()
|
||||
|
||||
context = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)
|
||||
|
||||
out = self.out_proj(context) # context.reshape - B, L1, -1
|
||||
out = self.proj_drop(out)
|
||||
|
||||
out_tuple = (out,)
|
||||
|
||||
return out_tuple
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
We rename some layer names to align with flash attention
|
||||
"""
|
||||
def __init__(self, dim, num_heads, qkv_bias=True, qk_norm=False, attn_drop=0., proj_drop=0., attn_precision=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.attn_precision = attn_precision
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
assert self.dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
self.head_dim = self.dim // num_heads
|
||||
# This assertion is aligned with flash attention
|
||||
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
# qkv --> Wqkv
|
||||
self.Wqkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
# TODO: eps should be 1 / 65530 if using fp16
|
||||
self.q_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
|
||||
self.k_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.out_proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x, freqs_cis_img=None):
|
||||
B, N, C = x.shape
|
||||
qkv = self.Wqkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) # [3, b, h, s, d]
|
||||
q, k, v = qkv.unbind(0) # [b, h, s, d]
|
||||
q = self.q_norm(q) # [b, h, s, d]
|
||||
k = self.k_norm(k) # [b, h, s, d]
|
||||
|
||||
# Apply RoPE if needed
|
||||
if freqs_cis_img is not None:
|
||||
qq, kk = apply_rotary_emb(q, k, freqs_cis_img, head_first=True)
|
||||
assert qq.shape == q.shape and kk.shape == k.shape, \
|
||||
f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}'
|
||||
q, k = qq, kk
|
||||
|
||||
x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)
|
||||
x = self.out_proj(x)
|
||||
x = self.proj_drop(x)
|
||||
|
||||
out_tuple = (x,)
|
||||
|
||||
return out_tuple
|
||||
405
comfy/ldm/hydit/models.py
Normal file
405
comfy/ldm/hydit/models.py
Normal file
@@ -0,0 +1,405 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ops
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
|
||||
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||
from torch.utils import checkpoint
|
||||
|
||||
from .attn_layers import Attention, CrossAttention
|
||||
from .poolers import AttentionPool
|
||||
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
|
||||
|
||||
def calc_rope(x, patch_size, head_size):
|
||||
th = (x.shape[2] + (patch_size // 2)) // patch_size
|
||||
tw = (x.shape[3] + (patch_size // 2)) // patch_size
|
||||
base_size = 512 // 8 // patch_size
|
||||
start, stop = get_fill_resize_and_crop((th, tw), base_size)
|
||||
sub_args = [start, stop, (th, tw)]
|
||||
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
|
||||
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
|
||||
return rope
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
class HunYuanDiTBlock(nn.Module):
|
||||
"""
|
||||
A HunYuanDiT block with `add` conditioning.
|
||||
"""
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
c_emb_size,
|
||||
num_heads,
|
||||
mlp_ratio=4.0,
|
||||
text_states_dim=1024,
|
||||
qk_norm=False,
|
||||
norm_type="layer",
|
||||
skip=False,
|
||||
attn_precision=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
use_ele_affine = True
|
||||
|
||||
if norm_type == "layer":
|
||||
norm_layer = operations.LayerNorm
|
||||
elif norm_type == "rms":
|
||||
norm_layer = RMSNorm
|
||||
else:
|
||||
raise ValueError(f"Unknown norm_type: {norm_type}")
|
||||
|
||||
# ========================= Self-Attention =========================
|
||||
self.norm1 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6, dtype=dtype, device=device)
|
||||
self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
# ========================= FFN =========================
|
||||
self.norm2 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6, dtype=dtype, device=device)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
# ========================= Add =========================
|
||||
# Simply use add like SDXL.
|
||||
self.default_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(c_emb_size, hidden_size, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
# ========================= Cross-Attention =========================
|
||||
self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=True,
|
||||
qk_norm=qk_norm, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
# ========================= Skip Connection =========================
|
||||
if skip:
|
||||
self.skip_norm = norm_layer(2 * hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
|
||||
self.skip_linear = operations.Linear(2 * hidden_size, hidden_size, dtype=dtype, device=device)
|
||||
else:
|
||||
self.skip_linear = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
|
||||
# Long Skip Connection
|
||||
if self.skip_linear is not None:
|
||||
cat = torch.cat([x, skip], dim=-1)
|
||||
cat = self.skip_norm(cat)
|
||||
x = self.skip_linear(cat)
|
||||
|
||||
# Self-Attention
|
||||
shift_msa = self.default_modulation(c).unsqueeze(dim=1)
|
||||
attn_inputs = (
|
||||
self.norm1(x) + shift_msa, freq_cis_img,
|
||||
)
|
||||
x = x + self.attn1(*attn_inputs)[0]
|
||||
|
||||
# Cross-Attention
|
||||
cross_inputs = (
|
||||
self.norm3(x), text_states, freq_cis_img
|
||||
)
|
||||
x = x + self.attn2(*cross_inputs)[0]
|
||||
|
||||
# FFN Layer
|
||||
mlp_inputs = self.norm2(x)
|
||||
x = x + self.mlp(mlp_inputs)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
|
||||
if self.gradient_checkpointing and self.training:
|
||||
return checkpoint.checkpoint(self._forward, x, c, text_states, freq_cis_img, skip)
|
||||
return self._forward(x, c, text_states, freq_cis_img, skip)
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of HunYuanDiT.
|
||||
"""
|
||||
def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(c_emb_size, 2 * final_hidden_size, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class HunYuanDiT(nn.Module):
|
||||
"""
|
||||
HunYuanDiT: Diffusion model with a Transformer backbone.
|
||||
|
||||
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
|
||||
|
||||
Inherit PeftAdapterMixin to be compatible with the PEFT training pipeline.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
args: argparse.Namespace
|
||||
The arguments parsed by argparse.
|
||||
input_size: tuple
|
||||
The size of the input image.
|
||||
patch_size: int
|
||||
The size of the patch.
|
||||
in_channels: int
|
||||
The number of input channels.
|
||||
hidden_size: int
|
||||
The hidden size of the transformer backbone.
|
||||
depth: int
|
||||
The number of transformer blocks.
|
||||
num_heads: int
|
||||
The number of attention heads.
|
||||
mlp_ratio: float
|
||||
The ratio of the hidden size of the MLP in the transformer block.
|
||||
log_fn: callable
|
||||
The logging function.
|
||||
"""
|
||||
#@register_to_config
|
||||
def __init__(self,
|
||||
input_size: tuple = 32,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 4,
|
||||
hidden_size: int = 1152,
|
||||
depth: int = 28,
|
||||
num_heads: int = 16,
|
||||
mlp_ratio: float = 4.0,
|
||||
text_states_dim = 1024,
|
||||
text_states_dim_t5 = 2048,
|
||||
text_len = 77,
|
||||
text_len_t5 = 256,
|
||||
qk_norm = True,# See http://arxiv.org/abs/2302.05442 for details.
|
||||
size_cond = False,
|
||||
use_style_cond = False,
|
||||
learn_sigma = True,
|
||||
norm = "layer",
|
||||
log_fn: callable = print,
|
||||
attn_precision=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.log_fn = log_fn
|
||||
self.depth = depth
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
||||
self.patch_size = patch_size
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.text_states_dim = text_states_dim
|
||||
self.text_states_dim_t5 = text_states_dim_t5
|
||||
self.text_len = text_len
|
||||
self.text_len_t5 = text_len_t5
|
||||
self.size_cond = size_cond
|
||||
self.use_style_cond = use_style_cond
|
||||
self.norm = norm
|
||||
self.dtype = dtype
|
||||
#import pdb
|
||||
#pdb.set_trace()
|
||||
|
||||
self.mlp_t5 = nn.Sequential(
|
||||
operations.Linear(self.text_states_dim_t5, self.text_states_dim_t5 * 4, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(self.text_states_dim_t5 * 4, self.text_states_dim, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
# learnable replace
|
||||
self.text_embedding_padding = nn.Parameter(
|
||||
torch.empty(self.text_len + self.text_len_t5, self.text_states_dim, dtype=dtype, device=device))
|
||||
|
||||
# Attention pooling
|
||||
pooler_out_dim = 1024
|
||||
self.pooler = AttentionPool(self.text_len_t5, self.text_states_dim_t5, num_heads=8, output_dim=pooler_out_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
# Dimension of the extra input vectors
|
||||
self.extra_in_dim = pooler_out_dim
|
||||
|
||||
if self.size_cond:
|
||||
# Image size and crop size conditions
|
||||
self.extra_in_dim += 6 * 256
|
||||
|
||||
if self.use_style_cond:
|
||||
# Here we use a default learned embedder layer for future extension.
|
||||
self.style_embedder = operations.Embedding(1, hidden_size, dtype=dtype, device=device)
|
||||
self.extra_in_dim += hidden_size
|
||||
|
||||
# Text embedding for `add`
|
||||
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.extra_embedder = nn.Sequential(
|
||||
operations.Linear(self.extra_in_dim, hidden_size * 4, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
# Image embedding
|
||||
num_patches = self.x_embedder.num_patches
|
||||
|
||||
# HUnYuanDiT Blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
HunYuanDiTBlock(hidden_size=hidden_size,
|
||||
c_emb_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
text_states_dim=self.text_states_dim,
|
||||
qk_norm=qk_norm,
|
||||
norm_type=self.norm,
|
||||
skip=layer > depth // 2,
|
||||
attn_precision=attn_precision,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
for layer in range(depth)
|
||||
])
|
||||
|
||||
self.final_layer = FinalLayer(hidden_size, hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
self.unpatchify_channels = self.out_channels
|
||||
|
||||
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
t,
|
||||
context,#encoder_hidden_states=None,
|
||||
text_embedding_mask=None,
|
||||
encoder_hidden_states_t5=None,
|
||||
text_embedding_mask_t5=None,
|
||||
image_meta_size=None,
|
||||
style=None,
|
||||
return_dict=False,
|
||||
control=None,
|
||||
transformer_options=None,
|
||||
):
|
||||
"""
|
||||
Forward pass of the encoder.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x: torch.Tensor
|
||||
(B, D, H, W)
|
||||
t: torch.Tensor
|
||||
(B)
|
||||
encoder_hidden_states: torch.Tensor
|
||||
CLIP text embedding, (B, L_clip, D)
|
||||
text_embedding_mask: torch.Tensor
|
||||
CLIP text embedding mask, (B, L_clip)
|
||||
encoder_hidden_states_t5: torch.Tensor
|
||||
T5 text embedding, (B, L_t5, D)
|
||||
text_embedding_mask_t5: torch.Tensor
|
||||
T5 text embedding mask, (B, L_t5)
|
||||
image_meta_size: torch.Tensor
|
||||
(B, 6)
|
||||
style: torch.Tensor
|
||||
(B)
|
||||
cos_cis_img: torch.Tensor
|
||||
sin_cis_img: torch.Tensor
|
||||
return_dict: bool
|
||||
Whether to return a dictionary.
|
||||
"""
|
||||
#import pdb
|
||||
#pdb.set_trace()
|
||||
encoder_hidden_states = context
|
||||
text_states = encoder_hidden_states # 2,77,1024
|
||||
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
|
||||
text_states_mask = text_embedding_mask.bool() # 2,77
|
||||
text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
|
||||
b_t5, l_t5, c_t5 = text_states_t5.shape
|
||||
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
|
||||
|
||||
padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states)
|
||||
|
||||
text_states[:,-self.text_len:] = torch.where(text_states_mask[:,-self.text_len:].unsqueeze(2), text_states[:,-self.text_len:], padding[:self.text_len])
|
||||
text_states_t5[:,-self.text_len_t5:] = torch.where(text_states_t5_mask[:,-self.text_len_t5:].unsqueeze(2), text_states_t5[:,-self.text_len_t5:], padding[self.text_len:])
|
||||
|
||||
text_states = torch.cat([text_states, text_states_t5], dim=1) # 2,205,1024
|
||||
# clip_t5_mask = torch.cat([text_states_mask, text_states_t5_mask], dim=-1)
|
||||
|
||||
_, _, oh, ow = x.shape
|
||||
th, tw = (oh + (self.patch_size // 2)) // self.patch_size, (ow + (self.patch_size // 2)) // self.patch_size
|
||||
|
||||
|
||||
# Get image RoPE embedding according to `reso`lution.
|
||||
freqs_cis_img = calc_rope(x, self.patch_size, self.hidden_size // self.num_heads) #(cos_cis_img, sin_cis_img)
|
||||
|
||||
# ========================= Build time and image embedding =========================
|
||||
t = self.t_embedder(t, dtype=x.dtype)
|
||||
x = self.x_embedder(x)
|
||||
|
||||
# ========================= Concatenate all extra vectors =========================
|
||||
# Build text tokens with pooling
|
||||
extra_vec = self.pooler(encoder_hidden_states_t5)
|
||||
|
||||
# Build image meta size tokens if applicable
|
||||
if self.size_cond:
|
||||
image_meta_size = timestep_embedding(image_meta_size.view(-1), 256).to(x.dtype) # [B * 6, 256]
|
||||
image_meta_size = image_meta_size.view(-1, 6 * 256)
|
||||
extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256]
|
||||
|
||||
# Build style tokens
|
||||
if self.use_style_cond:
|
||||
if style is None:
|
||||
style = torch.zeros((extra_vec.shape[0],), device=x.device, dtype=torch.int)
|
||||
style_embedding = self.style_embedder(style, out_dtype=x.dtype)
|
||||
extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
|
||||
|
||||
# Concatenate all extra vectors
|
||||
c = t + self.extra_embedder(extra_vec) # [B, D]
|
||||
|
||||
controls = None
|
||||
# ========================= Forward pass through HunYuanDiT blocks =========================
|
||||
skips = []
|
||||
for layer, block in enumerate(self.blocks):
|
||||
if layer > self.depth // 2:
|
||||
if controls is not None:
|
||||
skip = skips.pop() + controls.pop()
|
||||
else:
|
||||
skip = skips.pop()
|
||||
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
||||
else:
|
||||
x = block(x, c, text_states, freqs_cis_img) # (N, L, D)
|
||||
|
||||
if layer < (self.depth // 2 - 1):
|
||||
skips.append(x)
|
||||
if controls is not None and len(controls) != 0:
|
||||
raise ValueError("The number of controls is not equal to the number of skip connections.")
|
||||
|
||||
# ========================= Final layer =========================
|
||||
x = self.final_layer(x, c) # (N, L, patch_size ** 2 * out_channels)
|
||||
x = self.unpatchify(x, th, tw) # (N, out_channels, H, W)
|
||||
|
||||
if return_dict:
|
||||
return {'x': x}
|
||||
if self.learn_sigma:
|
||||
return x[:,:self.out_channels // 2,:oh,:ow]
|
||||
return x[:,:,:oh,:ow]
|
||||
|
||||
def unpatchify(self, x, h, w):
|
||||
"""
|
||||
x: (N, T, patch_size**2 * C)
|
||||
imgs: (N, H, W, C)
|
||||
"""
|
||||
c = self.unpatchify_channels
|
||||
p = self.x_embedder.patch_size[0]
|
||||
# h = w = int(x.shape[1] ** 0.5)
|
||||
assert h * w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||
return imgs
|
||||
37
comfy/ldm/hydit/poolers.py
Normal file
37
comfy/ldm/hydit/poolers.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.ops
|
||||
|
||||
class AttentionPool(nn.Module):
|
||||
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(torch.empty(spacial_dim + 1, embed_dim, dtype=dtype, device=device))
|
||||
self.k_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
|
||||
self.q_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
|
||||
self.v_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
|
||||
self.c_proj = operations.Linear(embed_dim, output_dim or embed_dim, dtype=dtype, device=device)
|
||||
self.num_heads = num_heads
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def forward(self, x):
|
||||
x = x[:,:self.positional_embedding.shape[0] - 1]
|
||||
x = x.permute(1, 0, 2) # NLC -> LNC
|
||||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
|
||||
x = x + comfy.ops.cast_to_input(self.positional_embedding[:, None, :], x) # (L+1)NC
|
||||
|
||||
q = self.q_proj(x[:1])
|
||||
k = self.k_proj(x)
|
||||
v = self.v_proj(x)
|
||||
|
||||
batch_size = q.shape[1]
|
||||
head_dim = self.embed_dim // self.num_heads
|
||||
q = q.view(1, batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
|
||||
k = k.view(k.shape[0], batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
|
||||
v = v.view(v.shape[0], batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
|
||||
|
||||
attn_output = optimized_attention(q, k, v, self.num_heads, skip_reshape=True).transpose(0, 1)
|
||||
|
||||
attn_output = self.c_proj(attn_output)
|
||||
return attn_output.squeeze(0)
|
||||
224
comfy/ldm/hydit/posemb_layers.py
Normal file
224
comfy/ldm/hydit/posemb_layers.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Union
|
||||
|
||||
|
||||
def _to_tuple(x):
|
||||
if isinstance(x, int):
|
||||
return x, x
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def get_fill_resize_and_crop(src, tgt):
|
||||
th, tw = _to_tuple(tgt)
|
||||
h, w = _to_tuple(src)
|
||||
|
||||
tr = th / tw # base resolution
|
||||
r = h / w # target resolution
|
||||
|
||||
# resize
|
||||
if r > tr:
|
||||
resize_height = th
|
||||
resize_width = int(round(th / h * w))
|
||||
else:
|
||||
resize_width = tw
|
||||
resize_height = int(round(tw / w * h)) # resize the target resolution down based on the base resolution
|
||||
|
||||
crop_top = int(round((th - resize_height) / 2.0))
|
||||
crop_left = int(round((tw - resize_width) / 2.0))
|
||||
|
||||
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
||||
|
||||
|
||||
def get_meshgrid(start, *args):
|
||||
if len(args) == 0:
|
||||
# start is grid_size
|
||||
num = _to_tuple(start)
|
||||
start = (0, 0)
|
||||
stop = num
|
||||
elif len(args) == 1:
|
||||
# start is start, args[0] is stop, step is 1
|
||||
start = _to_tuple(start)
|
||||
stop = _to_tuple(args[0])
|
||||
num = (stop[0] - start[0], stop[1] - start[1])
|
||||
elif len(args) == 2:
|
||||
# start is start, args[0] is stop, args[1] is num
|
||||
start = _to_tuple(start)
|
||||
stop = _to_tuple(args[0])
|
||||
num = _to_tuple(args[1])
|
||||
else:
|
||||
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
||||
|
||||
grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0) # [2, W, H]
|
||||
return grid
|
||||
|
||||
#################################################################################
|
||||
# Sine/Cosine Positional Embedding Functions #
|
||||
#################################################################################
|
||||
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
||||
|
||||
def get_2d_sincos_pos_embed(embed_dim, start, *args, cls_token=False, extra_tokens=0):
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
return:
|
||||
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
grid = get_meshgrid(start, *args) # [2, H, w]
|
||||
# grid_h = np.arange(grid_size, dtype=np.float32)
|
||||
# grid_w = np.arange(grid_size, dtype=np.float32)
|
||||
# grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
# grid = np.stack(grid, axis=0) # [2, W, H]
|
||||
|
||||
grid = grid.reshape([2, 1, *grid.shape[1:]])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token and extra_tokens > 0:
|
||||
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (W,H)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||
omega /= embed_dim / 2.
|
||||
omega = 1. / 10000**omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Rotary Positional Embedding Functions #
|
||||
#################################################################################
|
||||
# https://github.com/facebookresearch/llama/blob/main/llama/model.py#L443
|
||||
|
||||
def get_2d_rotary_pos_embed(embed_dim, start, *args, use_real=True):
|
||||
"""
|
||||
This is a 2d version of precompute_freqs_cis, which is a RoPE for image tokens with 2d structure.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
embed_dim: int
|
||||
embedding dimension size
|
||||
start: int or tuple of int
|
||||
If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1;
|
||||
If len(args) == 2, start is start, args[0] is stop, args[1] is num.
|
||||
use_real: bool
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pos_embed: torch.Tensor
|
||||
[HW, D/2]
|
||||
"""
|
||||
grid = get_meshgrid(start, *args) # [2, H, w]
|
||||
grid = grid.reshape([2, 1, *grid.shape[1:]]) # Returns a sampling matrix with the same resolution as the target resolution
|
||||
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
|
||||
assert embed_dim % 4 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
|
||||
emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
|
||||
|
||||
if use_real:
|
||||
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
|
||||
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
|
||||
return cos, sin
|
||||
else:
|
||||
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
|
||||
"""
|
||||
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
||||
|
||||
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
|
||||
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
||||
The returned tensor contains complex values in complex64 data type.
|
||||
|
||||
Args:
|
||||
dim (int): Dimension of the frequency tensor.
|
||||
pos (np.ndarray, int): Position indices for the frequency tensor. [S] or scalar
|
||||
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
||||
use_real (bool, optional): If True, return real part and imaginary part separately.
|
||||
Otherwise, return complex numbers.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
||||
|
||||
"""
|
||||
if isinstance(pos, int):
|
||||
pos = np.arange(pos)
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
|
||||
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
|
||||
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
|
||||
if use_real:
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
else:
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
||||
return freqs_cis
|
||||
|
||||
|
||||
|
||||
def calc_sizes(rope_img, patch_size, th, tw):
|
||||
if rope_img == 'extend':
|
||||
# Expansion mode
|
||||
sub_args = [(th, tw)]
|
||||
elif rope_img.startswith('base'):
|
||||
# Based on the specified dimensions, other dimensions are obtained through interpolation.
|
||||
base_size = int(rope_img[4:]) // 8 // patch_size
|
||||
start, stop = get_fill_resize_and_crop((th, tw), base_size)
|
||||
sub_args = [start, stop, (th, tw)]
|
||||
else:
|
||||
raise ValueError(f"Unknown rope_img: {rope_img}")
|
||||
return sub_args
|
||||
|
||||
|
||||
def init_image_posemb(rope_img,
|
||||
resolutions,
|
||||
patch_size,
|
||||
hidden_size,
|
||||
num_heads,
|
||||
log_fn,
|
||||
rope_real=True,
|
||||
):
|
||||
freqs_cis_img = {}
|
||||
for reso in resolutions:
|
||||
th, tw = reso.height // 8 // patch_size, reso.width // 8 // patch_size
|
||||
sub_args = calc_sizes(rope_img, patch_size, th, tw)
|
||||
freqs_cis_img[str(reso)] = get_2d_rotary_pos_embed(hidden_size // num_heads, *sub_args, use_real=rope_real)
|
||||
log_fn(f" Using image RoPE ({rope_img}) ({'real' if rope_real else 'complex'}): {sub_args} | ({reso}) "
|
||||
f"{freqs_cis_img[str(reso)][0].shape if rope_real else freqs_cis_img[str(reso)].shape}")
|
||||
return freqs_cis_img
|
||||
226
comfy/ldm/models/autoencoder.py
Normal file
226
comfy/ldm/models/autoencoder.py
Normal file
@@ -0,0 +1,226 @@
|
||||
import torch
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
|
||||
from comfy.ldm.util import instantiate_from_config
|
||||
from comfy.ldm.modules.ema import LitEma
|
||||
import comfy.ops
|
||||
|
||||
class DiagonalGaussianRegularizer(torch.nn.Module):
|
||||
def __init__(self, sample: bool = True):
|
||||
super().__init__()
|
||||
self.sample = sample
|
||||
|
||||
def get_trainable_parameters(self) -> Any:
|
||||
yield from ()
|
||||
|
||||
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
||||
log = dict()
|
||||
posterior = DiagonalGaussianDistribution(z)
|
||||
if self.sample:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
kl_loss = posterior.kl()
|
||||
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
||||
log["kl_loss"] = kl_loss
|
||||
return z, log
|
||||
|
||||
|
||||
class AbstractAutoencoder(torch.nn.Module):
|
||||
"""
|
||||
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
|
||||
unCLIP models, etc. Hence, it is fairly general, and specific features
|
||||
(e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ema_decay: Union[None, float] = None,
|
||||
monitor: Union[None, str] = None,
|
||||
input_key: str = "jpg",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_key = input_key
|
||||
self.use_ema = ema_decay is not None
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self, decay=ema_decay)
|
||||
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
def get_input(self, batch) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
# for EMA computation
|
||||
if self.use_ema:
|
||||
self.model_ema(self)
|
||||
|
||||
@contextmanager
|
||||
def ema_scope(self, context=None):
|
||||
if self.use_ema:
|
||||
self.model_ema.store(self.parameters())
|
||||
self.model_ema.copy_to(self)
|
||||
if context is not None:
|
||||
logpy.info(f"{context}: Switched to EMA weights")
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.parameters())
|
||||
if context is not None:
|
||||
logpy.info(f"{context}: Restored training weights")
|
||||
|
||||
def encode(self, *args, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError("encode()-method of abstract base class called")
|
||||
|
||||
def decode(self, *args, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError("decode()-method of abstract base class called")
|
||||
|
||||
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
||||
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
||||
return get_obj_from_str(cfg["target"])(
|
||||
params, lr=lr, **cfg.get("params", dict())
|
||||
)
|
||||
|
||||
def configure_optimizers(self) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class AutoencodingEngine(AbstractAutoencoder):
|
||||
"""
|
||||
Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
|
||||
(we also restore them explicitly as special cases for legacy reasons).
|
||||
Regularizations such as KL or VQ are moved to the regularizer class.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
encoder_config: Dict,
|
||||
decoder_config: Dict,
|
||||
regularizer_config: Dict,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
|
||||
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
|
||||
self.regularization: AbstractRegularizer = instantiate_from_config(
|
||||
regularizer_config
|
||||
)
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.get_last_layer()
|
||||
|
||||
def encode(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
return_reg_log: bool = False,
|
||||
unregularized: bool = False,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||
z = self.encoder(x)
|
||||
if unregularized:
|
||||
return z, dict()
|
||||
z, reg_log = self.regularization(z)
|
||||
if return_reg_log:
|
||||
return z, reg_log
|
||||
return z
|
||||
|
||||
def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
x = self.decoder(z, **kwargs)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, **additional_decode_kwargs
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
||||
z, reg_log = self.encode(x, return_reg_log=True)
|
||||
dec = self.decode(z, **additional_decode_kwargs)
|
||||
return z, dec, reg_log
|
||||
|
||||
|
||||
class AutoencodingEngineLegacy(AutoencodingEngine):
|
||||
def __init__(self, embed_dim: int, **kwargs):
|
||||
self.max_batch_size = kwargs.pop("max_batch_size", None)
|
||||
ddconfig = kwargs.pop("ddconfig")
|
||||
super().__init__(
|
||||
encoder_config={
|
||||
"target": "comfy.ldm.modules.diffusionmodules.model.Encoder",
|
||||
"params": ddconfig,
|
||||
},
|
||||
decoder_config={
|
||||
"target": "comfy.ldm.modules.diffusionmodules.model.Decoder",
|
||||
"params": ddconfig,
|
||||
},
|
||||
**kwargs,
|
||||
)
|
||||
self.quant_conv = comfy.ops.disable_weight_init.Conv2d(
|
||||
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
|
||||
(1 + ddconfig["double_z"]) * embed_dim,
|
||||
1,
|
||||
)
|
||||
self.post_quant_conv = comfy.ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def get_autoencoder_params(self) -> list:
|
||||
params = super().get_autoencoder_params()
|
||||
return params
|
||||
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_reg_log: bool = False
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||
if self.max_batch_size is None:
|
||||
z = self.encoder(x)
|
||||
z = self.quant_conv(z)
|
||||
else:
|
||||
N = x.shape[0]
|
||||
bs = self.max_batch_size
|
||||
n_batches = int(math.ceil(N / bs))
|
||||
z = list()
|
||||
for i_batch in range(n_batches):
|
||||
z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
|
||||
z_batch = self.quant_conv(z_batch)
|
||||
z.append(z_batch)
|
||||
z = torch.cat(z, 0)
|
||||
|
||||
z, reg_log = self.regularization(z)
|
||||
if return_reg_log:
|
||||
return z, reg_log
|
||||
return z
|
||||
|
||||
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
|
||||
if self.max_batch_size is None:
|
||||
dec = self.post_quant_conv(z)
|
||||
dec = self.decoder(dec, **decoder_kwargs)
|
||||
else:
|
||||
N = z.shape[0]
|
||||
bs = self.max_batch_size
|
||||
n_batches = int(math.ceil(N / bs))
|
||||
dec = list()
|
||||
for i_batch in range(n_batches):
|
||||
dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
|
||||
dec_batch = self.decoder(dec_batch, **decoder_kwargs)
|
||||
dec.append(dec_batch)
|
||||
dec = torch.cat(dec, 0)
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
class AutoencoderKL(AutoencodingEngineLegacy):
|
||||
def __init__(self, **kwargs):
|
||||
if "lossconfig" in kwargs:
|
||||
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
||||
super().__init__(
|
||||
regularizer_config={
|
||||
"target": (
|
||||
"comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"
|
||||
)
|
||||
},
|
||||
**kwargs,
|
||||
)
|
||||
865
comfy/ldm/modules/attention.py
Normal file
865
comfy/ldm/modules/attention.py
Normal file
@@ -0,0 +1,865 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
from typing import Optional
|
||||
import logging
|
||||
|
||||
from .diffusionmodules.util import AlphaBlender, timestep_embedding
|
||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
|
||||
from comfy import model_management
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
|
||||
from comfy.cli_args import args
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype()
|
||||
|
||||
def get_attn_precision(attn_precision):
|
||||
if args.dont_upcast_attention:
|
||||
return None
|
||||
if FORCE_UPCAST_ATTENTION_DTYPE is not None:
|
||||
return FORCE_UPCAST_ATTENTION_DTYPE
|
||||
return attn_precision
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return{el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d
|
||||
|
||||
|
||||
def max_neg_value(t):
|
||||
return -torch.finfo(t.dtype).max
|
||||
|
||||
|
||||
def init_(tensor):
|
||||
dim = tensor.shape[-1]
|
||||
std = 1 / math.sqrt(dim)
|
||||
tensor.uniform_(-std, std)
|
||||
return tensor
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops):
|
||||
super().__init__()
|
||||
self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=ops):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
operations.Linear(dim, inner_dim, dtype=dtype, device=device),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
def Normalize(in_channels, dtype=None, device=None):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||
|
||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
|
||||
scale = dim_head ** -0.5
|
||||
|
||||
h = heads
|
||||
if skip_reshape:
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b * heads, -1, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
else:
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(b, -1, heads, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b * heads, -1, dim_head)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
# force cast to fp32 to avoid overflowing
|
||||
if attn_precision == torch.float32:
|
||||
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
|
||||
else:
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||
|
||||
del q, k
|
||||
|
||||
if exists(mask):
|
||||
if mask.dtype == torch.bool:
|
||||
mask = rearrange(mask, 'b ... -> b (...)') #TODO: check if this bool part matches pytorch attention
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
else:
|
||||
if len(mask.shape) == 2:
|
||||
bs = 1
|
||||
else:
|
||||
bs = mask.shape[0]
|
||||
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
|
||||
sim.add_(mask)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
sim = sim.softmax(dim=-1)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(b, heads, -1, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = query.shape
|
||||
else:
|
||||
b, _, dim_head = query.shape
|
||||
dim_head //= heads
|
||||
|
||||
scale = dim_head ** -0.5
|
||||
|
||||
if skip_reshape:
|
||||
query = query.reshape(b * heads, -1, dim_head)
|
||||
value = value.reshape(b * heads, -1, dim_head)
|
||||
key = key.reshape(b * heads, -1, dim_head).movedim(1, 2)
|
||||
else:
|
||||
query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
|
||||
value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
|
||||
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
|
||||
|
||||
|
||||
dtype = query.dtype
|
||||
upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32
|
||||
if upcast_attention:
|
||||
bytes_per_token = torch.finfo(torch.float32).bits//8
|
||||
else:
|
||||
bytes_per_token = torch.finfo(query.dtype).bits//8
|
||||
batch_x_heads, q_tokens, _ = query.shape
|
||||
_, _, k_tokens = key.shape
|
||||
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
||||
|
||||
mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
|
||||
|
||||
kv_chunk_size_min = None
|
||||
kv_chunk_size = None
|
||||
query_chunk_size = None
|
||||
|
||||
for x in [4096, 2048, 1024, 512, 256]:
|
||||
count = mem_free_total / (batch_x_heads * bytes_per_token * x * 4.0)
|
||||
if count >= k_tokens:
|
||||
kv_chunk_size = k_tokens
|
||||
query_chunk_size = x
|
||||
break
|
||||
|
||||
if query_chunk_size is None:
|
||||
query_chunk_size = 512
|
||||
|
||||
if mask is not None:
|
||||
if len(mask.shape) == 2:
|
||||
bs = 1
|
||||
else:
|
||||
bs = mask.shape[0]
|
||||
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
|
||||
|
||||
hidden_states = efficient_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
query_chunk_size=query_chunk_size,
|
||||
kv_chunk_size=kv_chunk_size,
|
||||
kv_chunk_size_min=kv_chunk_size_min,
|
||||
use_checkpoint=False,
|
||||
upcast_attention=upcast_attention,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
||||
return hidden_states
|
||||
|
||||
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
|
||||
scale = dim_head ** -0.5
|
||||
|
||||
h = heads
|
||||
if skip_reshape:
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b * heads, -1, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
else:
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(b, -1, heads, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b * heads, -1, dim_head)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
|
||||
mem_free_total = model_management.get_free_memory(q.device)
|
||||
|
||||
if attn_precision == torch.float32:
|
||||
element_size = 4
|
||||
upcast = True
|
||||
else:
|
||||
element_size = q.element_size()
|
||||
upcast = False
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
|
||||
modifier = 3
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
|
||||
|
||||
if mask is not None:
|
||||
if len(mask.shape) == 2:
|
||||
bs = 1
|
||||
else:
|
||||
bs = mask.shape[0]
|
||||
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
|
||||
|
||||
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
|
||||
first_op_done = False
|
||||
cleared_cache = False
|
||||
while True:
|
||||
try:
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
if upcast:
|
||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
|
||||
else:
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
|
||||
|
||||
if mask is not None:
|
||||
if len(mask.shape) == 2:
|
||||
s1 += mask[i:end]
|
||||
else:
|
||||
s1 += mask[:, i:end]
|
||||
|
||||
s2 = s1.softmax(dim=-1).to(v.dtype)
|
||||
del s1
|
||||
first_op_done = True
|
||||
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
break
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
if first_op_done == False:
|
||||
model_management.soft_empty_cache(True)
|
||||
if cleared_cache == False:
|
||||
cleared_cache = True
|
||||
logging.warning("out of memory error, emptying cache and trying again")
|
||||
continue
|
||||
steps *= 2
|
||||
if steps > 64:
|
||||
raise e
|
||||
logging.warning("out of memory error, increasing steps and trying again {}".format(steps))
|
||||
else:
|
||||
raise e
|
||||
|
||||
del q, k, v
|
||||
|
||||
r1 = (
|
||||
r1.unsqueeze(0)
|
||||
.reshape(b, heads, -1, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
return r1
|
||||
|
||||
BROKEN_XFORMERS = False
|
||||
try:
|
||||
x_vers = xformers.__version__
|
||||
# XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error)
|
||||
BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20")
|
||||
except:
|
||||
pass
|
||||
|
||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
|
||||
disabled_xformers = False
|
||||
|
||||
if BROKEN_XFORMERS:
|
||||
if b * heads > 65535:
|
||||
disabled_xformers = True
|
||||
|
||||
if not disabled_xformers:
|
||||
if torch.jit.is_tracing() or torch.jit.is_scripting():
|
||||
disabled_xformers = True
|
||||
|
||||
if disabled_xformers:
|
||||
return attention_pytorch(q, k, v, heads, mask)
|
||||
|
||||
if skip_reshape:
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b * heads, -1, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
else:
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b, -1, heads, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
pad = 8 - q.shape[1] % 8
|
||||
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
|
||||
mask_out[:, :, :mask.shape[-1]] = mask
|
||||
mask = mask_out[:, :, :mask.shape[-1]]
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||
|
||||
if skip_reshape:
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(b, heads, -1, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
else:
|
||||
out = (
|
||||
out.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
optimized_attention = attention_basic
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
logging.info("Using xformers cross attention")
|
||||
optimized_attention = attention_xformers
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
logging.info("Using pytorch cross attention")
|
||||
optimized_attention = attention_pytorch
|
||||
else:
|
||||
if args.use_split_cross_attention:
|
||||
logging.info("Using split optimization for cross attention")
|
||||
optimized_attention = attention_split
|
||||
else:
|
||||
logging.info("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||
optimized_attention = attention_sub_quad
|
||||
|
||||
optimized_attention_masked = optimized_attention
|
||||
|
||||
def optimized_attention_for_device(device, mask=False, small_input=False):
|
||||
if small_input:
|
||||
if model_management.pytorch_attention_enabled():
|
||||
return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases
|
||||
else:
|
||||
return attention_basic
|
||||
|
||||
if device == torch.device("cpu"):
|
||||
return attention_sub_quad
|
||||
|
||||
if mask:
|
||||
return optimized_attention_masked
|
||||
|
||||
return optimized_attention
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=ops):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
self.attn_precision = attn_precision
|
||||
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
if value is not None:
|
||||
v = self.to_v(value)
|
||||
del value
|
||||
else:
|
||||
v = self.to_v(context)
|
||||
|
||||
if mask is None:
|
||||
out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
|
||||
else:
|
||||
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
|
||||
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, attn_precision=None, dtype=None, device=None, operations=ops):
|
||||
super().__init__()
|
||||
|
||||
self.ff_in = ff_in or inner_dim is not None
|
||||
if inner_dim is None:
|
||||
inner_dim = dim
|
||||
|
||||
self.is_res = inner_dim == dim
|
||||
self.attn_precision = attn_precision
|
||||
|
||||
if self.ff_in:
|
||||
self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device)
|
||||
self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.disable_self_attn = disable_self_attn
|
||||
self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
||||
context_dim=context_dim if self.disable_self_attn else None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
|
||||
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
if disable_temporal_crossattention:
|
||||
if switch_temporal_ca_to_sa:
|
||||
raise ValueError
|
||||
else:
|
||||
self.attn2 = None
|
||||
else:
|
||||
context_dim_attn2 = None
|
||||
if not switch_temporal_ca_to_sa:
|
||||
context_dim_attn2 = context_dim
|
||||
|
||||
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
|
||||
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
|
||||
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.n_heads = n_heads
|
||||
self.d_head = d_head
|
||||
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
|
||||
|
||||
def forward(self, x, context=None, transformer_options={}):
|
||||
extra_options = {}
|
||||
block = transformer_options.get("block", None)
|
||||
block_index = transformer_options.get("block_index", 0)
|
||||
transformer_patches = {}
|
||||
transformer_patches_replace = {}
|
||||
|
||||
for k in transformer_options:
|
||||
if k == "patches":
|
||||
transformer_patches = transformer_options[k]
|
||||
elif k == "patches_replace":
|
||||
transformer_patches_replace = transformer_options[k]
|
||||
else:
|
||||
extra_options[k] = transformer_options[k]
|
||||
|
||||
extra_options["n_heads"] = self.n_heads
|
||||
extra_options["dim_head"] = self.d_head
|
||||
extra_options["attn_precision"] = self.attn_precision
|
||||
|
||||
if self.ff_in:
|
||||
x_skip = x
|
||||
x = self.ff_in(self.norm_in(x))
|
||||
if self.is_res:
|
||||
x += x_skip
|
||||
|
||||
n = self.norm1(x)
|
||||
if self.disable_self_attn:
|
||||
context_attn1 = context
|
||||
else:
|
||||
context_attn1 = None
|
||||
value_attn1 = None
|
||||
|
||||
if "attn1_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn1_patch"]
|
||||
if context_attn1 is None:
|
||||
context_attn1 = n
|
||||
value_attn1 = context_attn1
|
||||
for p in patch:
|
||||
n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)
|
||||
|
||||
if block is not None:
|
||||
transformer_block = (block[0], block[1], block_index)
|
||||
else:
|
||||
transformer_block = None
|
||||
attn1_replace_patch = transformer_patches_replace.get("attn1", {})
|
||||
block_attn1 = transformer_block
|
||||
if block_attn1 not in attn1_replace_patch:
|
||||
block_attn1 = block
|
||||
|
||||
if block_attn1 in attn1_replace_patch:
|
||||
if context_attn1 is None:
|
||||
context_attn1 = n
|
||||
value_attn1 = n
|
||||
n = self.attn1.to_q(n)
|
||||
context_attn1 = self.attn1.to_k(context_attn1)
|
||||
value_attn1 = self.attn1.to_v(value_attn1)
|
||||
n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
|
||||
n = self.attn1.to_out(n)
|
||||
else:
|
||||
n = self.attn1(n, context=context_attn1, value=value_attn1)
|
||||
|
||||
if "attn1_output_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn1_output_patch"]
|
||||
for p in patch:
|
||||
n = p(n, extra_options)
|
||||
|
||||
x += n
|
||||
if "middle_patch" in transformer_patches:
|
||||
patch = transformer_patches["middle_patch"]
|
||||
for p in patch:
|
||||
x = p(x, extra_options)
|
||||
|
||||
if self.attn2 is not None:
|
||||
n = self.norm2(x)
|
||||
if self.switch_temporal_ca_to_sa:
|
||||
context_attn2 = n
|
||||
else:
|
||||
context_attn2 = context
|
||||
value_attn2 = None
|
||||
if "attn2_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn2_patch"]
|
||||
value_attn2 = context_attn2
|
||||
for p in patch:
|
||||
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
|
||||
|
||||
attn2_replace_patch = transformer_patches_replace.get("attn2", {})
|
||||
block_attn2 = transformer_block
|
||||
if block_attn2 not in attn2_replace_patch:
|
||||
block_attn2 = block
|
||||
|
||||
if block_attn2 in attn2_replace_patch:
|
||||
if value_attn2 is None:
|
||||
value_attn2 = context_attn2
|
||||
n = self.attn2.to_q(n)
|
||||
context_attn2 = self.attn2.to_k(context_attn2)
|
||||
value_attn2 = self.attn2.to_v(value_attn2)
|
||||
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
|
||||
n = self.attn2.to_out(n)
|
||||
else:
|
||||
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
||||
|
||||
if "attn2_output_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn2_output_patch"]
|
||||
for p in patch:
|
||||
n = p(n, extra_options)
|
||||
|
||||
x += n
|
||||
if self.is_res:
|
||||
x_skip = x
|
||||
x = self.ff(self.norm3(x))
|
||||
if self.is_res:
|
||||
x += x_skip
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SpatialTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data.
|
||||
First, project the input (aka embedding)
|
||||
and reshape to b, t, d.
|
||||
Then apply standard transformer action.
|
||||
Finally, reshape to image
|
||||
NEW: use_linear for more efficiency instead of the 1x1 convs
|
||||
"""
|
||||
def __init__(self, in_channels, n_heads, d_head,
|
||||
depth=1, dropout=0., context_dim=None,
|
||||
disable_self_attn=False, use_linear=False,
|
||||
use_checkpoint=True, attn_precision=None, dtype=None, device=None, operations=ops):
|
||||
super().__init__()
|
||||
if exists(context_dim) and not isinstance(context_dim, list):
|
||||
context_dim = [context_dim] * depth
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = operations.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||
if not use_linear:
|
||||
self.proj_in = operations.Conv2d(in_channels,
|
||||
inner_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0, dtype=dtype, device=device)
|
||||
else:
|
||||
self.proj_in = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
||||
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
for d in range(depth)]
|
||||
)
|
||||
if not use_linear:
|
||||
self.proj_out = operations.Conv2d(inner_dim,in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0, dtype=dtype, device=device)
|
||||
else:
|
||||
self.proj_out = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
|
||||
self.use_linear = use_linear
|
||||
|
||||
def forward(self, x, context=None, transformer_options={}):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
if not isinstance(context, list):
|
||||
context = [context] * len(self.transformer_blocks)
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
if not self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
x = x.movedim(1, 3).flatten(1, 2).contiguous()
|
||||
if self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
x = block(x, context=context[i], transformer_options=transformer_options)
|
||||
if self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
x = x.reshape(x.shape[0], h, w, x.shape[-1]).movedim(3, 1).contiguous()
|
||||
if not self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
|
||||
|
||||
class SpatialVideoTransformer(SpatialTransformer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=1,
|
||||
dropout=0.0,
|
||||
use_linear=False,
|
||||
context_dim=None,
|
||||
use_spatial_context=False,
|
||||
timesteps=None,
|
||||
merge_strategy: str = "fixed",
|
||||
merge_factor: float = 0.5,
|
||||
time_context_dim=None,
|
||||
ff_in=False,
|
||||
checkpoint=False,
|
||||
time_depth=1,
|
||||
disable_self_attn=False,
|
||||
disable_temporal_crossattention=False,
|
||||
max_time_embed_period: int = 10000,
|
||||
attn_precision=None,
|
||||
dtype=None, device=None, operations=ops
|
||||
):
|
||||
super().__init__(
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=depth,
|
||||
dropout=dropout,
|
||||
use_checkpoint=checkpoint,
|
||||
context_dim=context_dim,
|
||||
use_linear=use_linear,
|
||||
disable_self_attn=disable_self_attn,
|
||||
attn_precision=attn_precision,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.time_depth = time_depth
|
||||
self.depth = depth
|
||||
self.max_time_embed_period = max_time_embed_period
|
||||
|
||||
time_mix_d_head = d_head
|
||||
n_time_mix_heads = n_heads
|
||||
|
||||
time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
|
||||
|
||||
inner_dim = n_heads * d_head
|
||||
if use_spatial_context:
|
||||
time_context_dim = context_dim
|
||||
|
||||
self.time_stack = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
n_time_mix_heads,
|
||||
time_mix_d_head,
|
||||
dropout=dropout,
|
||||
context_dim=time_context_dim,
|
||||
# timesteps=timesteps,
|
||||
checkpoint=checkpoint,
|
||||
ff_in=ff_in,
|
||||
inner_dim=time_mix_inner_dim,
|
||||
disable_self_attn=disable_self_attn,
|
||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||
attn_precision=attn_precision,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(self.depth)
|
||||
]
|
||||
)
|
||||
|
||||
assert len(self.time_stack) == len(self.transformer_blocks)
|
||||
|
||||
self.use_spatial_context = use_spatial_context
|
||||
self.in_channels = in_channels
|
||||
|
||||
time_embed_dim = self.in_channels * 4
|
||||
self.time_pos_embed = nn.Sequential(
|
||||
operations.Linear(self.in_channels, time_embed_dim, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(time_embed_dim, self.in_channels, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.time_mixer = AlphaBlender(
|
||||
alpha=merge_factor, merge_strategy=merge_strategy
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
time_context: Optional[torch.Tensor] = None,
|
||||
timesteps: Optional[int] = None,
|
||||
image_only_indicator: Optional[torch.Tensor] = None,
|
||||
transformer_options={}
|
||||
) -> torch.Tensor:
|
||||
_, _, h, w = x.shape
|
||||
x_in = x
|
||||
spatial_context = None
|
||||
if exists(context):
|
||||
spatial_context = context
|
||||
|
||||
if self.use_spatial_context:
|
||||
assert (
|
||||
context.ndim == 3
|
||||
), f"n dims of spatial context should be 3 but are {context.ndim}"
|
||||
|
||||
if time_context is None:
|
||||
time_context = context
|
||||
time_context_first_timestep = time_context[::timesteps]
|
||||
time_context = repeat(
|
||||
time_context_first_timestep, "b ... -> (b n) ...", n=h * w
|
||||
)
|
||||
elif time_context is not None and not self.use_spatial_context:
|
||||
time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)
|
||||
if time_context.ndim == 2:
|
||||
time_context = rearrange(time_context, "b c -> b 1 c")
|
||||
|
||||
x = self.norm(x)
|
||||
if not self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, "b c h w -> b (h w) c")
|
||||
if self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
|
||||
num_frames = torch.arange(timesteps, device=x.device)
|
||||
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
|
||||
num_frames = rearrange(num_frames, "b t -> (b t)")
|
||||
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False, max_period=self.max_time_embed_period).to(x.dtype)
|
||||
emb = self.time_pos_embed(t_emb)
|
||||
emb = emb[:, None, :]
|
||||
|
||||
for it_, (block, mix_block) in enumerate(
|
||||
zip(self.transformer_blocks, self.time_stack)
|
||||
):
|
||||
transformer_options["block_index"] = it_
|
||||
x = block(
|
||||
x,
|
||||
context=spatial_context,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
x_mix = x
|
||||
x_mix = x_mix + emb
|
||||
|
||||
B, S, C = x_mix.shape
|
||||
x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
|
||||
x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options
|
||||
x_mix = rearrange(
|
||||
x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
|
||||
)
|
||||
|
||||
x = self.time_mixer(x_spatial=x, x_temporal=x_mix, image_only_indicator=image_only_indicator)
|
||||
|
||||
if self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
||||
if not self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
out = x + x_in
|
||||
return out
|
||||
|
||||
|
||||
0
comfy/ldm/modules/diffusionmodules/__init__.py
Normal file
0
comfy/ldm/modules/diffusionmodules/__init__.py
Normal file
956
comfy/ldm/modules/diffusionmodules/mmdit.py
Normal file
956
comfy/ldm/modules/diffusionmodules/mmdit.py
Normal file
@@ -0,0 +1,956 @@
|
||||
import logging
|
||||
import math
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .. import attention
|
||||
from einops import rearrange, repeat
|
||||
from .util import timestep_embedding
|
||||
import comfy.ops
|
||||
|
||||
def default(x, y):
|
||||
if x is not None:
|
||||
return x
|
||||
return y
|
||||
|
||||
class Mlp(nn.Module):
|
||||
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=None,
|
||||
bias=True,
|
||||
drop=0.,
|
||||
use_conv=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
drop_probs = drop
|
||||
linear_layer = partial(operations.Conv2d, kernel_size=1) if use_conv else operations.Linear
|
||||
|
||||
self.fc1 = linear_layer(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
|
||||
self.act = act_layer()
|
||||
self.drop1 = nn.Dropout(drop_probs)
|
||||
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
||||
self.fc2 = linear_layer(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
|
||||
self.drop2 = nn.Dropout(drop_probs)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop1(x)
|
||||
x = self.norm(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop2(x)
|
||||
return x
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
""" 2D Image to Patch Embedding
|
||||
"""
|
||||
dynamic_img_pad: torch.jit.Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size: Optional[int] = 224,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer = None,
|
||||
flatten: bool = True,
|
||||
bias: bool = True,
|
||||
strict_img_size: bool = True,
|
||||
dynamic_img_pad: bool = True,
|
||||
padding_mode='circular',
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = (patch_size, patch_size)
|
||||
self.padding_mode = padding_mode
|
||||
if img_size is not None:
|
||||
self.img_size = (img_size, img_size)
|
||||
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
else:
|
||||
self.img_size = None
|
||||
self.grid_size = None
|
||||
self.num_patches = None
|
||||
|
||||
# flatten spatial dim and transpose to channels last, kept for bwd compat
|
||||
self.flatten = flatten
|
||||
self.strict_img_size = strict_img_size
|
||||
self.dynamic_img_pad = dynamic_img_pad
|
||||
|
||||
self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
# if self.img_size is not None:
|
||||
# if self.strict_img_size:
|
||||
# _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
|
||||
# _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
|
||||
# elif not self.dynamic_img_pad:
|
||||
# _assert(
|
||||
# H % self.patch_size[0] == 0,
|
||||
# f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
|
||||
# )
|
||||
# _assert(
|
||||
# W % self.patch_size[1] == 0,
|
||||
# f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
|
||||
# )
|
||||
if self.dynamic_img_pad:
|
||||
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
|
||||
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
|
||||
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode=self.padding_mode)
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
if shift is None:
|
||||
shift = torch.zeros_like(scale)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Sine/Cosine Positional Embedding Functions #
|
||||
#################################################################################
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(
|
||||
embed_dim,
|
||||
grid_size,
|
||||
cls_token=False,
|
||||
extra_tokens=0,
|
||||
scaling_factor=None,
|
||||
offset=None,
|
||||
):
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
return:
|
||||
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
grid_h = np.arange(grid_size, dtype=np.float32)
|
||||
grid_w = np.arange(grid_size, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
if scaling_factor is not None:
|
||||
grid = grid / scaling_factor
|
||||
if offset is not None:
|
||||
grid = grid - offset
|
||||
|
||||
grid = grid.reshape([2, 1, grid_size, grid_size])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token and extra_tokens > 0:
|
||||
pos_embed = np.concatenate(
|
||||
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
|
||||
)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000**omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos, device=None, dtype=torch.float32):
|
||||
omega = torch.arange(embed_dim // 2, device=device, dtype=dtype)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000**omega # (D/2,)
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||
emb_sin = torch.sin(out) # (M, D/2)
|
||||
emb_cos = torch.cos(out) # (M, D/2)
|
||||
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
||||
return emb
|
||||
|
||||
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, val_center=7.5, val_magnitude=7.5, device=None, dtype=torch.float32):
|
||||
small = min(h, w)
|
||||
val_h = (h / small) * val_magnitude
|
||||
val_w = (w / small) * val_magnitude
|
||||
grid_h, grid_w = torch.meshgrid(torch.linspace(-val_h + val_center, val_h + val_center, h, device=device, dtype=dtype), torch.linspace(-val_w + val_center, val_w + val_center, w, device=device, dtype=dtype), indexing='ij')
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
|
||||
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Embedding Layers for Timesteps and Class Labels #
|
||||
#################################################################################
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
def forward(self, t, dtype, **kwargs):
|
||||
t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class VectorEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds a flat vector of dimension input_dim
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim: int, hidden_size: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(input_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
emb = self.mlp(x)
|
||||
return emb
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Core DiT Model #
|
||||
#################################################################################
|
||||
|
||||
|
||||
def split_qkv(qkv, head_dim):
|
||||
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
|
||||
return qkv[0], qkv[1], qkv[2]
|
||||
|
||||
def optimized_attention(qkv, num_heads):
|
||||
return attention.optimized_attention(qkv[0], qkv[1], qkv[2], num_heads)
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_scale: Optional[float] = None,
|
||||
proj_drop: float = 0.0,
|
||||
attn_mode: str = "xformers",
|
||||
pre_only: bool = False,
|
||||
qk_norm: Optional[str] = None,
|
||||
rmsnorm: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
if not pre_only:
|
||||
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
assert attn_mode in self.ATTENTION_MODES
|
||||
self.attn_mode = attn_mode
|
||||
self.pre_only = pre_only
|
||||
|
||||
if qk_norm == "rms":
|
||||
self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
|
||||
self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
|
||||
elif qk_norm == "ln":
|
||||
self.ln_q = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
|
||||
self.ln_k = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
|
||||
elif qk_norm is None:
|
||||
self.ln_q = nn.Identity()
|
||||
self.ln_k = nn.Identity()
|
||||
else:
|
||||
raise ValueError(qk_norm)
|
||||
|
||||
def pre_attention(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, L, C = x.shape
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = split_qkv(qkv, self.head_dim)
|
||||
q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1)
|
||||
k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1)
|
||||
return (q, k, v)
|
||||
|
||||
def post_attention(self, x: torch.Tensor) -> torch.Tensor:
|
||||
assert not self.pre_only
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
qkv = self.pre_attention(x)
|
||||
x = optimized_attention(
|
||||
qkv, num_heads=self.num_heads
|
||||
)
|
||||
x = self.post_attention(x)
|
||||
return x
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(
|
||||
self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None
|
||||
):
|
||||
"""
|
||||
Initialize the RMSNorm normalization layer.
|
||||
Args:
|
||||
dim (int): The dimension of the input tensor.
|
||||
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
||||
Attributes:
|
||||
eps (float): A small value added to the denominator for numerical stability.
|
||||
weight (nn.Parameter): Learnable scaling parameter.
|
||||
"""
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.learnable_scale = elementwise_affine
|
||||
if self.learnable_scale:
|
||||
self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
|
||||
def _norm(self, x):
|
||||
"""
|
||||
Apply the RMSNorm normalization to the input tensor.
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
"""
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the RMSNorm layer.
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
Returns:
|
||||
torch.Tensor: The output tensor after applying RMSNorm.
|
||||
"""
|
||||
x = self._norm(x)
|
||||
if self.learnable_scale:
|
||||
return x * self.weight.to(device=x.device, dtype=x.dtype)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class SwiGLUFeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the FeedForward module.
|
||||
|
||||
Args:
|
||||
dim (int): Input dimension.
|
||||
hidden_dim (int): Hidden dimension of the feedforward layer.
|
||||
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
|
||||
ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
|
||||
|
||||
Attributes:
|
||||
w1 (ColumnParallelLinear): Linear transformation for the first layer.
|
||||
w2 (RowParallelLinear): Linear transformation for the second layer.
|
||||
w3 (ColumnParallelLinear): Linear transformation for the third layer.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class DismantledBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with gated adaptive layer norm (adaLN) conditioning.
|
||||
"""
|
||||
|
||||
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
attn_mode: str = "xformers",
|
||||
qkv_bias: bool = False,
|
||||
pre_only: bool = False,
|
||||
rmsnorm: bool = False,
|
||||
scale_mod_only: bool = False,
|
||||
swiglu: bool = False,
|
||||
qk_norm: Optional[str] = None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
assert attn_mode in self.ATTENTION_MODES
|
||||
if not rmsnorm:
|
||||
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
else:
|
||||
self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = SelfAttention(
|
||||
dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
pre_only=pre_only,
|
||||
qk_norm=qk_norm,
|
||||
rmsnorm=rmsnorm,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
if not pre_only:
|
||||
if not rmsnorm:
|
||||
self.norm2 = operations.LayerNorm(
|
||||
hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
|
||||
)
|
||||
else:
|
||||
self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
if not pre_only:
|
||||
if not swiglu:
|
||||
self.mlp = Mlp(
|
||||
in_features=hidden_size,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=lambda: nn.GELU(approximate="tanh"),
|
||||
drop=0,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
else:
|
||||
self.mlp = SwiGLUFeedForward(
|
||||
dim=hidden_size,
|
||||
hidden_dim=mlp_hidden_dim,
|
||||
multiple_of=256,
|
||||
)
|
||||
self.scale_mod_only = scale_mod_only
|
||||
if not scale_mod_only:
|
||||
n_mods = 6 if not pre_only else 2
|
||||
else:
|
||||
n_mods = 4 if not pre_only else 1
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(), operations.Linear(hidden_size, n_mods * hidden_size, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
self.pre_only = pre_only
|
||||
|
||||
def pre_attention(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
if not self.pre_only:
|
||||
if not self.scale_mod_only:
|
||||
(
|
||||
shift_msa,
|
||||
scale_msa,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
) = self.adaLN_modulation(c).chunk(6, dim=1)
|
||||
else:
|
||||
shift_msa = None
|
||||
shift_mlp = None
|
||||
(
|
||||
scale_msa,
|
||||
gate_msa,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
) = self.adaLN_modulation(
|
||||
c
|
||||
).chunk(4, dim=1)
|
||||
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
|
||||
return qkv, (
|
||||
x,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
)
|
||||
else:
|
||||
if not self.scale_mod_only:
|
||||
(
|
||||
shift_msa,
|
||||
scale_msa,
|
||||
) = self.adaLN_modulation(
|
||||
c
|
||||
).chunk(2, dim=1)
|
||||
else:
|
||||
shift_msa = None
|
||||
scale_msa = self.adaLN_modulation(c)
|
||||
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
|
||||
return qkv, None
|
||||
|
||||
def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp):
|
||||
assert not self.pre_only
|
||||
x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
|
||||
x = x + gate_mlp.unsqueeze(1) * self.mlp(
|
||||
modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||
)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
assert not self.pre_only
|
||||
qkv, intermediates = self.pre_attention(x, c)
|
||||
attn = optimized_attention(
|
||||
qkv,
|
||||
num_heads=self.attn.num_heads,
|
||||
)
|
||||
return self.post_attention(attn, *intermediates)
|
||||
|
||||
|
||||
def block_mixing(*args, use_checkpoint=True, **kwargs):
|
||||
if use_checkpoint:
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
_block_mixing, *args, use_reentrant=False, **kwargs
|
||||
)
|
||||
else:
|
||||
return _block_mixing(*args, **kwargs)
|
||||
|
||||
|
||||
def _block_mixing(context, x, context_block, x_block, c):
|
||||
context_qkv, context_intermediates = context_block.pre_attention(context, c)
|
||||
|
||||
x_qkv, x_intermediates = x_block.pre_attention(x, c)
|
||||
|
||||
o = []
|
||||
for t in range(3):
|
||||
o.append(torch.cat((context_qkv[t], x_qkv[t]), dim=1))
|
||||
qkv = tuple(o)
|
||||
|
||||
attn = optimized_attention(
|
||||
qkv,
|
||||
num_heads=x_block.attn.num_heads,
|
||||
)
|
||||
context_attn, x_attn = (
|
||||
attn[:, : context_qkv[0].shape[1]],
|
||||
attn[:, context_qkv[0].shape[1] :],
|
||||
)
|
||||
|
||||
if not context_block.pre_only:
|
||||
context = context_block.post_attention(context_attn, *context_intermediates)
|
||||
|
||||
else:
|
||||
context = None
|
||||
x = x_block.post_attention(x_attn, *x_intermediates)
|
||||
return context, x
|
||||
|
||||
|
||||
class JointBlock(nn.Module):
|
||||
"""just a small wrapper to serve as a fsdp unit"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
pre_only = kwargs.pop("pre_only")
|
||||
qk_norm = kwargs.pop("qk_norm", None)
|
||||
self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)
|
||||
self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return block_mixing(
|
||||
*args, context_block=self.context_block, x_block=self.x_block, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of DiT.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
patch_size: int,
|
||||
out_channels: int,
|
||||
total_out_channels: Optional[int] = None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = (
|
||||
operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||
if (total_out_channels is None)
|
||||
else operations.Linear(hidden_size, total_out_channels, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class SelfAttentionContext(nn.Module):
|
||||
def __init__(self, dim, heads=8, dim_head=64, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
dim_head = dim // heads
|
||||
inner_dim = dim
|
||||
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.proj = operations.Linear(inner_dim, dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = split_qkv(qkv, self.dim_head)
|
||||
x = optimized_attention((q.reshape(q.shape[0], q.shape[1], -1), k, v), self.heads)
|
||||
return self.proj(x)
|
||||
|
||||
class ContextProcessorBlock(nn.Module):
|
||||
def __init__(self, context_size, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm1 = operations.LayerNorm(context_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.attn = SelfAttentionContext(context_size, dtype=dtype, device=device, operations=operations)
|
||||
self.norm2 = operations.LayerNorm(context_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.mlp = Mlp(in_features=context_size, hidden_features=(context_size * 4), act_layer=lambda: nn.GELU(approximate="tanh"), drop=0, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x):
|
||||
x += self.attn(self.norm1(x))
|
||||
x += self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
class ContextProcessor(nn.Module):
|
||||
def __init__(self, context_size, num_layers, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList([ContextProcessorBlock(context_size, dtype=dtype, device=device, operations=operations) for i in range(num_layers)])
|
||||
self.norm = operations.LayerNorm(context_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
for i, l in enumerate(self.layers):
|
||||
x = l(x)
|
||||
return self.norm(x)
|
||||
|
||||
class MMDiT(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int = 32,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 4,
|
||||
depth: int = 28,
|
||||
# hidden_size: Optional[int] = None,
|
||||
# num_heads: Optional[int] = None,
|
||||
mlp_ratio: float = 4.0,
|
||||
learn_sigma: bool = False,
|
||||
adm_in_channels: Optional[int] = None,
|
||||
context_embedder_config: Optional[Dict] = None,
|
||||
compile_core: bool = False,
|
||||
use_checkpoint: bool = False,
|
||||
register_length: int = 0,
|
||||
attn_mode: str = "torch",
|
||||
rmsnorm: bool = False,
|
||||
scale_mod_only: bool = False,
|
||||
swiglu: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
pos_embed_scaling_factor: Optional[float] = None,
|
||||
pos_embed_offset: Optional[float] = None,
|
||||
pos_embed_max_size: Optional[int] = None,
|
||||
num_patches = None,
|
||||
qk_norm: Optional[str] = None,
|
||||
qkv_bias: bool = True,
|
||||
context_processor_layers = None,
|
||||
context_size = 4096,
|
||||
num_blocks = None,
|
||||
final_layer = True,
|
||||
dtype = None, #TODO
|
||||
device = None,
|
||||
operations = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
default_out_channels = in_channels * 2 if learn_sigma else in_channels
|
||||
self.out_channels = default(out_channels, default_out_channels)
|
||||
self.patch_size = patch_size
|
||||
self.pos_embed_scaling_factor = pos_embed_scaling_factor
|
||||
self.pos_embed_offset = pos_embed_offset
|
||||
self.pos_embed_max_size = pos_embed_max_size
|
||||
|
||||
# hidden_size = default(hidden_size, 64 * depth)
|
||||
# num_heads = default(num_heads, hidden_size // 64)
|
||||
|
||||
# apply magic --> this defines a head_size of 64
|
||||
self.hidden_size = 64 * depth
|
||||
num_heads = depth
|
||||
if num_blocks is None:
|
||||
num_blocks = depth
|
||||
|
||||
self.depth = depth
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.x_embedder = PatchEmbed(
|
||||
input_size,
|
||||
patch_size,
|
||||
in_channels,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
strict_img_size=self.pos_embed_max_size is None,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.y_embedder = None
|
||||
if adm_in_channels is not None:
|
||||
assert isinstance(adm_in_channels, int)
|
||||
self.y_embedder = VectorEmbedder(adm_in_channels, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
if context_processor_layers is not None:
|
||||
self.context_processor = ContextProcessor(context_size, context_processor_layers, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.context_processor = None
|
||||
|
||||
self.context_embedder = nn.Identity()
|
||||
if context_embedder_config is not None:
|
||||
if context_embedder_config["target"] == "torch.nn.Linear":
|
||||
self.context_embedder = operations.Linear(**context_embedder_config["params"], dtype=dtype, device=device)
|
||||
|
||||
self.register_length = register_length
|
||||
if self.register_length > 0:
|
||||
self.register = nn.Parameter(torch.randn(1, register_length, self.hidden_size, dtype=dtype, device=device))
|
||||
|
||||
# num_patches = self.x_embedder.num_patches
|
||||
# Will use fixed sin-cos embedding:
|
||||
# just use a buffer already
|
||||
if num_patches is not None:
|
||||
self.register_buffer(
|
||||
"pos_embed",
|
||||
torch.empty(1, num_patches, self.hidden_size, dtype=dtype, device=device),
|
||||
)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.joint_blocks = nn.ModuleList(
|
||||
[
|
||||
JointBlock(
|
||||
self.hidden_size,
|
||||
num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
pre_only=(i == num_blocks - 1) and final_layer,
|
||||
rmsnorm=rmsnorm,
|
||||
scale_mod_only=scale_mod_only,
|
||||
swiglu=swiglu,
|
||||
qk_norm=qk_norm,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
for i in range(num_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
if final_layer:
|
||||
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
if compile_core:
|
||||
assert False
|
||||
self.forward_core_with_concat = torch.compile(self.forward_core_with_concat)
|
||||
|
||||
def cropped_pos_embed(self, hw, device=None):
|
||||
p = self.x_embedder.patch_size[0]
|
||||
h, w = hw
|
||||
# patched size
|
||||
h = (h + 1) // p
|
||||
w = (w + 1) // p
|
||||
if self.pos_embed is None:
|
||||
return get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=device)
|
||||
assert self.pos_embed_max_size is not None
|
||||
assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
|
||||
assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size)
|
||||
top = (self.pos_embed_max_size - h) // 2
|
||||
left = (self.pos_embed_max_size - w) // 2
|
||||
spatial_pos_embed = rearrange(
|
||||
self.pos_embed,
|
||||
"1 (h w) c -> 1 h w c",
|
||||
h=self.pos_embed_max_size,
|
||||
w=self.pos_embed_max_size,
|
||||
)
|
||||
spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
|
||||
spatial_pos_embed = rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c")
|
||||
# print(spatial_pos_embed, top, left, h, w)
|
||||
# # t = get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, 7.875, 7.875, device=device) #matches exactly for 1024 res
|
||||
# t = get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, 7.5, 7.5, device=device) #scales better
|
||||
# # print(t)
|
||||
# return t
|
||||
return spatial_pos_embed
|
||||
|
||||
def unpatchify(self, x, hw=None):
|
||||
"""
|
||||
x: (N, T, patch_size**2 * C)
|
||||
imgs: (N, H, W, C)
|
||||
"""
|
||||
c = self.out_channels
|
||||
p = self.x_embedder.patch_size[0]
|
||||
if hw is None:
|
||||
h = w = int(x.shape[1] ** 0.5)
|
||||
else:
|
||||
h, w = hw
|
||||
h = (h + 1) // p
|
||||
w = (w + 1) // p
|
||||
assert h * w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||
x = torch.einsum("nhwpqc->nchpwq", x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||
return imgs
|
||||
|
||||
def forward_core_with_concat(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
c_mod: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
control = None,
|
||||
) -> torch.Tensor:
|
||||
if self.register_length > 0:
|
||||
context = torch.cat(
|
||||
(
|
||||
repeat(self.register, "1 ... -> b ...", b=x.shape[0]),
|
||||
default(context, torch.Tensor([]).type_as(x)),
|
||||
),
|
||||
1,
|
||||
)
|
||||
|
||||
# context is B, L', D
|
||||
# x is B, L, D
|
||||
blocks = len(self.joint_blocks)
|
||||
for i in range(blocks):
|
||||
context, x = self.joint_blocks[i](
|
||||
context,
|
||||
x,
|
||||
c=c_mod,
|
||||
use_checkpoint=self.use_checkpoint,
|
||||
)
|
||||
if control is not None:
|
||||
control_o = control.get("output")
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
x += add
|
||||
|
||||
x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
control = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of DiT.
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N,) tensor of class labels
|
||||
"""
|
||||
|
||||
if self.context_processor is not None:
|
||||
context = self.context_processor(context)
|
||||
|
||||
hw = x.shape[-2:]
|
||||
x = self.x_embedder(x) + comfy.ops.cast_to_input(self.cropped_pos_embed(hw, device=x.device), x)
|
||||
c = self.t_embedder(t, dtype=x.dtype) # (N, D)
|
||||
if y is not None and self.y_embedder is not None:
|
||||
y = self.y_embedder(y) # (N, D)
|
||||
c = c + y # (N, D)
|
||||
|
||||
if context is not None:
|
||||
context = self.context_embedder(context)
|
||||
|
||||
x = self.forward_core_with_concat(x, c, context, control)
|
||||
|
||||
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
|
||||
return x[:,:,:hw[-2],:hw[-1]]
|
||||
|
||||
|
||||
class OpenAISignatureMMDITWrapper(MMDiT):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
control = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return super().forward(x, timesteps, context=context, y=y, control=control)
|
||||
|
||||
650
comfy/ldm/modules/diffusionmodules/model.py
Normal file
650
comfy/ldm/modules/diffusionmodules/model.py
Normal file
@@ -0,0 +1,650 @@
|
||||
# pytorch_diffusion + derived encoder decoder
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from typing import Optional, Any
|
||||
import logging
|
||||
|
||||
from comfy import model_management
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
if model_management.xformers_enabled_vae():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
From Fairseq.
|
||||
Build sinusoidal embeddings.
|
||||
This matches the implementation in tensor2tensor, but differs slightly
|
||||
from the description in Section 3.5 of "Attention Is All You Need".
|
||||
"""
|
||||
assert len(timesteps.shape) == 1
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||
emb = emb.to(device=timesteps.device)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
||||
return emb
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x*torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
try:
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
except: #operation not implemented for bf16
|
||||
b, c, h, w = x.shape
|
||||
out = torch.empty((b, c, h*2, w*2), dtype=x.dtype, layout=x.layout, device=x.device)
|
||||
split = 8
|
||||
l = out.shape[1] // split
|
||||
for i in range(0, out.shape[1], l):
|
||||
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=2.0, mode="nearest").to(x.dtype)
|
||||
del x
|
||||
x = out
|
||||
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
pad = (0,1,0,1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
||||
dropout, temb_channels=512):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.swish = torch.nn.SiLU(inplace=True)
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = ops.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = ops.Linear(temb_channels,
|
||||
out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout, inplace=True)
|
||||
self.conv2 = ops.Conv2d(out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = ops.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
else:
|
||||
self.nin_shortcut = ops.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = self.swish(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(self.swish(temb))[:,:,None,None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = self.swish(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x+h
|
||||
|
||||
def slice_attention(q, k, v):
|
||||
r1 = torch.zeros_like(k, device=q.device)
|
||||
scale = (int(q.shape[-1])**(-0.5))
|
||||
|
||||
mem_free_total = model_management.get_free_memory(q.device)
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
|
||||
while True:
|
||||
try:
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = torch.bmm(q[:, i:end], k) * scale
|
||||
|
||||
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
|
||||
del s1
|
||||
|
||||
r1[:, :, i:end] = torch.bmm(v, s2)
|
||||
del s2
|
||||
break
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
model_management.soft_empty_cache(True)
|
||||
steps *= 2
|
||||
if steps > 128:
|
||||
raise e
|
||||
logging.warning("out of memory error, increasing steps and trying again {}".format(steps))
|
||||
|
||||
return r1
|
||||
|
||||
def normal_attention(q, k, v):
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
|
||||
q = q.reshape(b,c,h*w)
|
||||
q = q.permute(0,2,1) # b,hw,c
|
||||
k = k.reshape(b,c,h*w) # b,c,hw
|
||||
v = v.reshape(b,c,h*w)
|
||||
|
||||
r1 = slice_attention(q, k, v)
|
||||
h_ = r1.reshape(b,c,h,w)
|
||||
del r1
|
||||
return h_
|
||||
|
||||
def xformers_attention(q, k, v):
|
||||
# compute attention
|
||||
B, C, H, W = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
try:
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
||||
out = out.transpose(1, 2).reshape(B, C, H, W)
|
||||
except NotImplementedError as e:
|
||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||
return out
|
||||
|
||||
def pytorch_attention(q, k, v):
|
||||
# compute attention
|
||||
B, C, H, W = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
try:
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out = out.transpose(2, 3).reshape(B, C, H, W)
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||
return out
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
if model_management.xformers_enabled_vae():
|
||||
logging.info("Using xformers attention in VAE")
|
||||
self.optimized_attention = xformers_attention
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
logging.info("Using pytorch attention in VAE")
|
||||
self.optimized_attention = pytorch_attention
|
||||
else:
|
||||
logging.info("Using split attention in VAE")
|
||||
self.optimized_attention = normal_attention
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
h_ = self.optimized_attention(q, k, v)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
|
||||
|
||||
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||
return AttnBlock(in_channels)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = self.ch*4
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.use_timestep = use_timestep
|
||||
if self.use_timestep:
|
||||
# timestep embedding
|
||||
self.temb = nn.Module()
|
||||
self.temb.dense = nn.ModuleList([
|
||||
ops.Linear(self.ch,
|
||||
self.temb_ch),
|
||||
ops.Linear(self.temb_ch,
|
||||
self.temb_ch),
|
||||
])
|
||||
|
||||
# downsampling
|
||||
self.conv_in = ops.Conv2d(in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch*in_ch_mult[i_level]
|
||||
block_out = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions-1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch*ch_mult[i_level]
|
||||
skip_in = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
if i_block == self.num_res_blocks:
|
||||
skip_in = ch*in_ch_mult[i_level]
|
||||
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = ops.Conv2d(block_in,
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x, t=None, context=None):
|
||||
#assert x.shape[2] == x.shape[3] == self.resolution
|
||||
if context is not None:
|
||||
# assume aligned context, cat along channel axis
|
||||
x = torch.cat((x, context), dim=1)
|
||||
if self.use_timestep:
|
||||
# timestep embedding
|
||||
assert t is not None
|
||||
temb = get_timestep_embedding(t, self.ch)
|
||||
temb = self.temb.dense[0](temb)
|
||||
temb = nonlinearity(temb)
|
||||
temb = self.temb.dense[1](temb)
|
||||
else:
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions-1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h = self.up[i_level].block[i_block](
|
||||
torch.cat([h, hs.pop()], dim=1), temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.conv_out.weight
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
|
||||
**ignore_kwargs):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
# downsampling
|
||||
self.conv_in = ops.Conv2d(in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch*in_ch_mult[i_level]
|
||||
block_out = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions-1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = ops.Conv2d(block_in,
|
||||
2*z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
# timestep embedding
|
||||
temb = None
|
||||
# downsampling
|
||||
h = self.conv_in(x)
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](h, temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
if i_level != self.num_resolutions-1:
|
||||
h = self.down[i_level].downsample(h)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
||||
conv_out_op=ops.Conv2d,
|
||||
resnet_op=ResnetBlock,
|
||||
attn_op=AttnBlock,
|
||||
**ignorekwargs):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
block_in = ch*ch_mult[self.num_resolutions-1]
|
||||
curr_res = resolution // 2**(self.num_resolutions-1)
|
||||
self.z_shape = (1,z_channels,curr_res,curr_res)
|
||||
logging.debug("Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = ops.Conv2d(z_channels,
|
||||
block_in,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = resnet_op(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = attn_op(block_in)
|
||||
self.mid.block_2 = resnet_op(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
block.append(resnet_op(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(attn_op(block_in))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = conv_out_op(block_in,
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, z, **kwargs):
|
||||
#assert z.shape[1:] == self.z_shape[1:]
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb, **kwargs)
|
||||
h = self.mid.attn_1(h, **kwargs)
|
||||
h = self.mid.block_2(h, temb, **kwargs)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h = self.up[i_level].block[i_block](h, temb, **kwargs)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, **kwargs)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h, **kwargs)
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
return h
|
||||
892
comfy/ldm/modules/diffusionmodules/openaimodel.py
Normal file
892
comfy/ldm/modules/diffusionmodules/openaimodel.py
Normal file
@@ -0,0 +1,892 @@
|
||||
from abc import abstractmethod
|
||||
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
import logging
|
||||
|
||||
from .util import (
|
||||
checkpoint,
|
||||
avg_pool_nd,
|
||||
zero_module,
|
||||
timestep_embedding,
|
||||
AlphaBlender,
|
||||
)
|
||||
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
|
||||
from comfy.ldm.util import exists
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
class TimestepBlock(nn.Module):
|
||||
"""
|
||||
Any module where forward() takes timestep embeddings as a second argument.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
Apply the module to `x` given `emb` timestep embeddings.
|
||||
"""
|
||||
|
||||
#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
|
||||
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
|
||||
for layer in ts:
|
||||
if isinstance(layer, VideoResBlock):
|
||||
x = layer(x, emb, num_video_frames, image_only_indicator)
|
||||
elif isinstance(layer, TimestepBlock):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, SpatialVideoTransformer):
|
||||
x = layer(x, context, time_context, num_video_frames, image_only_indicator, transformer_options)
|
||||
if "transformer_index" in transformer_options:
|
||||
transformer_options["transformer_index"] += 1
|
||||
elif isinstance(layer, SpatialTransformer):
|
||||
x = layer(x, context, transformer_options)
|
||||
if "transformer_index" in transformer_options:
|
||||
transformer_options["transformer_index"] += 1
|
||||
elif isinstance(layer, Upsample):
|
||||
x = layer(x, output_shape=output_shape)
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
"""
|
||||
A sequential module that passes timestep embeddings to the children that
|
||||
support it as an extra input.
|
||||
"""
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return forward_timestep_embed(self, *args, **kwargs)
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=ops):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
if use_conv:
|
||||
self.conv = operations.conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, output_shape=None):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.dims == 3:
|
||||
shape = [x.shape[2], x.shape[3] * 2, x.shape[4] * 2]
|
||||
if output_shape is not None:
|
||||
shape[1] = output_shape[3]
|
||||
shape[2] = output_shape[4]
|
||||
else:
|
||||
shape = [x.shape[2] * 2, x.shape[3] * 2]
|
||||
if output_shape is not None:
|
||||
shape[0] = output_shape[2]
|
||||
shape[1] = output_shape[3]
|
||||
|
||||
x = F.interpolate(x, size=shape, mode="nearest")
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer with an optional convolution.
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=ops):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
stride = 2 if dims != 3 else (1, 2, 2)
|
||||
if use_conv:
|
||||
self.op = operations.conv_nd(
|
||||
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype, device=device
|
||||
)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class ResBlock(TimestepBlock):
|
||||
"""
|
||||
A residual block that can optionally change the number of channels.
|
||||
:param channels: the number of input channels.
|
||||
:param emb_channels: the number of timestep embedding channels.
|
||||
:param dropout: the rate of dropout.
|
||||
:param out_channels: if specified, the number of out channels.
|
||||
:param use_conv: if True and out_channels is specified, use a spatial
|
||||
convolution instead of a smaller 1x1 convolution to change the
|
||||
channels in the skip connection.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D.
|
||||
:param use_checkpoint: if True, use gradient checkpointing on this module.
|
||||
:param up: if True, use this block for upsampling.
|
||||
:param down: if True, use this block for downsampling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
emb_channels,
|
||||
dropout,
|
||||
out_channels=None,
|
||||
use_conv=False,
|
||||
use_scale_shift_norm=False,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
up=False,
|
||||
down=False,
|
||||
kernel_size=3,
|
||||
exchange_temb_dims=False,
|
||||
skip_t_emb=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=ops
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.emb_channels = emb_channels
|
||||
self.dropout = dropout
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
self.exchange_temb_dims = exchange_temb_dims
|
||||
|
||||
if isinstance(kernel_size, list):
|
||||
padding = [k // 2 for k in kernel_size]
|
||||
else:
|
||||
padding = kernel_size // 2
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
operations.GroupNorm(32, channels, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.updown = up or down
|
||||
|
||||
if up:
|
||||
self.h_upd = Upsample(channels, False, dims, dtype=dtype, device=device)
|
||||
self.x_upd = Upsample(channels, False, dims, dtype=dtype, device=device)
|
||||
elif down:
|
||||
self.h_upd = Downsample(channels, False, dims, dtype=dtype, device=device)
|
||||
self.x_upd = Downsample(channels, False, dims, dtype=dtype, device=device)
|
||||
else:
|
||||
self.h_upd = self.x_upd = nn.Identity()
|
||||
|
||||
self.skip_t_emb = skip_t_emb
|
||||
if self.skip_t_emb:
|
||||
self.emb_layers = None
|
||||
self.exchange_temb_dims = False
|
||||
else:
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(
|
||||
emb_channels,
|
||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
|
||||
),
|
||||
)
|
||||
self.out_layers = nn.Sequential(
|
||||
operations.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
|
||||
,
|
||||
)
|
||||
|
||||
if self.out_channels == channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
elif use_conv:
|
||||
self.skip_connection = operations.conv_nd(
|
||||
dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device
|
||||
)
|
||||
else:
|
||||
self.skip_connection = operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
Apply the block to a Tensor, conditioned on a timestep embedding.
|
||||
:param x: an [N x C x ...] Tensor of features.
|
||||
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
return checkpoint(
|
||||
self._forward, (x, emb), self.parameters(), self.use_checkpoint
|
||||
)
|
||||
|
||||
|
||||
def _forward(self, x, emb):
|
||||
if self.updown:
|
||||
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
||||
h = in_rest(x)
|
||||
h = self.h_upd(h)
|
||||
x = self.x_upd(x)
|
||||
h = in_conv(h)
|
||||
else:
|
||||
h = self.in_layers(x)
|
||||
|
||||
emb_out = None
|
||||
if not self.skip_t_emb:
|
||||
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||
while len(emb_out.shape) < len(h.shape):
|
||||
emb_out = emb_out[..., None]
|
||||
if self.use_scale_shift_norm:
|
||||
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
||||
h = out_norm(h)
|
||||
if emb_out is not None:
|
||||
scale, shift = th.chunk(emb_out, 2, dim=1)
|
||||
h *= (1 + scale)
|
||||
h += shift
|
||||
h = out_rest(h)
|
||||
else:
|
||||
if emb_out is not None:
|
||||
if self.exchange_temb_dims:
|
||||
emb_out = emb_out.movedim(1, 2)
|
||||
h = h + emb_out
|
||||
h = self.out_layers(h)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
class VideoResBlock(ResBlock):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
emb_channels: int,
|
||||
dropout: float,
|
||||
video_kernel_size=3,
|
||||
merge_strategy: str = "fixed",
|
||||
merge_factor: float = 0.5,
|
||||
out_channels=None,
|
||||
use_conv: bool = False,
|
||||
use_scale_shift_norm: bool = False,
|
||||
dims: int = 2,
|
||||
use_checkpoint: bool = False,
|
||||
up: bool = False,
|
||||
down: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=ops
|
||||
):
|
||||
super().__init__(
|
||||
channels,
|
||||
emb_channels,
|
||||
dropout,
|
||||
out_channels=out_channels,
|
||||
use_conv=use_conv,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
up=up,
|
||||
down=down,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
|
||||
self.time_stack = ResBlock(
|
||||
default(out_channels, channels),
|
||||
emb_channels,
|
||||
dropout=dropout,
|
||||
dims=3,
|
||||
out_channels=default(out_channels, channels),
|
||||
use_scale_shift_norm=False,
|
||||
use_conv=False,
|
||||
up=False,
|
||||
down=False,
|
||||
kernel_size=video_kernel_size,
|
||||
use_checkpoint=use_checkpoint,
|
||||
exchange_temb_dims=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
self.time_mixer = AlphaBlender(
|
||||
alpha=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
rearrange_pattern="b t -> b 1 t 1 1",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: th.Tensor,
|
||||
emb: th.Tensor,
|
||||
num_video_frames: int,
|
||||
image_only_indicator = None,
|
||||
) -> th.Tensor:
|
||||
x = super().forward(x, emb)
|
||||
|
||||
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
|
||||
|
||||
x = self.time_stack(
|
||||
x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames)
|
||||
)
|
||||
x = self.time_mixer(
|
||||
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
|
||||
)
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
return x
|
||||
|
||||
|
||||
class Timestep(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, t):
|
||||
return timestep_embedding(t, self.dim)
|
||||
|
||||
def apply_control(h, control, name):
|
||||
if control is not None and name in control and len(control[name]) > 0:
|
||||
ctrl = control[name].pop()
|
||||
if ctrl is not None:
|
||||
try:
|
||||
h += ctrl
|
||||
except:
|
||||
logging.warning("warning control could not be applied {} {}".format(h.shape, ctrl.shape))
|
||||
return h
|
||||
|
||||
class UNetModel(nn.Module):
|
||||
"""
|
||||
The full UNet model with attention and timestep embedding.
|
||||
:param in_channels: channels in the input Tensor.
|
||||
:param model_channels: base channel count for the model.
|
||||
:param out_channels: channels in the output Tensor.
|
||||
:param num_res_blocks: number of residual blocks per downsample.
|
||||
:param dropout: the dropout probability.
|
||||
:param channel_mult: channel multiplier for each level of the UNet.
|
||||
:param conv_resample: if True, use learned convolutions for upsampling and
|
||||
downsampling.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D.
|
||||
:param num_classes: if specified (as an int), then this model will be
|
||||
class-conditional with `num_classes` classes.
|
||||
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
|
||||
:param num_heads: the number of attention heads in each attention layer.
|
||||
:param num_heads_channels: if specified, ignore num_heads and instead use
|
||||
a fixed channel width per attention head.
|
||||
:param num_heads_upsample: works with num_heads to set a different number
|
||||
of heads for upsampling. Deprecated.
|
||||
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
||||
:param resblock_updown: use residual blocks for up/downsampling.
|
||||
:param use_new_attention_order: use a different attention pattern for potentially
|
||||
increased efficiency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
in_channels,
|
||||
model_channels,
|
||||
out_channels,
|
||||
num_res_blocks,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
num_classes=None,
|
||||
use_checkpoint=False,
|
||||
dtype=th.float32,
|
||||
num_heads=-1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
use_spatial_transformer=False, # custom transformer support
|
||||
transformer_depth=1, # custom transformer support
|
||||
context_dim=None, # custom transformer support
|
||||
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||
legacy=True,
|
||||
disable_self_attentions=None,
|
||||
num_attention_blocks=None,
|
||||
disable_middle_self_attn=False,
|
||||
use_linear_in_transformer=False,
|
||||
adm_in_channels=None,
|
||||
transformer_depth_middle=None,
|
||||
transformer_depth_output=None,
|
||||
use_temporal_resblock=False,
|
||||
use_temporal_attention=False,
|
||||
time_context_dim=None,
|
||||
extra_ff_mix_layer=False,
|
||||
use_spatial_context=False,
|
||||
merge_strategy=None,
|
||||
merge_factor=0.0,
|
||||
video_kernel_size=None,
|
||||
disable_temporal_crossattention=False,
|
||||
max_ddpm_temb_period=10000,
|
||||
attn_precision=None,
|
||||
device=None,
|
||||
operations=ops,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if context_dim is not None:
|
||||
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
||||
# from omegaconf.listconfig import ListConfig
|
||||
# if type(context_dim) == ListConfig:
|
||||
# context_dim = list(context_dim)
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
num_heads_upsample = num_heads
|
||||
|
||||
if num_heads == -1:
|
||||
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
|
||||
if num_head_channels == -1:
|
||||
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
if isinstance(num_res_blocks, int):
|
||||
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
||||
else:
|
||||
if len(num_res_blocks) != len(channel_mult):
|
||||
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
||||
"as a list/tuple (per-level) with the same length as channel_mult")
|
||||
self.num_res_blocks = num_res_blocks
|
||||
|
||||
if disable_self_attentions is not None:
|
||||
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
||||
assert len(disable_self_attentions) == len(channel_mult)
|
||||
if num_attention_blocks is not None:
|
||||
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
||||
|
||||
transformer_depth = transformer_depth[:]
|
||||
transformer_depth_output = transformer_depth_output[:]
|
||||
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.num_classes = num_classes
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = dtype
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
self.use_temporal_resblocks = use_temporal_resblock
|
||||
self.predict_codebook_ids = n_embed is not None
|
||||
|
||||
self.default_num_video_frames = None
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
|
||||
)
|
||||
|
||||
if self.num_classes is not None:
|
||||
if isinstance(self.num_classes, int):
|
||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device)
|
||||
elif self.num_classes == "continuous":
|
||||
logging.debug("setting up linear c_adm embedding layer")
|
||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||
elif self.num_classes == "sequential":
|
||||
assert adm_in_channels is not None
|
||||
self.label_emb = nn.Sequential(
|
||||
nn.Sequential(
|
||||
operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
|
||||
)
|
||||
]
|
||||
)
|
||||
self._feature_size = model_channels
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
|
||||
def get_attention_layer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=1,
|
||||
context_dim=None,
|
||||
use_checkpoint=False,
|
||||
disable_self_attn=False,
|
||||
):
|
||||
if use_temporal_attention:
|
||||
return SpatialVideoTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=depth,
|
||||
context_dim=context_dim,
|
||||
time_context_dim=time_context_dim,
|
||||
dropout=dropout,
|
||||
ff_in=extra_ff_mix_layer,
|
||||
use_spatial_context=use_spatial_context,
|
||||
merge_strategy=merge_strategy,
|
||||
merge_factor=merge_factor,
|
||||
checkpoint=use_checkpoint,
|
||||
use_linear=use_linear_in_transformer,
|
||||
disable_self_attn=disable_self_attn,
|
||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||
max_time_embed_period=max_ddpm_temb_period,
|
||||
attn_precision=attn_precision,
|
||||
dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
else:
|
||||
return SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=depth, context_dim=context_dim,
|
||||
disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def get_resblock(
|
||||
merge_factor,
|
||||
merge_strategy,
|
||||
video_kernel_size,
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels,
|
||||
dims,
|
||||
use_checkpoint,
|
||||
use_scale_shift_norm,
|
||||
down=False,
|
||||
up=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=ops
|
||||
):
|
||||
if self.use_temporal_resblocks:
|
||||
return VideoResBlock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
channels=ch,
|
||||
emb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=out_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=down,
|
||||
up=up,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
else:
|
||||
return ResBlock(
|
||||
channels=ch,
|
||||
emb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=out_channels,
|
||||
use_checkpoint=use_checkpoint,
|
||||
dims=dims,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=down,
|
||||
up=up,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for nr in range(self.num_res_blocks[level]):
|
||||
layers = [
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
num_transformers = transformer_depth.pop(0)
|
||||
if num_transformers > 0:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
if exists(disable_self_attentions):
|
||||
disabled_sa = disable_self_attentions[level]
|
||||
else:
|
||||
disabled_sa = False
|
||||
|
||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||
layers.append(get_attention_layer(
|
||||
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
mid_block = [
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=None,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)]
|
||||
|
||||
self.middle_block = None
|
||||
if transformer_depth_middle >= -1:
|
||||
if transformer_depth_middle >= 0:
|
||||
mid_block += [get_attention_layer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint
|
||||
),
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=None,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)]
|
||||
self.middle_block = TimestepEmbedSequential(*mid_block)
|
||||
self._feature_size += ch
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
for level, mult in list(enumerate(channel_mult))[::-1]:
|
||||
for i in range(self.num_res_blocks[level] + 1):
|
||||
ich = input_block_chans.pop()
|
||||
layers = [
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch + ich,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=model_channels * mult,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
]
|
||||
ch = model_channels * mult
|
||||
num_transformers = transformer_depth_output.pop()
|
||||
if num_transformers > 0:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
if exists(disable_self_attentions):
|
||||
disabled_sa = disable_self_attentions[level]
|
||||
else:
|
||||
disabled_sa = False
|
||||
|
||||
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
get_attention_layer(
|
||||
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint
|
||||
)
|
||||
)
|
||||
if level and i == self.num_res_blocks[level]:
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
get_resblock(
|
||||
merge_factor=merge_factor,
|
||||
merge_strategy=merge_strategy,
|
||||
video_kernel_size=video_kernel_size,
|
||||
ch=ch,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
up=True,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
if resblock_updown
|
||||
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations)
|
||||
)
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
|
||||
self.out = nn.Sequential(
|
||||
operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device),
|
||||
)
|
||||
if self.predict_codebook_ids:
|
||||
self.id_predictor = nn.Sequential(
|
||||
operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
||||
operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
|
||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||
)
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:param context: conditioning plugged in via crossattn
|
||||
:param y: an [N] Tensor of labels, if class-conditional.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
transformer_options["original_shape"] = list(x.shape)
|
||||
transformer_options["transformer_index"] = 0
|
||||
transformer_patches = transformer_options.get("patches", {})
|
||||
|
||||
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
|
||||
image_only_indicator = kwargs.get("image_only_indicator", None)
|
||||
time_context = kwargs.get("time_context", None)
|
||||
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
hs = []
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
h = x
|
||||
for id, module in enumerate(self.input_blocks):
|
||||
transformer_options["block"] = ("input", id)
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||
h = apply_control(h, control, 'input')
|
||||
if "input_block_patch" in transformer_patches:
|
||||
patch = transformer_patches["input_block_patch"]
|
||||
for p in patch:
|
||||
h = p(h, transformer_options)
|
||||
|
||||
hs.append(h)
|
||||
if "input_block_patch_after_skip" in transformer_patches:
|
||||
patch = transformer_patches["input_block_patch_after_skip"]
|
||||
for p in patch:
|
||||
h = p(h, transformer_options)
|
||||
|
||||
transformer_options["block"] = ("middle", 0)
|
||||
if self.middle_block is not None:
|
||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||
h = apply_control(h, control, 'middle')
|
||||
|
||||
|
||||
for id, module in enumerate(self.output_blocks):
|
||||
transformer_options["block"] = ("output", id)
|
||||
hsp = hs.pop()
|
||||
hsp = apply_control(hsp, control, 'output')
|
||||
|
||||
if "output_block_patch" in transformer_patches:
|
||||
patch = transformer_patches["output_block_patch"]
|
||||
for p in patch:
|
||||
h, hsp = p(h, hsp, transformer_options)
|
||||
|
||||
h = th.cat([h, hsp], dim=1)
|
||||
del hsp
|
||||
if len(hs) > 0:
|
||||
output_shape = hs[-1].shape
|
||||
else:
|
||||
output_shape = None
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||
h = h.type(x.dtype)
|
||||
if self.predict_codebook_ids:
|
||||
return self.id_predictor(h)
|
||||
else:
|
||||
return self.out(h)
|
||||
85
comfy/ldm/modules/diffusionmodules/upscaling.py
Normal file
85
comfy/ldm/modules/diffusionmodules/upscaling.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
|
||||
from .util import extract_into_tensor, make_beta_schedule
|
||||
from comfy.ldm.util import default
|
||||
|
||||
|
||||
class AbstractLowScaleModel(nn.Module):
|
||||
# for concatenating a downsampled image to the latent representation
|
||||
def __init__(self, noise_schedule_config=None):
|
||||
super(AbstractLowScaleModel, self).__init__()
|
||||
if noise_schedule_config is not None:
|
||||
self.register_schedule(**noise_schedule_config)
|
||||
|
||||
def register_schedule(self, beta_schedule="linear", timesteps=1000,
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
|
||||
cosine_s=cosine_s)
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.linear_start = linear_start
|
||||
self.linear_end = linear_end
|
||||
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
||||
|
||||
to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||
|
||||
self.register_buffer('betas', to_torch(betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
||||
|
||||
def q_sample(self, x_start, t, noise=None, seed=None):
|
||||
if noise is None:
|
||||
if seed is None:
|
||||
noise = torch.randn_like(x_start)
|
||||
else:
|
||||
noise = torch.randn(x_start.size(), dtype=x_start.dtype, layout=x_start.layout, generator=torch.manual_seed(seed)).to(x_start.device)
|
||||
return (extract_into_tensor(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
|
||||
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise)
|
||||
|
||||
def forward(self, x):
|
||||
return x, None
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class SimpleImageConcat(AbstractLowScaleModel):
|
||||
# no noise level conditioning
|
||||
def __init__(self):
|
||||
super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
|
||||
self.max_noise_level = 0
|
||||
|
||||
def forward(self, x):
|
||||
# fix to constant noise level
|
||||
return x, torch.zeros(x.shape[0], device=x.device).long()
|
||||
|
||||
|
||||
class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
|
||||
def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
|
||||
super().__init__(noise_schedule_config=noise_schedule_config)
|
||||
self.max_noise_level = max_noise_level
|
||||
|
||||
def forward(self, x, noise_level=None, seed=None):
|
||||
if noise_level is None:
|
||||
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
||||
else:
|
||||
assert isinstance(noise_level, torch.Tensor)
|
||||
z = self.q_sample(x, noise_level, seed=seed)
|
||||
return z, noise_level
|
||||
|
||||
|
||||
|
||||
306
comfy/ldm/modules/diffusionmodules/util.py
Normal file
306
comfy/ldm/modules/diffusionmodules/util.py
Normal file
@@ -0,0 +1,306 @@
|
||||
# adopted from
|
||||
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
||||
# and
|
||||
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
||||
# and
|
||||
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
||||
#
|
||||
# thanks!
|
||||
|
||||
|
||||
import os
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from einops import repeat, rearrange
|
||||
|
||||
from comfy.ldm.util import instantiate_from_config
|
||||
|
||||
class AlphaBlender(nn.Module):
|
||||
strategies = ["learned", "fixed", "learned_with_images"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
alpha: float,
|
||||
merge_strategy: str = "learned_with_images",
|
||||
rearrange_pattern: str = "b t -> (b t) 1 1",
|
||||
):
|
||||
super().__init__()
|
||||
self.merge_strategy = merge_strategy
|
||||
self.rearrange_pattern = rearrange_pattern
|
||||
|
||||
assert (
|
||||
merge_strategy in self.strategies
|
||||
), f"merge_strategy needs to be in {self.strategies}"
|
||||
|
||||
if self.merge_strategy == "fixed":
|
||||
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
||||
elif (
|
||||
self.merge_strategy == "learned"
|
||||
or self.merge_strategy == "learned_with_images"
|
||||
):
|
||||
self.register_parameter(
|
||||
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
||||
|
||||
def get_alpha(self, image_only_indicator: torch.Tensor, device) -> torch.Tensor:
|
||||
# skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t)
|
||||
if self.merge_strategy == "fixed":
|
||||
# make shape compatible
|
||||
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
|
||||
alpha = self.mix_factor.to(device)
|
||||
elif self.merge_strategy == "learned":
|
||||
alpha = torch.sigmoid(self.mix_factor.to(device))
|
||||
# make shape compatible
|
||||
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
|
||||
elif self.merge_strategy == "learned_with_images":
|
||||
if image_only_indicator is None:
|
||||
alpha = rearrange(torch.sigmoid(self.mix_factor.to(device)), "... -> ... 1")
|
||||
else:
|
||||
alpha = torch.where(
|
||||
image_only_indicator.bool(),
|
||||
torch.ones(1, 1, device=image_only_indicator.device),
|
||||
rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"),
|
||||
)
|
||||
alpha = rearrange(alpha, self.rearrange_pattern)
|
||||
# make shape compatible
|
||||
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return alpha
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_spatial,
|
||||
x_temporal,
|
||||
image_only_indicator=None,
|
||||
) -> torch.Tensor:
|
||||
alpha = self.get_alpha(image_only_indicator, x_spatial.device)
|
||||
x = (
|
||||
alpha.to(x_spatial.dtype) * x_spatial
|
||||
+ (1.0 - alpha).to(x_spatial.dtype) * x_temporal
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
if schedule == "linear":
|
||||
betas = (
|
||||
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
|
||||
)
|
||||
|
||||
elif schedule == "cosine":
|
||||
timesteps = (
|
||||
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
||||
)
|
||||
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
||||
alphas = torch.cos(alphas).pow(2)
|
||||
alphas = alphas / alphas[0]
|
||||
betas = 1 - alphas[1:] / alphas[:-1]
|
||||
betas = torch.clamp(betas, min=0, max=0.999)
|
||||
|
||||
elif schedule == "squaredcos_cap_v2": # used for karlo prior
|
||||
# return early
|
||||
return betas_for_alpha_bar(
|
||||
n_timestep,
|
||||
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
||||
)
|
||||
|
||||
elif schedule == "sqrt_linear":
|
||||
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
||||
elif schedule == "sqrt":
|
||||
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
|
||||
else:
|
||||
raise ValueError(f"schedule '{schedule}' unknown.")
|
||||
return betas
|
||||
|
||||
|
||||
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
|
||||
if ddim_discr_method == 'uniform':
|
||||
c = num_ddpm_timesteps // num_ddim_timesteps
|
||||
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
||||
elif ddim_discr_method == 'quad':
|
||||
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
|
||||
else:
|
||||
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
|
||||
|
||||
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
||||
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
||||
steps_out = ddim_timesteps + 1
|
||||
if verbose:
|
||||
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
||||
return steps_out
|
||||
|
||||
|
||||
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
||||
# select alphas for computing the variance schedule
|
||||
alphas = alphacums[ddim_timesteps]
|
||||
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
||||
|
||||
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
||||
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
||||
if verbose:
|
||||
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
||||
print(f'For the chosen value of eta, which is {eta}, '
|
||||
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
|
||||
return sigmas, alphas, alphas_prev
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function,
|
||||
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
||||
:param num_diffusion_timesteps: the number of betas to produce.
|
||||
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that
|
||||
part of the diffusion process.
|
||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
"""
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return np.array(betas)
|
||||
|
||||
|
||||
def extract_into_tensor(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
||||
|
||||
|
||||
def checkpoint(func, inputs, params, flag):
|
||||
"""
|
||||
Evaluate a function without caching intermediate activations, allowing for
|
||||
reduced memory at the expense of extra compute in the backward pass.
|
||||
:param func: the function to evaluate.
|
||||
:param inputs: the argument sequence to pass to `func`.
|
||||
:param params: a sequence of parameters `func` depends on but does not
|
||||
explicitly take as arguments.
|
||||
:param flag: if False, disable gradient checkpointing.
|
||||
"""
|
||||
if flag:
|
||||
args = tuple(inputs) + tuple(params)
|
||||
return CheckpointFunction.apply(func, len(inputs), *args)
|
||||
else:
|
||||
return func(*inputs)
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, length, *args):
|
||||
ctx.run_function = run_function
|
||||
ctx.input_tensors = list(args[:length])
|
||||
ctx.input_params = list(args[length:])
|
||||
ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
|
||||
"dtype": torch.get_autocast_gpu_dtype(),
|
||||
"cache_enabled": torch.is_autocast_cache_enabled()}
|
||||
with torch.no_grad():
|
||||
output_tensors = ctx.run_function(*ctx.input_tensors)
|
||||
return output_tensors
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *output_grads):
|
||||
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
||||
with torch.enable_grad(), \
|
||||
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
|
||||
# Fixes a bug where the first op in run_function modifies the
|
||||
# Tensor storage in place, which is not allowed for detach()'d
|
||||
# Tensors.
|
||||
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
||||
output_tensors = ctx.run_function(*shallow_copies)
|
||||
input_grads = torch.autograd.grad(
|
||||
output_tensors,
|
||||
ctx.input_tensors + ctx.input_params,
|
||||
output_grads,
|
||||
allow_unused=True,
|
||||
)
|
||||
del ctx.input_tensors
|
||||
del ctx.input_params
|
||||
del output_tensors
|
||||
return (None, None) + input_grads
|
||||
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
if not repeat_only:
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
|
||||
)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
else:
|
||||
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
||||
return embedding
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def scale_module(module, scale):
|
||||
"""
|
||||
Scale the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().mul_(scale)
|
||||
return module
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def avg_pool_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D average pooling module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.AvgPool1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.AvgPool2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
class HybridConditioner(nn.Module):
|
||||
|
||||
def __init__(self, c_concat_config, c_crossattn_config):
|
||||
super().__init__()
|
||||
self.concat_conditioner = instantiate_from_config(c_concat_config)
|
||||
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
|
||||
|
||||
def forward(self, c_concat, c_crossattn):
|
||||
c_concat = self.concat_conditioner(c_concat)
|
||||
c_crossattn = self.crossattn_conditioner(c_crossattn)
|
||||
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
|
||||
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
||||
noise = lambda: torch.randn(shape, device=device)
|
||||
return repeat_noise() if repeat else noise()
|
||||
0
comfy/ldm/modules/distributions/__init__.py
Normal file
0
comfy/ldm/modules/distributions/__init__.py
Normal file
92
comfy/ldm/modules/distributions/distributions.py
Normal file
92
comfy/ldm/modules/distributions/distributions.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AbstractDistribution:
|
||||
def sample(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def mode(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DiracDistribution(AbstractDistribution):
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def sample(self):
|
||||
return self.value
|
||||
|
||||
def mode(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
||||
|
||||
def sample(self):
|
||||
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(torch.pow(self.mean, 2)
|
||||
+ self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3])
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
dim=[1, 2, 3])
|
||||
|
||||
def nll(self, sample, dims=[1,2,3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
"""
|
||||
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
||||
Compute the KL divergence between two gaussians.
|
||||
Shapes are automatically broadcasted, so batches can be compared to
|
||||
scalars, among other use cases.
|
||||
"""
|
||||
tensor = None
|
||||
for obj in (mean1, logvar1, mean2, logvar2):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tensor = obj
|
||||
break
|
||||
assert tensor is not None, "at least one argument must be a Tensor"
|
||||
|
||||
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||
# Tensors, but it does not work for torch.exp().
|
||||
logvar1, logvar2 = [
|
||||
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
|
||||
for x in (logvar1, logvar2)
|
||||
]
|
||||
|
||||
return 0.5 * (
|
||||
-1.0
|
||||
+ logvar2
|
||||
- logvar1
|
||||
+ torch.exp(logvar1 - logvar2)
|
||||
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
||||
)
|
||||
80
comfy/ldm/modules/ema.py
Normal file
80
comfy/ldm/modules/ema.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LitEma(nn.Module):
|
||||
def __init__(self, model, decay=0.9999, use_num_upates=True):
|
||||
super().__init__()
|
||||
if decay < 0.0 or decay > 1.0:
|
||||
raise ValueError('Decay must be between 0 and 1')
|
||||
|
||||
self.m_name2s_name = {}
|
||||
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
|
||||
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
|
||||
else torch.tensor(-1, dtype=torch.int))
|
||||
|
||||
for name, p in model.named_parameters():
|
||||
if p.requires_grad:
|
||||
# remove as '.'-character is not allowed in buffers
|
||||
s_name = name.replace('.', '')
|
||||
self.m_name2s_name.update({name: s_name})
|
||||
self.register_buffer(s_name, p.clone().detach().data)
|
||||
|
||||
self.collected_params = []
|
||||
|
||||
def reset_num_updates(self):
|
||||
del self.num_updates
|
||||
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
|
||||
|
||||
def forward(self, model):
|
||||
decay = self.decay
|
||||
|
||||
if self.num_updates >= 0:
|
||||
self.num_updates += 1
|
||||
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
|
||||
|
||||
one_minus_decay = 1.0 - decay
|
||||
|
||||
with torch.no_grad():
|
||||
m_param = dict(model.named_parameters())
|
||||
shadow_params = dict(self.named_buffers())
|
||||
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
sname = self.m_name2s_name[key]
|
||||
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
||||
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
|
||||
def copy_to(self, model):
|
||||
m_param = dict(model.named_parameters())
|
||||
shadow_params = dict(self.named_buffers())
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
|
||||
def store(self, parameters):
|
||||
"""
|
||||
Save the current parameters for restoring later.
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
temporarily stored.
|
||||
"""
|
||||
self.collected_params = [param.clone() for param in parameters]
|
||||
|
||||
def restore(self, parameters):
|
||||
"""
|
||||
Restore the parameters stored with the `store` method.
|
||||
Useful to validate the model with EMA parameters without affecting the
|
||||
original optimization process. Store the parameters before the
|
||||
`copy_to` method. After validation (or model saving), use this to
|
||||
restore the former parameters.
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
updated with the stored parameters.
|
||||
"""
|
||||
for c_param, param in zip(self.collected_params, parameters):
|
||||
param.data.copy_(c_param.data)
|
||||
0
comfy/ldm/modules/encoders/__init__.py
Normal file
0
comfy/ldm/modules/encoders/__init__.py
Normal file
35
comfy/ldm/modules/encoders/noise_aug_modules.py
Normal file
35
comfy/ldm/modules/encoders/noise_aug_modules.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from ..diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||
from ..diffusionmodules.openaimodel import Timestep
|
||||
import torch
|
||||
|
||||
class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
|
||||
def __init__(self, *args, clip_stats_path=None, timestep_dim=256, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if clip_stats_path is None:
|
||||
clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim)
|
||||
else:
|
||||
clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu")
|
||||
self.register_buffer("data_mean", clip_mean[None, :], persistent=False)
|
||||
self.register_buffer("data_std", clip_std[None, :], persistent=False)
|
||||
self.time_embed = Timestep(timestep_dim)
|
||||
|
||||
def scale(self, x):
|
||||
# re-normalize to centered mean and unit variance
|
||||
x = (x - self.data_mean.to(x.device)) * 1. / self.data_std.to(x.device)
|
||||
return x
|
||||
|
||||
def unscale(self, x):
|
||||
# back to original data stats
|
||||
x = (x * self.data_std.to(x.device)) + self.data_mean.to(x.device)
|
||||
return x
|
||||
|
||||
def forward(self, x, noise_level=None, seed=None):
|
||||
if noise_level is None:
|
||||
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
||||
else:
|
||||
assert isinstance(noise_level, torch.Tensor)
|
||||
x = self.scale(x)
|
||||
z = self.q_sample(x, noise_level, seed=seed)
|
||||
z = self.unscale(z)
|
||||
noise_level = self.time_embed(noise_level)
|
||||
return z, noise_level
|
||||
274
comfy/ldm/modules/sub_quadratic_attention.py
Normal file
274
comfy/ldm/modules/sub_quadratic_attention.py
Normal file
@@ -0,0 +1,274 @@
|
||||
# original source:
|
||||
# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
|
||||
# license:
|
||||
# MIT
|
||||
# credit:
|
||||
# Amin Rezaei (original author)
|
||||
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
|
||||
# implementation of:
|
||||
# Self-attention Does Not Need O(n2) Memory":
|
||||
# https://arxiv.org/abs/2112.05682v2
|
||||
|
||||
from functools import partial
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
import math
|
||||
import logging
|
||||
|
||||
try:
|
||||
from typing import Optional, NamedTuple, List, Protocol
|
||||
except ImportError:
|
||||
from typing import Optional, NamedTuple, List
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from torch import Tensor
|
||||
from typing import List
|
||||
|
||||
from comfy import model_management
|
||||
|
||||
def dynamic_slice(
|
||||
x: Tensor,
|
||||
starts: List[int],
|
||||
sizes: List[int],
|
||||
) -> Tensor:
|
||||
slicing = [slice(start, start + size) for start, size in zip(starts, sizes)]
|
||||
return x[slicing]
|
||||
|
||||
class AttnChunk(NamedTuple):
|
||||
exp_values: Tensor
|
||||
exp_weights_sum: Tensor
|
||||
max_score: Tensor
|
||||
|
||||
class SummarizeChunk(Protocol):
|
||||
@staticmethod
|
||||
def __call__(
|
||||
query: Tensor,
|
||||
key_t: Tensor,
|
||||
value: Tensor,
|
||||
) -> AttnChunk: ...
|
||||
|
||||
class ComputeQueryChunkAttn(Protocol):
|
||||
@staticmethod
|
||||
def __call__(
|
||||
query: Tensor,
|
||||
key_t: Tensor,
|
||||
value: Tensor,
|
||||
) -> Tensor: ...
|
||||
|
||||
def _summarize_chunk(
|
||||
query: Tensor,
|
||||
key_t: Tensor,
|
||||
value: Tensor,
|
||||
scale: float,
|
||||
upcast_attention: bool,
|
||||
mask,
|
||||
) -> AttnChunk:
|
||||
if upcast_attention:
|
||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||
query = query.float()
|
||||
key_t = key_t.float()
|
||||
attn_weights = torch.baddbmm(
|
||||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||
query,
|
||||
key_t,
|
||||
alpha=scale,
|
||||
beta=0,
|
||||
)
|
||||
else:
|
||||
attn_weights = torch.baddbmm(
|
||||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||
query,
|
||||
key_t,
|
||||
alpha=scale,
|
||||
beta=0,
|
||||
)
|
||||
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
||||
max_score = max_score.detach()
|
||||
attn_weights -= max_score
|
||||
if mask is not None:
|
||||
attn_weights += mask
|
||||
torch.exp(attn_weights, out=attn_weights)
|
||||
exp_weights = attn_weights.to(value.dtype)
|
||||
exp_values = torch.bmm(exp_weights, value)
|
||||
max_score = max_score.squeeze(-1)
|
||||
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
|
||||
|
||||
def _query_chunk_attention(
|
||||
query: Tensor,
|
||||
key_t: Tensor,
|
||||
value: Tensor,
|
||||
summarize_chunk: SummarizeChunk,
|
||||
kv_chunk_size: int,
|
||||
mask,
|
||||
) -> Tensor:
|
||||
batch_x_heads, k_channels_per_head, k_tokens = key_t.shape
|
||||
_, _, v_channels_per_head = value.shape
|
||||
|
||||
def chunk_scanner(chunk_idx: int, mask) -> AttnChunk:
|
||||
key_chunk = dynamic_slice(
|
||||
key_t,
|
||||
(0, 0, chunk_idx),
|
||||
(batch_x_heads, k_channels_per_head, kv_chunk_size)
|
||||
)
|
||||
value_chunk = dynamic_slice(
|
||||
value,
|
||||
(0, chunk_idx, 0),
|
||||
(batch_x_heads, kv_chunk_size, v_channels_per_head)
|
||||
)
|
||||
if mask is not None:
|
||||
mask = mask[:,:,chunk_idx:chunk_idx + kv_chunk_size]
|
||||
|
||||
return summarize_chunk(query, key_chunk, value_chunk, mask=mask)
|
||||
|
||||
chunks: List[AttnChunk] = [
|
||||
chunk_scanner(chunk, mask) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
|
||||
]
|
||||
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
|
||||
chunk_values, chunk_weights, chunk_max = acc_chunk
|
||||
|
||||
global_max, _ = torch.max(chunk_max, 0, keepdim=True)
|
||||
max_diffs = torch.exp(chunk_max - global_max)
|
||||
chunk_values *= torch.unsqueeze(max_diffs, -1)
|
||||
chunk_weights *= max_diffs
|
||||
|
||||
all_values = chunk_values.sum(dim=0)
|
||||
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
|
||||
return all_values / all_weights
|
||||
|
||||
# TODO: refactor CrossAttention#get_attention_scores to share code with this
|
||||
def _get_attention_scores_no_kv_chunking(
|
||||
query: Tensor,
|
||||
key_t: Tensor,
|
||||
value: Tensor,
|
||||
scale: float,
|
||||
upcast_attention: bool,
|
||||
mask,
|
||||
) -> Tensor:
|
||||
if upcast_attention:
|
||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||
query = query.float()
|
||||
key_t = key_t.float()
|
||||
attn_scores = torch.baddbmm(
|
||||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||
query,
|
||||
key_t,
|
||||
alpha=scale,
|
||||
beta=0,
|
||||
)
|
||||
else:
|
||||
attn_scores = torch.baddbmm(
|
||||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||
query,
|
||||
key_t,
|
||||
alpha=scale,
|
||||
beta=0,
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
attn_scores += mask
|
||||
try:
|
||||
attn_probs = attn_scores.softmax(dim=-1)
|
||||
del attn_scores
|
||||
except model_management.OOM_EXCEPTION:
|
||||
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
|
||||
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values
|
||||
torch.exp(attn_scores, out=attn_scores)
|
||||
summed = torch.sum(attn_scores, dim=-1, keepdim=True)
|
||||
attn_scores /= summed
|
||||
attn_probs = attn_scores
|
||||
|
||||
hidden_states_slice = torch.bmm(attn_probs.to(value.dtype), value)
|
||||
return hidden_states_slice
|
||||
|
||||
class ScannedChunk(NamedTuple):
|
||||
chunk_idx: int
|
||||
attn_chunk: AttnChunk
|
||||
|
||||
def efficient_dot_product_attention(
|
||||
query: Tensor,
|
||||
key_t: Tensor,
|
||||
value: Tensor,
|
||||
query_chunk_size=1024,
|
||||
kv_chunk_size: Optional[int] = None,
|
||||
kv_chunk_size_min: Optional[int] = None,
|
||||
use_checkpoint=True,
|
||||
upcast_attention=False,
|
||||
mask = None,
|
||||
):
|
||||
"""Computes efficient dot-product attention given query, transposed key, and value.
|
||||
This is efficient version of attention presented in
|
||||
https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
|
||||
Args:
|
||||
query: queries for calculating attention with shape of
|
||||
`[batch * num_heads, tokens, channels_per_head]`.
|
||||
key_t: keys for calculating attention with shape of
|
||||
`[batch * num_heads, channels_per_head, tokens]`.
|
||||
value: values to be used in attention with shape of
|
||||
`[batch * num_heads, tokens, channels_per_head]`.
|
||||
query_chunk_size: int: query chunks size
|
||||
kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
|
||||
kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
|
||||
use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
|
||||
Returns:
|
||||
Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
|
||||
"""
|
||||
batch_x_heads, q_tokens, q_channels_per_head = query.shape
|
||||
_, _, k_tokens = key_t.shape
|
||||
scale = q_channels_per_head ** -0.5
|
||||
|
||||
kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
|
||||
if kv_chunk_size_min is not None:
|
||||
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
|
||||
|
||||
if mask is not None and len(mask.shape) == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
|
||||
def get_query_chunk(chunk_idx: int) -> Tensor:
|
||||
return dynamic_slice(
|
||||
query,
|
||||
(0, chunk_idx, 0),
|
||||
(batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
|
||||
)
|
||||
|
||||
def get_mask_chunk(chunk_idx: int) -> Tensor:
|
||||
if mask is None:
|
||||
return None
|
||||
chunk = min(query_chunk_size, q_tokens)
|
||||
return mask[:,chunk_idx:chunk_idx + chunk]
|
||||
|
||||
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention)
|
||||
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
|
||||
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
|
||||
_get_attention_scores_no_kv_chunking,
|
||||
scale=scale,
|
||||
upcast_attention=upcast_attention
|
||||
) if k_tokens <= kv_chunk_size else (
|
||||
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
|
||||
partial(
|
||||
_query_chunk_attention,
|
||||
kv_chunk_size=kv_chunk_size,
|
||||
summarize_chunk=summarize_chunk,
|
||||
)
|
||||
)
|
||||
|
||||
if q_tokens <= query_chunk_size:
|
||||
# fast-path for when there's just 1 query chunk
|
||||
return compute_query_chunk_attn(
|
||||
query=query,
|
||||
key_t=key_t,
|
||||
value=value,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
|
||||
# and pass slices to be mutated, instead of torch.cat()ing the returned slices
|
||||
res = torch.cat([
|
||||
compute_query_chunk_attn(
|
||||
query=get_query_chunk(i * query_chunk_size),
|
||||
key_t=key_t,
|
||||
value=value,
|
||||
mask=get_mask_chunk(i * query_chunk_size)
|
||||
) for i in range(math.ceil(q_tokens / query_chunk_size))
|
||||
], dim=1)
|
||||
return res
|
||||
245
comfy/ldm/modules/temporal_ae.py
Normal file
245
comfy/ldm/modules/temporal_ae.py
Normal file
@@ -0,0 +1,245 @@
|
||||
import functools
|
||||
from typing import Callable, Iterable, Union
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
from .diffusionmodules.model import (
|
||||
AttnBlock,
|
||||
Decoder,
|
||||
ResnetBlock,
|
||||
)
|
||||
from .diffusionmodules.openaimodel import ResBlock, timestep_embedding
|
||||
from .attention import BasicTransformerBlock
|
||||
|
||||
def partialclass(cls, *args, **kwargs):
|
||||
class NewCls(cls):
|
||||
__init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
|
||||
|
||||
return NewCls
|
||||
|
||||
|
||||
class VideoResBlock(ResnetBlock):
|
||||
def __init__(
|
||||
self,
|
||||
out_channels,
|
||||
*args,
|
||||
dropout=0.0,
|
||||
video_kernel_size=3,
|
||||
alpha=0.0,
|
||||
merge_strategy="learned",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
|
||||
if video_kernel_size is None:
|
||||
video_kernel_size = [3, 1, 1]
|
||||
self.time_stack = ResBlock(
|
||||
channels=out_channels,
|
||||
emb_channels=0,
|
||||
dropout=dropout,
|
||||
dims=3,
|
||||
use_scale_shift_norm=False,
|
||||
use_conv=False,
|
||||
up=False,
|
||||
down=False,
|
||||
kernel_size=video_kernel_size,
|
||||
use_checkpoint=False,
|
||||
skip_t_emb=True,
|
||||
)
|
||||
|
||||
self.merge_strategy = merge_strategy
|
||||
if self.merge_strategy == "fixed":
|
||||
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
||||
elif self.merge_strategy == "learned":
|
||||
self.register_parameter(
|
||||
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
||||
|
||||
def get_alpha(self, bs):
|
||||
if self.merge_strategy == "fixed":
|
||||
return self.mix_factor
|
||||
elif self.merge_strategy == "learned":
|
||||
return torch.sigmoid(self.mix_factor)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(self, x, temb, skip_video=False, timesteps=None):
|
||||
b, c, h, w = x.shape
|
||||
if timesteps is None:
|
||||
timesteps = b
|
||||
|
||||
x = super().forward(x, temb)
|
||||
|
||||
if not skip_video:
|
||||
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
||||
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
||||
|
||||
x = self.time_stack(x, temb)
|
||||
|
||||
alpha = self.get_alpha(bs=b // timesteps).to(x.device)
|
||||
x = alpha * x + (1.0 - alpha) * x_mix
|
||||
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
return x
|
||||
|
||||
|
||||
class AE3DConv(ops.Conv2d):
|
||||
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
|
||||
super().__init__(in_channels, out_channels, *args, **kwargs)
|
||||
if isinstance(video_kernel_size, Iterable):
|
||||
padding = [int(k // 2) for k in video_kernel_size]
|
||||
else:
|
||||
padding = int(video_kernel_size // 2)
|
||||
|
||||
self.time_mix_conv = ops.Conv3d(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=video_kernel_size,
|
||||
padding=padding,
|
||||
)
|
||||
|
||||
def forward(self, input, timesteps=None, skip_video=False):
|
||||
if timesteps is None:
|
||||
timesteps = input.shape[0]
|
||||
x = super().forward(input)
|
||||
if skip_video:
|
||||
return x
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
||||
x = self.time_mix_conv(x)
|
||||
return rearrange(x, "b c t h w -> (b t) c h w")
|
||||
|
||||
|
||||
class AttnVideoBlock(AttnBlock):
|
||||
def __init__(
|
||||
self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
|
||||
):
|
||||
super().__init__(in_channels)
|
||||
# no context, single headed, as in base class
|
||||
self.time_mix_block = BasicTransformerBlock(
|
||||
dim=in_channels,
|
||||
n_heads=1,
|
||||
d_head=in_channels,
|
||||
checkpoint=False,
|
||||
ff_in=True,
|
||||
)
|
||||
|
||||
time_embed_dim = self.in_channels * 4
|
||||
self.video_time_embed = torch.nn.Sequential(
|
||||
ops.Linear(self.in_channels, time_embed_dim),
|
||||
torch.nn.SiLU(),
|
||||
ops.Linear(time_embed_dim, self.in_channels),
|
||||
)
|
||||
|
||||
self.merge_strategy = merge_strategy
|
||||
if self.merge_strategy == "fixed":
|
||||
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
||||
elif self.merge_strategy == "learned":
|
||||
self.register_parameter(
|
||||
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
||||
|
||||
def forward(self, x, timesteps=None, skip_time_block=False):
|
||||
if skip_time_block:
|
||||
return super().forward(x)
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = x.shape[0]
|
||||
|
||||
x_in = x
|
||||
x = self.attention(x)
|
||||
h, w = x.shape[2:]
|
||||
x = rearrange(x, "b c h w -> b (h w) c")
|
||||
|
||||
x_mix = x
|
||||
num_frames = torch.arange(timesteps, device=x.device)
|
||||
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
|
||||
num_frames = rearrange(num_frames, "b t -> (b t)")
|
||||
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
|
||||
emb = self.video_time_embed(t_emb) # b, n_channels
|
||||
emb = emb[:, None, :]
|
||||
x_mix = x_mix + emb
|
||||
|
||||
alpha = self.get_alpha().to(x.device)
|
||||
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
|
||||
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
|
||||
|
||||
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
||||
x = self.proj_out(x)
|
||||
|
||||
return x_in + x
|
||||
|
||||
def get_alpha(
|
||||
self,
|
||||
):
|
||||
if self.merge_strategy == "fixed":
|
||||
return self.mix_factor
|
||||
elif self.merge_strategy == "learned":
|
||||
return torch.sigmoid(self.mix_factor)
|
||||
else:
|
||||
raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
|
||||
|
||||
|
||||
|
||||
def make_time_attn(
|
||||
in_channels,
|
||||
attn_type="vanilla",
|
||||
attn_kwargs=None,
|
||||
alpha: float = 0,
|
||||
merge_strategy: str = "learned",
|
||||
):
|
||||
return partialclass(
|
||||
AttnVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
|
||||
)
|
||||
|
||||
|
||||
class Conv2DWrapper(torch.nn.Conv2d):
|
||||
def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
return super().forward(input)
|
||||
|
||||
|
||||
class VideoDecoder(Decoder):
|
||||
available_time_modes = ["all", "conv-only", "attn-only"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
video_kernel_size: Union[int, list] = 3,
|
||||
alpha: float = 0.0,
|
||||
merge_strategy: str = "learned",
|
||||
time_mode: str = "conv-only",
|
||||
**kwargs,
|
||||
):
|
||||
self.video_kernel_size = video_kernel_size
|
||||
self.alpha = alpha
|
||||
self.merge_strategy = merge_strategy
|
||||
self.time_mode = time_mode
|
||||
assert (
|
||||
self.time_mode in self.available_time_modes
|
||||
), f"time_mode parameter has to be in {self.available_time_modes}"
|
||||
|
||||
if self.time_mode != "attn-only":
|
||||
kwargs["conv_out_op"] = partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
|
||||
if self.time_mode not in ["conv-only", "only-last-conv"]:
|
||||
kwargs["attn_op"] = partialclass(make_time_attn, alpha=self.alpha, merge_strategy=self.merge_strategy)
|
||||
if self.time_mode not in ["attn-only", "only-last-conv"]:
|
||||
kwargs["resnet_op"] = partialclass(VideoResBlock, video_kernel_size=self.video_kernel_size, alpha=self.alpha, merge_strategy=self.merge_strategy)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def get_last_layer(self, skip_time_mix=False, **kwargs):
|
||||
if self.time_mode == "attn-only":
|
||||
raise NotImplementedError("TODO")
|
||||
else:
|
||||
return (
|
||||
self.conv_out.time_mix_conv.weight
|
||||
if not skip_time_mix
|
||||
else self.conv_out.weight
|
||||
)
|
||||
197
comfy/ldm/util.py
Normal file
197
comfy/ldm/util.py
Normal file
@@ -0,0 +1,197 @@
|
||||
import importlib
|
||||
|
||||
import torch
|
||||
from torch import optim
|
||||
import numpy as np
|
||||
|
||||
from inspect import isfunction
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
|
||||
def log_txt_as_img(wh, xc, size=10):
|
||||
# wh a tuple of (width, height)
|
||||
# xc a list of captions to plot
|
||||
b = len(xc)
|
||||
txts = list()
|
||||
for bi in range(b):
|
||||
txt = Image.new("RGB", wh, color="white")
|
||||
draw = ImageDraw.Draw(txt)
|
||||
font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
|
||||
nc = int(40 * (wh[0] / 256))
|
||||
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
|
||||
|
||||
try:
|
||||
draw.text((0, 0), lines, fill="black", font=font)
|
||||
except UnicodeEncodeError:
|
||||
print("Cant encode string for logging. Skipping.")
|
||||
|
||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||
txts.append(txt)
|
||||
txts = np.stack(txts)
|
||||
txts = torch.tensor(txts)
|
||||
return txts
|
||||
|
||||
|
||||
def ismap(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] > 3)
|
||||
|
||||
|
||||
def isimage(x):
|
||||
if not isinstance(x,torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
||||
|
||||
|
||||
def exists(x):
|
||||
return x is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
||||
return total_params
|
||||
|
||||
|
||||
def instantiate_from_config(config):
|
||||
if not "target" in config:
|
||||
if config == '__is_first_stage__':
|
||||
return None
|
||||
elif config == "__is_unconditional__":
|
||||
return None
|
||||
raise KeyError("Expected key `target` to instantiate.")
|
||||
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
module, cls = string.rsplit(".", 1)
|
||||
if reload:
|
||||
module_imp = importlib.import_module(module)
|
||||
importlib.reload(module_imp)
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
|
||||
|
||||
class AdamWwithEMAandWings(optim.Optimizer):
|
||||
# credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
|
||||
def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
|
||||
weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
|
||||
ema_power=1., param_names=()):
|
||||
"""AdamW that saves EMA versions of the parameters."""
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
if not 0.0 <= ema_decay <= 1.0:
|
||||
raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
|
||||
ema_power=ema_power, param_names=param_names)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault('amsgrad', False)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
Args:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_with_grad = []
|
||||
grads = []
|
||||
exp_avgs = []
|
||||
exp_avg_sqs = []
|
||||
ema_params_with_grad = []
|
||||
state_sums = []
|
||||
max_exp_avg_sqs = []
|
||||
state_steps = []
|
||||
amsgrad = group['amsgrad']
|
||||
beta1, beta2 = group['betas']
|
||||
ema_decay = group['ema_decay']
|
||||
ema_power = group['ema_power']
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
params_with_grad.append(p)
|
||||
if p.grad.is_sparse:
|
||||
raise RuntimeError('AdamW does not support sparse gradients')
|
||||
grads.append(p.grad)
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
if amsgrad:
|
||||
# Maintains max of all exp. moving avg. of sq. grad. values
|
||||
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
# Exponential moving average of parameter values
|
||||
state['param_exp_avg'] = p.detach().float().clone()
|
||||
|
||||
exp_avgs.append(state['exp_avg'])
|
||||
exp_avg_sqs.append(state['exp_avg_sq'])
|
||||
ema_params_with_grad.append(state['param_exp_avg'])
|
||||
|
||||
if amsgrad:
|
||||
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
|
||||
|
||||
# update the steps for each param group update
|
||||
state['step'] += 1
|
||||
# record the step after step update
|
||||
state_steps.append(state['step'])
|
||||
|
||||
optim._functional.adamw(params_with_grad,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
max_exp_avg_sqs,
|
||||
state_steps,
|
||||
amsgrad=amsgrad,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=group['lr'],
|
||||
weight_decay=group['weight_decay'],
|
||||
eps=group['eps'],
|
||||
maximize=False)
|
||||
|
||||
cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
|
||||
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
|
||||
ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
|
||||
|
||||
return loss
|
||||
Reference in New Issue
Block a user