import torch from torch import nn, einsum from einops import rearrange, repeat import torch.nn.functional as F import math from comfy import model_management import types import os def exists(val): return val is not None # better than a division by 0 hey abs_mean = lambda x: torch.where(torch.isnan(x) | torch.isinf(x), torch.zeros_like(x), x).abs().mean() class temperature_patcher(): def __init__(self, temperature, layer_name="None"): self.temperature = temperature self.layer_name = layer_name # taken from comfy.ldm.modules def attention_basic_with_temperature(self, q, k, v, extra_options, mask=None, attn_precision=None): if isinstance(extra_options, int): heads = extra_options else: heads = extra_options['n_heads'] b, _, dim_head = q.shape dim_head //= heads scale = dim_head ** -0.5 h = heads q, k, v = map( lambda t: t.unsqueeze(3) .reshape(b, -1, heads, dim_head) .permute(0, 2, 1, 3) .reshape(b * heads, -1, dim_head) .contiguous(), (q, k, v), ) # force cast to fp32 to avoid overflowing if attn_precision == torch.float32: sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale else: sim = einsum('b i d, b j d -> b i j', q, k) * scale del q, k if exists(mask): if mask.dtype == torch.bool: mask = rearrange(mask, 'b ... -> b (...)') max_neg_value = -torch.finfo(sim.dtype).max mask = repeat(mask, 'b j -> (b h) () j', h=h) sim.masked_fill_(~mask, max_neg_value) else: if len(mask.shape) == 2: bs = 1 else: bs = mask.shape[0] mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) sim.add_(mask) # attention, what we cannot get enough of sim = sim.div(self.temperature if self.temperature > 0 else abs_mean(sim)).softmax(dim=-1) out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v) out = ( out.unsqueeze(0) .reshape(b, heads, -1, dim_head) .permute(0, 2, 1, 3) .reshape(b, -1, heads * dim_head) ) return out layers_SD15 = { "input":[1,2,4,5,7,8], "middle":[0], "output":[3,4,5,6,7,8,9,10,11], } layers_SDXL = { "input":[4,5,7,8], "middle":[0], "output":[0,1,2,3,4,5], } class ExperimentalTemperaturePatch: @classmethod def INPUT_TYPES(s): required_inputs = {f"{key}_{layer}": ("BOOLEAN", {"default": False}) for key, layers in s.TOGGLES.items() for layer in layers} required_inputs["model"] = ("MODEL",) required_inputs["Temperature"] = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "round": 0.01}) required_inputs["Attention"] = (["both","self","cross"],) return {"required": required_inputs} TOGGLES = {} RETURN_TYPES = ("MODEL","STRING",) RETURN_NAMES = ("Model","String",) FUNCTION = "patch" CATEGORY = "model_patches/Automatic_CFG/Standalone_temperature_patches" def patch(self, model, Temperature, Attention, **kwargs): m = model.clone() levels = ["input","middle","output"] parameters_output = {level:[] for level in levels} for key, toggle_enabled in kwargs.items(): current_level = key.split("_")[0] if current_level in levels and toggle_enabled: b_number = int(key.split("_")[1]) parameters_output[current_level].append(b_number) patcher = temperature_patcher(Temperature,key) if Attention in ["both","self"]: m.set_model_attn1_replace(patcher.attention_basic_with_temperature, current_level, b_number) if Attention in ["both","cross"]: m.set_model_attn2_replace(patcher.attention_basic_with_temperature, current_level, b_number) parameters_as_string = "\n".join(f"{k}: {','.join(map(str, v))}" for k, v in parameters_output.items()) parameters_as_string = f"Temperature: {Temperature}\n{parameters_as_string}\nAttention: {Attention}" return (m, parameters_as_string,) ExperimentalTemperaturePatchSDXL = type("ExperimentalTemperaturePatch_SDXL", (ExperimentalTemperaturePatch,), {"TOGGLES": layers_SDXL}) ExperimentalTemperaturePatchSD15 = type("ExperimentalTemperaturePatch_SD15", (ExperimentalTemperaturePatch,), {"TOGGLES": layers_SD15}) class CLIPTemperaturePatch: @classmethod def INPUT_TYPES(cls): return {"required": { "clip": ("CLIP",), "Temperature": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), }} RETURN_TYPES = ("CLIP",) FUNCTION = "patch" CATEGORY = "model_patches/Automatic_CFG/Standalone_temperature_patches" def patch(self, clip, Temperature): def custom_optimized_attention(device, mask=None, small_input=True): return temperature_patcher(Temperature).attention_basic_with_temperature def new_forward(self, x, mask=None, intermediate_output=None): optimized_attention = custom_optimized_attention(x.device, mask=mask is not None, small_input=True) if intermediate_output is not None: if intermediate_output < 0: intermediate_output = len(self.layers) + intermediate_output intermediate = None for i, l in enumerate(self.layers): x = l(x, mask, optimized_attention) if i == intermediate_output: intermediate = x.clone() return x, intermediate m = clip.clone() clip_encoder_instance = m.cond_stage_model.clip_l.transformer.text_model.encoder clip_encoder_instance.forward = types.MethodType(new_forward, clip_encoder_instance) if getattr(m.cond_stage_model, f"clip_g", None) is not None: clip_encoder_instance_g = m.cond_stage_model.clip_g.transformer.text_model.encoder clip_encoder_instance_g.forward = types.MethodType(new_forward, clip_encoder_instance_g) return (m,) class CLIPTemperaturePatchDual: @classmethod def INPUT_TYPES(cls): return {"required": { "clip": ("CLIP",), "Temperature": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), "CLIP_Model": (["clip_g","clip_l","both"],), }} RETURN_TYPES = ("CLIP",) FUNCTION = "patch" CATEGORY = "model_patches/Automatic_CFG/Standalone_temperature_patches" def patch(self, clip, Temperature, CLIP_Model): def custom_optimized_attention(device, mask=None, small_input=True): return temperature_patcher(Temperature, "CLIP").attention_basic_with_temperature def new_forward(self, x, mask=None, intermediate_output=None): optimized_attention = custom_optimized_attention(x.device, mask=mask is not None, small_input=True) if intermediate_output is not None: if intermediate_output < 0: intermediate_output = len(self.layers) + intermediate_output intermediate = None for i, l in enumerate(self.layers): x = l(x, mask, optimized_attention) if i == intermediate_output: intermediate = x.clone() return x, intermediate m = clip.clone() if CLIP_Model in ["clip_l","both"]: clip_encoder_instance = m.cond_stage_model.clip_l.transformer.text_model.encoder clip_encoder_instance.forward = types.MethodType(new_forward, clip_encoder_instance) if CLIP_Model in ["clip_g","both"]: if getattr(m.cond_stage_model, f"clip_g", None) is not None: clip_encoder_instance_g = m.cond_stage_model.clip_g.transformer.text_model.encoder clip_encoder_instance_g.forward = types.MethodType(new_forward, clip_encoder_instance_g) return (m,)