sync with repo 28.08

This commit is contained in:
2024-08-28 19:33:34 +03:00
parent 727693318c
commit ad1e3ecbcb
134 changed files with 112534 additions and 12635 deletions

View File

@@ -6,9 +6,9 @@ import torch
import os
class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -35,11 +35,11 @@ class FluxTokenizer:
class FluxClipModel(torch.nn.Module):
def __init__(self, dtype_t5=None, device="cpu", dtype=None):
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
super().__init__()
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
self.dtypes = set([dtype, dtype_t5])
def set_clip_options(self, options):
@@ -52,9 +52,9 @@ class FluxClipModel(torch.nn.Module):
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_l = token_weight_pairs["l"]
token_weight_pars_t5 = token_weight_pairs["t5xxl"]
token_weight_pairs_t5 = token_weight_pairs["t5xxl"]
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5)
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
return t5_out, l_pooled
@@ -66,6 +66,6 @@ class FluxClipModel(torch.nn.Module):
def flux_clip(dtype_t5=None):
class FluxClipModel_(FluxClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype)
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
return FluxClipModel_