Skip to content

[pull] master from comfyanonymous:master #148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 42 additions & 31 deletions comfy/k_diffusion/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,27 +1447,34 @@ def post_cfg_function(args):
old_d = d
return x


@torch.no_grad()
def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True)


@torch.no_grad()
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
"""
Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169.
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1.0, noise_sampler=None, noise_scaler=None, max_stage=3):
"""Extended Reverse-Time SDE solver (VP ER-SDE-Solver-3). arXiv: https://arxiv.org/abs/2309.06169.
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
"""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])

def default_noise_scaler(sigma):
return sigma * ((sigma ** 0.3).exp() + 10.0)
noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler
def default_er_sde_noise_scaler(x):
return x * ((x ** 0.3).exp() + 10.0)

noise_scaler = default_er_sde_noise_scaler if noise_scaler is None else noise_scaler
num_integration_points = 200.0
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)

model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
half_log_snrs = sigma_to_half_log_snr(sigmas, model_sampling)
er_lambdas = half_log_snrs.neg().exp() # er_lambda_t = sigma_t / alpha_t

old_denoised = None
old_denoised_d = None

Expand All @@ -1478,32 +1485,36 @@ def default_noise_scaler(sigma):
stage_used = min(max_stage, i + 1)
if sigmas[i + 1] == 0:
x = denoised
elif stage_used == 1:
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
x = r * x + (1 - r) * denoised
else:
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
x = r * x + (1 - r) * denoised

dt = sigmas[i + 1] - sigmas[i]
sigma_step_size = -dt / num_integration_points
sigma_pos = sigmas[i + 1] + point_indice * sigma_step_size
scaled_pos = noise_scaler(sigma_pos)

# Stage 2
s = torch.sum(1 / scaled_pos) * sigma_step_size
denoised_d = (denoised - old_denoised) / (sigmas[i] - sigmas[i - 1])
x = x + (dt + s * noise_scaler(sigmas[i + 1])) * denoised_d

if stage_used >= 3:
# Stage 3
s_u = torch.sum((sigma_pos - sigmas[i]) / scaled_pos) * sigma_step_size
denoised_u = (denoised_d - old_denoised_d) / ((sigmas[i] - sigmas[i - 2]) / 2)
x = x + ((dt ** 2) / 2 + s_u * noise_scaler(sigmas[i + 1])) * denoised_u
old_denoised_d = denoised_d

if s_noise != 0 and sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
er_lambda_s, er_lambda_t = er_lambdas[i], er_lambdas[i + 1]
alpha_s = sigmas[i] / er_lambda_s
alpha_t = sigmas[i + 1] / er_lambda_t
r_alpha = alpha_t / alpha_s
r = noise_scaler(er_lambda_t) / noise_scaler(er_lambda_s)

# Stage 1 Euler
x = r_alpha * r * x + alpha_t * (1 - r) * denoised

if stage_used >= 2:
dt = er_lambda_t - er_lambda_s
lambda_step_size = -dt / num_integration_points
lambda_pos = er_lambda_t + point_indice * lambda_step_size
scaled_pos = noise_scaler(lambda_pos)

# Stage 2
s = torch.sum(1 / scaled_pos) * lambda_step_size
denoised_d = (denoised - old_denoised) / (er_lambda_s - er_lambdas[i - 1])
x = x + alpha_t * (dt + s * noise_scaler(er_lambda_t)) * denoised_d

if stage_used >= 3:
# Stage 3
s_u = torch.sum((lambda_pos - er_lambda_s) / scaled_pos) * lambda_step_size
denoised_u = (denoised_d - old_denoised_d) / ((er_lambda_s - er_lambdas[i - 2]) / 2)
x = x + alpha_t * ((dt ** 2) / 2 + s_u * noise_scaler(er_lambda_t)) * denoised_u
old_denoised_d = denoised_d

if s_noise > 0:
x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (er_lambda_t ** 2 - er_lambda_s ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
old_denoised = denoised
return x

Expand Down
42 changes: 42 additions & 0 deletions comfy_extras/nodes_custom_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import comfy.samplers
import comfy.sample
from comfy.k_diffusion import sampling as k_diffusion_sampling
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
import latent_preview
import torch
import comfy.utils
Expand Down Expand Up @@ -480,6 +481,46 @@ def get_sampler(self, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_
"s_noise":s_noise })
return (sampler, )


class SamplerER_SDE(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"solver_type": (IO.COMBO, {"options": ["ER-SDE", "Reverse-time SDE", "ODE"]}),
"max_stage": (IO.INT, {"default": 3, "min": 1, "max": 3}),
"eta": (
IO.FLOAT,
{"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False, "tooltip": "Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type."},
),
"s_noise": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False}),
}
}

RETURN_TYPES = (IO.SAMPLER,)
CATEGORY = "sampling/custom_sampling/samplers"

FUNCTION = "get_sampler"

def get_sampler(self, solver_type, max_stage, eta, s_noise):
if solver_type == "ODE" or (solver_type == "Reverse-time SDE" and eta == 0):
eta = 0
s_noise = 0

def reverse_time_sde_noise_scaler(x):
return x ** (eta + 1)

if solver_type == "ER-SDE":
# Use the default one in sample_er_sde()
noise_scaler = None
else:
noise_scaler = reverse_time_sde_noise_scaler

sampler_name = "er_sde"
sampler = comfy.samplers.ksampler(sampler_name, {"s_noise": s_noise, "noise_scaler": noise_scaler, "max_stage": max_stage})
return (sampler,)


class Noise_EmptyNoise:
def __init__(self):
self.seed = 0
Expand Down Expand Up @@ -787,6 +828,7 @@ def add_noise(self, model, noise, sigmas, latent_image):
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
"SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral,
"SamplerDPMAdaptative": SamplerDPMAdaptative,
"SamplerER_SDE": SamplerER_SDE,
"SplitSigmas": SplitSigmas,
"SplitSigmasDenoise": SplitSigmasDenoise,
"FlipSigmas": FlipSigmas,
Expand Down
7 changes: 7 additions & 0 deletions comfy_extras/nodes_perpneg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import comfy.samplers
import comfy.utils
import node_helpers
import math

def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale):
pos = noise_pred_pos - noise_pred_nocond
Expand Down Expand Up @@ -69,6 +70,12 @@ def predict_noise(self, x, timestep, model_options={}, seed=None):
negative_cond = self.conds.get("negative", None)
empty_cond = self.conds.get("empty_negative_prompt", None)

if model_options.get("disable_cfg1_optimization", False) == False:
if math.isclose(self.neg_scale, 0.0):
negative_cond = None
if math.isclose(self.cfg, 1.0):
empty_cond = None

conds = [positive_cond, negative_cond, empty_cond]

out = comfy.samplers.calc_cond_batch(self.inner_model, conds, x, timestep, model_options)
Expand Down
71 changes: 71 additions & 0 deletions comfy_extras/nodes_tcfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# TCFG: Tangential Damping Classifier-free Guidance - (arXiv: https://arxiv.org/abs/2503.18137)

import torch

from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict


def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tensor) -> torch.Tensor:
"""Drop tangential components from uncond score to align with cond score."""
# (B, 1, ...)
batch_num = cond_score.shape[0]
cond_score_flat = cond_score.reshape(batch_num, 1, -1).float()
uncond_score_flat = uncond_score.reshape(batch_num, 1, -1).float()

# Score matrix A (B, 2, ...)
score_matrix = torch.cat((uncond_score_flat, cond_score_flat), dim=1)
try:
_, _, Vh = torch.linalg.svd(score_matrix, full_matrices=False)
except RuntimeError:
# Fallback to CPU
_, _, Vh = torch.linalg.svd(score_matrix.cpu(), full_matrices=False)

# Drop the tangential components
v1 = Vh[:, 0:1, :].to(uncond_score_flat.device) # (B, 1, ...)
uncond_score_td = (uncond_score_flat @ v1.transpose(-2, -1)) * v1
return uncond_score_td.reshape_as(uncond_score).to(uncond_score.dtype)


class TCFG(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"model": (IO.MODEL, {}),
}
}

RETURN_TYPES = (IO.MODEL,)
RETURN_NAMES = ("patched_model",)
FUNCTION = "patch"

CATEGORY = "advanced/guidance"
DESCRIPTION = "TCFG – Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality."

def patch(self, model):
m = model.clone()

def tangential_damping_cfg(args):
# Assume [cond, uncond, ...]
x = args["input"]
conds_out = args["conds_out"]
if len(conds_out) <= 1 or None in args["conds"][:2]:
# Skip when either cond or uncond is None
return conds_out
cond_pred = conds_out[0]
uncond_pred = conds_out[1]
uncond_td = score_tangential_damping(x - cond_pred, x - uncond_pred)
uncond_pred_td = x - uncond_td
return [cond_pred, uncond_pred_td] + conds_out[2:]

m.set_model_sampler_pre_cfg_function(tangential_damping_cfg)
return (m,)


NODE_CLASS_MAPPINGS = {
"TCFG": TCFG,
}

NODE_DISPLAY_NAME_MAPPINGS = {
"TCFG": "Tangential Damping CFG",
}
1 change: 1 addition & 0 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2283,6 +2283,7 @@ def init_builtin_extra_nodes():
"nodes_string.py",
"nodes_camera_trajectory.py",
"nodes_edit_model.py",
"nodes_tcfg.py"
]

import_failed = []
Expand Down
Loading