sync with repo 28.08
This commit is contained in:
89
comfy/ops.py
89
comfy/ops.py
@@ -18,29 +18,42 @@
|
||||
|
||||
import torch
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
return weight
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False):
|
||||
return weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
return r
|
||||
|
||||
def cast_to_input(weight, input, non_blocking=False):
|
||||
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking)
|
||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None):
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if bias_dtype is None:
|
||||
bias_dtype = dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
|
||||
bias = None
|
||||
non_blocking = comfy.model_management.device_should_use_non_blocking(device)
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
if s.bias is not None:
|
||||
bias = cast_to(s.bias, dtype, device, non_blocking=non_blocking)
|
||||
if s.bias_function is not None:
|
||||
has_function = s.bias_function is not None
|
||||
bias = cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||
if has_function:
|
||||
bias = s.bias_function(bias)
|
||||
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking)
|
||||
if s.weight_function is not None:
|
||||
|
||||
has_function = s.weight_function is not None
|
||||
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||
if has_function:
|
||||
weight = s.weight_function(weight)
|
||||
return weight, bias
|
||||
|
||||
@@ -238,3 +251,59 @@ class manual_cast(disable_weight_init):
|
||||
|
||||
class Embedding(disable_weight_init.Embedding):
|
||||
comfy_cast_weights = True
|
||||
|
||||
|
||||
def fp8_linear(self, input):
|
||||
dtype = self.weight.dtype
|
||||
if dtype not in [torch.float8_e4m3fn]:
|
||||
return None
|
||||
|
||||
if len(input.shape) == 3:
|
||||
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
|
||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
||||
w = w.t()
|
||||
|
||||
scale_weight = self.scale_weight
|
||||
scale_input = self.scale_input
|
||||
if scale_weight is None:
|
||||
scale_weight = torch.ones((1), device=input.device, dtype=torch.float32)
|
||||
if scale_input is None:
|
||||
scale_input = scale_weight
|
||||
if scale_input is None:
|
||||
scale_input = torch.ones((1), device=input.device, dtype=torch.float32)
|
||||
|
||||
if bias is not None:
|
||||
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||
else:
|
||||
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)
|
||||
|
||||
if isinstance(o, tuple):
|
||||
o = o[0]
|
||||
|
||||
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
|
||||
return None
|
||||
|
||||
class fp8_ops(manual_cast):
|
||||
class Linear(manual_cast.Linear):
|
||||
def reset_parameters(self):
|
||||
self.scale_weight = None
|
||||
self.scale_input = None
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
out = fp8_linear(self, input)
|
||||
if out is not None:
|
||||
return out
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None):
|
||||
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||
return disable_weight_init
|
||||
if args.fast:
|
||||
if comfy.model_management.supports_fp8_compute(load_device):
|
||||
return fp8_ops
|
||||
return manual_cast
|
||||
|
||||
Reference in New Issue
Block a user