sync with repo 28.08

This commit is contained in:
2024-08-28 19:33:34 +03:00
parent 727693318c
commit ad1e3ecbcb
134 changed files with 112534 additions and 12635 deletions

View File

@@ -1,5 +1,27 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import comfy.utils
import comfy.model_management
import comfy.model_base
import logging
import torch
LORA_CLIP_MAP = {
"mlp.fc1": "mlp_fc1",
@@ -218,11 +240,17 @@ def model_lora_keys_clip(model, key_map={}):
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config
key_map[lora_key] = k
for k in sdk: #OneTrainer SD3 lora
if k.startswith("t5xxl.transformer.") and k.endswith(".weight"):
l_key = k[len("t5xxl.transformer."):-len(".weight")]
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
key_map[lora_key] = k
for k in sdk:
if k.endswith(".weight"):
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
l_key = k[len("t5xxl.transformer."):-len(".weight")]
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
key_map[lora_key] = k
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
key_map[lora_key] = k
k = "clip_g.transformer.text_projection.weight"
if k in sdk:
@@ -245,6 +273,7 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = k
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
for k in diffusers_keys:
@@ -288,4 +317,240 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["base_model.model.{}".format(key_lora)] = k #official hunyuan lora format
if isinstance(model, comfy.model_base.Flux): #Diffusers lora Flux
diffusers_keys = comfy.utils.flux_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:
if k.endswith(".weight"):
to = diffusers_keys[k]
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
return key_map
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
lora_diff *= alpha
weight_calc = weight + lora_diff.type(weight.dtype)
weight_norm = (
weight_calc.transpose(0, 1)
.reshape(weight_calc.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
.transpose(0, 1)
)
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
if strength != 1.0:
weight_calc -= weight
weight += strength * (weight_calc)
else:
weight[:] = weight_calc
return weight
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
"""
Pad a tensor to a new shape with zeros.
Args:
tensor (torch.Tensor): The original tensor to be padded.
new_shape (List[int]): The desired shape of the padded tensor.
Returns:
torch.Tensor: A new tensor padded with zeros to the specified shape.
Note:
If the new shape is smaller than the original tensor in any dimension,
the original tensor will be truncated in that dimension.
"""
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
if len(new_shape) != len(tensor.shape):
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
# Create a new tensor filled with zeros
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
# Create slicing tuples for both tensors
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
# Copy the original tensor into the new tensor
padded_tensor[new_slices] = tensor[orig_slices]
return padded_tensor
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
for p in patches:
strength = p[0]
v = p[1]
strength_model = p[2]
offset = p[3]
function = p[4]
if function is None:
function = lambda a: a
old_weight = None
if offset is not None:
old_weight = weight
weight = weight.narrow(offset[0], offset[1], offset[2])
if strength_model != 1.0:
weight *= strength_model
if isinstance(v, list):
v = (calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype), )
if len(v) == 1:
patch_type = "diff"
elif len(v) == 2:
patch_type = v[0]
v = v[1]
if patch_type == "diff":
diff: torch.Tensor = v[0]
# An extra flag to pad the weight if the diff's shape is larger than the weight
do_pad_weight = len(v) > 1 and v[1]['pad_weight']
if do_pad_weight and diff.shape != weight.shape:
logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape))
weight = pad_tensor_to_shape(weight, diff.shape)
if strength != 0.0:
if diff.shape != weight.shape:
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
else:
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
dora_scale = v[4]
if v[2] is not None:
alpha = v[2] / mat2.shape[0]
else:
alpha = 1.0
if v[3] is not None:
#locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "lokr":
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dora_scale = v[8]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
else:
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
else:
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha = v[2] / dim
else:
alpha = 1.0
try:
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "loha":
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha = v[2] / w1b.shape[0]
else:
alpha = 1.0
w2a = v[3]
w2b = v[4]
dora_scale = v[7]
if v[5] is not None: #cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
else:
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
try:
lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora":
if v[4] is not None:
alpha = v[4] / v[0].shape[0]
else:
alpha = 1.0
dora_scale = v[5]
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
try:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
else:
logging.warning("patch type not recognized {} {}".format(patch_type, key))
if old_weight is not None:
weight = old_weight
return weight