Fixing sizing mismatch issue.

#10
by IHateCrickets - opened

I was running into errors like these:

!!! Exception during processing !!! The expanded size of the tensor (26) must match the existing size (27) at non-singleton dimension 3. Target sizes: [2, 128, 35, 26]. Tensor sizes: [2, 128, 36, 27]
File "*\ComfyUI\comfy\ldm\lightricks\model.py", line 437, in forward
x[:, :, 0] = guiding_latent[:, :, 0]
~^^^^^^^^^
RuntimeError: The expanded size of the tensor (26) must match the existing size (27) at non-singleton dimension 3. Target sizes: [2, 128, 35, 26]. Tensor sizes: [2, 128, 36, 27]

I modified two files to fix the issue.
I updated *\ComfyUI\comfy_extras\nodes_lt.py
And I also updated *\ComfyUI\comfy\ldm\lightricks\model.py

code for model.py:

import torch
from torch import nn
import comfy.ldm.modules.attention
from comfy.ldm.genmo.joint_model.layers import RMSNorm
import comfy.ldm.common_dit
from einops import rearrange
import math
from typing import Dict, Optional, Tuple

from .symmetric_patchifier import SymmetricPatchifier

def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
"""
Create sinusoidal timestep embeddings.

Args:
    timesteps (torch.Tensor): A 1-D Tensor of N indices, one per batch element. These may be fractional.
    embedding_dim (int): The dimension of the output.
    flip_sin_to_cos (bool): Whether to flip the sine and cosine embeddings.
    downscale_freq_shift (float): Controls the delta between frequencies between dimensions.
    scale (float): Scaling factor applied to the embeddings.
    max_period (int): Controls the maximum frequency of the embeddings.

Returns:
    torch.Tensor: An [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"

half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
    start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)

emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]

# Scale embeddings
emb = scale * emb

# Concatenate sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)

# Flip sine and cosine embeddings if required
if flip_sin_to_cos:
    emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)

# Zero pad if embedding_dim is odd
if embedding_dim % 2 == 1:
    emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb

class TimestepEmbedding(nn.Module):
def init(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
sample_proj_bias=True,
dtype=None, device=None, operations=None,
):
super().init()

    self.linear_1 = operations.Linear(in_channels, time_embed_dim, sample_proj_bias, dtype=dtype, device=device)

    if cond_proj_dim is not None:
        self.cond_proj = operations.Linear(cond_proj_dim, in_channels, bias=False, dtype=dtype, device=device)
    else:
        self.cond_proj = None

    self.act = nn.SiLU()

    if out_dim is not None:
        time_embed_dim_out = out_dim
    else:
        time_embed_dim_out = time_embed_dim
    self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device)

    if post_act_fn is None:
        self.post_act = None
    # else:
    #     self.post_act = get_activation(post_act_fn)

def forward(self, sample, condition=None):
    if condition is not None:
        sample = sample + self.cond_proj(condition)
    sample = self.linear_1(sample)

    if self.act is not None:
        sample = self.act(sample)

    sample = self.linear_2(sample)

    if self.post_act is not None:
        sample = self.post_act(sample)
    return sample

class Timesteps(nn.Module):
def init(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
super().init()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale

def forward(self, timesteps):
    t_emb = get_timestep_embedding(
        timesteps,
        self.num_channels,
        flip_sin_to_cos=self.flip_sin_to_cos,
        downscale_freq_shift=self.downscale_freq_shift,
        scale=self.scale,
    )
    return t_emb

class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
"""
For PixArt-Alpha.

Reference:
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
"""

def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
    super().__init__()

    self.outdim = size_emb_dim
    self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
    self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations)

def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
    timesteps_proj = self.time_proj(timestep)
    timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype))  # (N, D)
    return timesteps_emb

class AdaLayerNormSingle(nn.Module):
r"""
Norm layer adaptive layer norm single (adaLN-single).

As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).

Parameters:
    embedding_dim (`int`): The size of each embedding vector.
    use_additional_conditions (`bool`): To use additional conditions for normalization or not.
"""

def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
    super().__init__()

    self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
        embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations
    )

    self.silu = nn.SiLU()
    self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device)

def forward(
    self,
    timestep: torch.Tensor,
    added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
    batch_size: Optional[int] = None,
    hidden_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    # No modulation happening here.
    added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
    embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
    return self.linear(self.silu(embedded_timestep)), embedded_timestep

class PixArtAlphaTextProjection(nn.Module):
"""
Projects caption embeddings. Also handles dropout for classifier-free guidance.

Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""

def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None):
    super().__init__()
    if out_features is None:
        out_features = hidden_size
    self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device)
    if act_fn == "gelu_tanh":
        self.act_1 = nn.GELU(approximate="tanh")
    elif act_fn == "silu":
        self.act_1 = nn.SiLU()
    else:
        raise ValueError(f"Unknown activation function: {act_fn}")
    self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device)

def forward(self, caption):
    hidden_states = self.linear_1(caption)
    hidden_states = self.act_1(hidden_states)
    hidden_states = self.linear_2(hidden_states)
    return hidden_states

class GELU_approx(nn.Module):
def init(self, dim_in, dim_out, dtype=None, device=None, operations=None):
super().init()
self.proj = operations.Linear(dim_in, dim_out, dtype=dtype, device=device)

def forward(self, x):
    return torch.nn.functional.gelu(self.proj(x), approximate="tanh")

class FeedForward(nn.Module):
def init(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None):
super().init()
inner_dim = int(dim * mult)
project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)

    self.net = nn.Sequential(
        project_in,
        nn.Dropout(dropout),
        operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
    )

def forward(self, x):
    return self.net(x)

def apply_rotary_emb(input_tensor, freqs_cis): # TODO: remove duplicate funcs and pick the best/fastest one
cos_freqs = freqs_cis[0]
sin_freqs = freqs_cis[1]

t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
t1, t2 = t_dup.unbind(dim=-1)
t_dup = torch.stack((-t2, t1), dim=-1)
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")

out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs

return out

class CrossAttention(nn.Module):
def init(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
super().init()
inner_dim = dim_head * heads
context_dim = query_dim if context_dim is None else context_dim
self.attn_precision = attn_precision

    self.heads = heads
    self.dim_head = dim_head

    self.q_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
    self.k_norm = RMSNorm(inner_dim, dtype=dtype, device=device)

    self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
    self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
    self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)

    self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))

def forward(self, x, context=None, mask=None, pe=None):
    q = self.to_q(x)
    context = x if context is None else context
    k = self.to_k(context)
    v = self.to_v(context)

    q = self.q_norm(q)
    k = self.k_norm(k)

    if pe is not None:
        q = apply_rotary_emb(q, pe)
        k = apply_rotary_emb(k, pe)

    if mask is None:
        out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
    else:
        out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
    return self.to_out(out)

class BasicTransformerBlock(nn.Module):
def init(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None):
super().init()

    self.attn_precision = attn_precision
    self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
    self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)

    self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)

    self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))

def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
    shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
        self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], x.shape[1], self.scale_shift_table.shape[0], -1)
    ).unbind(dim=2)

    x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa

    x += self.attn2(x, context=context, mask=attention_mask)

    y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
    x += self.ff(y) * gate_mlp

    return x

def get_fractional_positions(indices_grid, max_pos):
fractional_positions = torch.stack(
[
indices_grid[:, i] / max_pos[i]
for i in range(3)
],
dim=-1,
)
return fractional_positions

def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
dtype = torch.float32 # self.dtype

fractional_positions = get_fractional_positions(indices_grid, max_pos)

start = 1
end = theta
device = fractional_positions.device

indices = theta ** (
    torch.linspace(
        math.log(start, theta),
        math.log(end, theta),
        dim // 6,
        device=device,
        dtype=dtype,
    )
)
indices = indices.to(dtype=dtype)

indices = indices * math.pi / 2

freqs = (
    (indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
    .transpose(-1, -2)
    .flatten(2)
)

cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
if dim % 6 != 0:
    cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
    sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
    cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
    sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)

class LTXVModel(torch.nn.Module):
def init(self,
in_channels=128,
cross_attention_dim=2048,
attention_head_dim=64,
num_attention_heads=32,
caption_channels=4096,
num_layers=28,
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
dtype=None, device=None, operations=None, **kwargs):
super().init()
self.dtype = dtype
self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim

    self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)

    self.adaln_single = AdaLayerNormSingle(
        self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations
    )

    self.caption_projection = PixArtAlphaTextProjection(
        in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations
    )

    self.transformer_blocks = nn.ModuleList(
        [
            BasicTransformerBlock(
                self.inner_dim,
                num_attention_heads,
                attention_head_dim,
                context_dim=cross_attention_dim,
                dtype=dtype, device=device, operations=operations
            )
            for d in range(num_layers)
        ]
    )

    self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
    self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
    self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)

    self.patchifier = SymmetricPatchifier(1)

def forward(self, x, timestep, context, attention_mask, frame_rate, guiding_latent=None, transformer_options={}, **kwargs):
    # Debug: Print original input shape
    print(f"Original x shape: {x.shape}")  # Expected: [batch, channels, frames, height, width]

    # Calculate padding to ensure dimensions are divisible by patch size
    patch_size = self.patchifier.patch_size
    print(f"Patch size: {patch_size}")

    # Handle patch_size as a tuple (frames, height, width)
    if isinstance(patch_size, tuple):
        if len(patch_size) != 3:
            raise ValueError(f"Expected patch_size to be a tuple of length 3, but got length {len(patch_size)}")
        patch_size_frames, patch_size_height, patch_size_width = patch_size
    else:
        patch_size_frames = patch_size_height = patch_size_width = patch_size

    # Compute padding for each dimension
    padding_frames = (patch_size_frames - (x.shape[2] % patch_size_frames)) % patch_size_frames
    padding_height = (patch_size_height - (x.shape[3] % patch_size_height)) % patch_size_height
    padding_width = (patch_size_width - (x.shape[4] % patch_size_width)) % patch_size_width

    if padding_frames != 0 or padding_height != 0 or padding_width != 0:
        # Calculate padding in the order (W_left, W_right, H_top, H_bottom, F_front, F_back)
        padding = [
            padding_width // 2, padding_width - (padding_width // 2),
            padding_height // 2, padding_height - (padding_height // 2),
            padding_frames // 2, padding_frames - (padding_frames // 2)
        ]
        x = torch.nn.functional.pad(x, padding)
        print(f"x shape after padding: {x.shape}")

    indices_grid = self.patchifier.get_grid(
        orig_num_frames=x.shape[2],
        orig_height=x.shape[3],
        orig_width=x.shape[4],
        batch_size=x.shape[0],
        scale_grid=(1, 32, 32),  # Changed from ((1 / frame_rate) * 8, 32, 32)
        device=x.device,
    )
    print(f"indices_grid shape: {indices_grid.shape}")

    if guiding_latent is not None:
        ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
        input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
        ts *= input_ts
        ts[:, :, 0] = 0.0
        timestep = self.patchifier.patchify(ts)
        print(f"Timestep shape after patchify: {timestep.shape}")

        input_x = x.clone()
        # Debug: Print shapes before assignment
        print(f"x shape before assignment: {x[:, :, 0].shape}")  # [batch, channels, height, width]
        print(f"guiding_latent shape: {guiding_latent.shape}")  # Should match [batch, channels, height, width]

        # Ensure guiding_latent has the same shape as x[:, :, 0] before assignment
        if guiding_latent.shape != x[:, :, 0].shape:
            try:
                # Check if guiding_latent is 4D or 5D
                if guiding_latent.dim() == 4:
                    # Add a frames dimension
                    guiding_latent = guiding_latent.unsqueeze(2)  # [N, C, 1, H, W]
                    print(f"guiding_latent shape after unsqueeze: {guiding_latent.shape}")
                elif guiding_latent.dim() != 5:
                    raise ValueError(f"Expected guiding_latent to have 4 or 5 dimensions, but got {guiding_latent.dim()}")

                # Resize guiding_latent spatially to match x
                guiding_latent = torch.nn.functional.interpolate(
                    guiding_latent,
                    size=(1, x.shape[3], x.shape[4]),  # (D, H, W)
                    mode='trilinear',  # Use 'trilinear' for 5D tensors
                    align_corners=False
                )
                print(f"Resized guiding_latent shape after interpolate: {guiding_latent.shape}")
            except Exception as e:
                print(f"Interpolation failed: {e}")
                raise

            # After interpolation, check again
            if guiding_latent.shape[3] != x.shape[3] or guiding_latent.shape[4] != x.shape[4]:
                # Determine target height and width
                target_height, target_width = x.shape[3], x.shape[4]

                # Current height and width
                current_height, current_width = guiding_latent.shape[3], guiding_latent.shape[4]

                # Calculate necessary padding or cropping
                pad_height = target_height - current_height
                pad_width = target_width - current_width

                # Initialize padding parameters
                # For trilinear, padding is (W_left, W_right, H_top, H_bottom, D_front, D_back)
                padding = [0, 0, 0, 0, 0, 0]

                # Apply padding or cropping for height
                if pad_height > 0:
                    padding[2] = pad_height // 2
                    padding[3] = pad_height - padding[2]
                    guiding_latent = torch.nn.functional.pad(guiding_latent, padding, mode='constant', value=0)
                    print(f"Padded guiding_latent height to: {guiding_latent.shape[3]}")
                elif pad_height < 0:
                    guiding_latent = guiding_latent[:, :, :, :target_height, :]
                    print(f"Cropped guiding_latent height to: {guiding_latent.shape[3]}")

                # Apply padding or cropping for width
                if pad_width > 0:
                    padding = [pad_width // 2, pad_width - (pad_width // 2), 0, 0, 0, 0]
                    guiding_latent = torch.nn.functional.pad(guiding_latent, padding, mode='constant', value=0)
                    print(f"Padded guiding_latent width to: {guiding_latent.shape[4]}")
                elif pad_width < 0:
                    guiding_latent = guiding_latent[:, :, :, :, :target_width]
                    print(f"Cropped guiding_latent width to: {guiding_latent.shape[4]}")

                # Final shape check
                if guiding_latent.shape[3] != target_height or guiding_latent.shape[4] != target_width:
                    raise RuntimeError(f"Final guiding_latent shape {guiding_latent.shape} does not match x shape [N, C, H, W] [{x.shape}]")

            # If patch_size_frames is 1, remove the frames dimension
            if patch_size_frames == 1 and guiding_latent.shape[2] == 1:
                guiding_latent = guiding_latent.squeeze(2)  # [N, C, H, W]
                print(f"guiding_latent shape after squeeze: {guiding_latent.shape}")

        # Final assignment
        # Ensure shapes match before assignment
        if guiding_latent.shape != x[:, :, 0].shape:
            raise RuntimeError(f"Final guiding_latent shape {guiding_latent.shape} does not match x[:, :, 0] shape {x[:, :, 0].shape}")
        x[:, :, 0] = guiding_latent
        print(f"x shape after assignment: {x[:, :, 0].shape}")

    orig_shape = list(x.shape)

    # Patchify the input
    x = self.patchifier.patchify(x)
    print(f"x shape after patchify: {x.shape}")  # Expected: [batch, channels, frames, height, width] if patch_size=(1,1,1)

    x = self.patchify_proj(x)
    print(f"x shape after patchify_proj: {x.shape}")

    timestep = timestep * 1000.0
    print(f"Timestep after scaling: {timestep.shape}")

    attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
    attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf"))  # not sure about this

    pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
    print(f"Positional encoding shapes: cos_freq={pe[0].shape}, sin_freq={pe[1].shape}")

    batch_size = x.shape[0]
    timestep, embedded_timestep = self.adaln_single(
        timestep.flatten(),
        {"resolution": None, "aspect_ratio": None},
        batch_size=batch_size,
        hidden_dtype=x.dtype,
    )
    print(f"embedded_timestep shape: {embedded_timestep.shape}")

    # Reshape timestep
    timestep = timestep.view(batch_size, -1, timestep.shape[-1])
    embedded_timestep = embedded_timestep.view(
        batch_size, -1, embedded_timestep.shape[-1]
    )
    print(f"Timestep reshaped: {timestep.shape}, embedded_timestep reshaped: {embedded_timestep.shape}")

    # 2. Blocks
    if self.caption_projection is not None:
        batch_size = x.shape[0]
        context = self.caption_projection(context)
        context = context.view(
            batch_size, -1, x.shape[-1]
        )
        print(f"context shape after projection: {context.shape}")

    blocks_replace = transformer_options.get("dit", {})
    for i, block in enumerate(self.transformer_blocks):
        if ("double_block", i) in blocks_replace:
            def block_wrap(args):
                out = {}
                out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
                return out

            out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
            x = out["img"]
        else:
            x = block(
                x,
                context=context,
                attention_mask=attention_mask,
                timestep=timestep,
                pe=pe
            )
        print(f"x shape after transformer block {i}: {x.shape}")

    # 3. Output
    scale_shift_values = (
        self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
    )
    shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
    print(f"scale_shift_values shape: {scale_shift_values.shape}")

    x = self.norm_out(x)
    print(f"x shape after norm_out: {x.shape}")

    # Modulation
    x = x * (1 + scale) + shift
    print(f"x shape after modulation: {x.shape}")

    x = self.proj_out(x)
    print(f"x shape after proj_out: {x.shape}")

    # Unpatchify the latents
    x = self.patchifier.unpatchify(
        latents=x,
        out_channels=self.out_channels,  # Added this line
        output_height=orig_shape[3],
        output_width=orig_shape[4],
        output_num_frames=orig_shape[2]
    )
    print(f"x shape after unpatchify: {x.shape}")

    if guiding_latent is not None:
        # Ensure shapes match before final assignment
        if (input_x[:, :, 0].shape != guiding_latent.shape):
            raise RuntimeError(f"Final guiding latent shape {guiding_latent.shape} does not match input_x[:, :, 0] shape {input_x[:, :, 0].shape}")
        x[:, :, 0] = (input_x[:, :, 0] - guiding_latent) / input_ts[:, :, 0]
        print(f"x shape after final assignment with guiding_latent: {x.shape}")

    # Final shape check
    print(f"Final x shape: {x.shape}")

    return x

HERE'S THE CODE for nodes_It.py:

import nodes
import node_helpers
import torch
import comfy.model_management
import comfy.model_sampling
import math

class EmptyLTXVLatentVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"

CATEGORY = "latent/video/ltxv"

def generate(self, width, height, length, batch_size=1):
    latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
    return ({"samples": latent}, )

class LTXVImgToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE",),
"image": ("IMAGE",),
"width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}

RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")

CATEGORY = "conditioning/video_models"
FUNCTION = "generate"

def generate(self, positive, negative, image, vae, width, height, length, batch_size):
    pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
    encode_pixels = pixels[:, :, :, :3]
    t = vae.encode(encode_pixels)
    positive = node_helpers.conditioning_set_values(positive, {"guiding_latent": t})
    negative = node_helpers.conditioning_set_values(negative, {"guiding_latent": t})

    latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
    
    # Ensure the dimensions match before assignment
    min_length = min(latent.shape[2], t.shape[2])
    min_height = min(latent.shape[3], t.shape[3])
    min_width = min(latent.shape[4], t.shape[4])
    
    latent[:, :, :min_length, :min_height, :min_width] = t[:, :, :min_length, :min_height, :min_width]
    return (positive, negative, {"samples": latent}, )

class LTXVConditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"frame_rate": ("FLOAT", {"default": 25.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "append"

CATEGORY = "conditioning/video_models"

def append(self, positive, negative, frame_rate):
    positive = node_helpers.conditioning_set_values(positive, {"frame_rate": frame_rate})
    negative = node_helpers.conditioning_set_values(negative, {"frame_rate": frame_rate})
    return (positive, negative)

class ModelSamplingLTXV:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}),
"base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}),
},
"optional": {"latent": ("LATENT",), }
}

RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

CATEGORY = "advanced/model"

def patch(self, model, max_shift, base_shift, latent=None):
    m = model.clone()

    if latent is None:
        tokens = 4096
    else:
        tokens = math.prod(latent["samples"].shape[2:])

    x1 = 1024
    x2 = 4096
    mm = (max_shift - base_shift) / (x2 - x1)
    b = base_shift - mm * x1
    shift = (tokens) * mm + b

    sampling_base = comfy.model_sampling.ModelSamplingFlux
    sampling_type = comfy.model_sampling.CONST

    class ModelSamplingAdvanced(sampling_base, sampling_type):
        pass

    model_sampling = ModelSamplingAdvanced(model.model.model_config)
    model_sampling.set_parameters(shift=shift)
    m.add_object_patch("model_sampling", model_sampling)
    return (m, )

class LTXVScheduler:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}),
"base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}),
"stretch": ("BOOLEAN", {
"default": True,
"tooltip": "Stretch the sigmas to be in the range [terminal, 1]."
}),
"terminal": (
"FLOAT",
{
"default": 0.1, "min": 0.0, "max": 0.99, "step": 0.01,
"tooltip": "The terminal value of the sigmas after stretching."
},
),
},
"optional": {"latent": ("LATENT",), }
}

RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"

FUNCTION = "get_sigmas"

def get_sigmas(self, steps, max_shift, base_shift, stretch, terminal, latent=None):
    if latent is None:
        tokens = 4096
    else:
        tokens = math.prod(latent["samples"].shape[2:])

    sigmas = torch.linspace(1.0, 0.0, steps + 1)

    x1 = 1024
    x2 = 4096
    mm = (max_shift - base_shift) / (x2 - x1)
    b = base_shift - mm * x1
    sigma_shift = (tokens) * mm + b

    power = 1
    sigmas = torch.where(
        sigmas != 0,
        math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power),
        0,
    )

    # Stretch sigmas so that its final value matches the given terminal value.
    if stretch:
        non_zero_mask = sigmas != 0
        non_zero_sigmas = sigmas[non_zero_mask]
        one_minus_z = 1.0 - non_zero_sigmas
        scale_factor = one_minus_z[-1] / (1.0 - terminal)
        stretched = 1.0 - (one_minus_z / scale_factor)
        sigmas[non_zero_mask] = stretched

    return (sigmas,)

NODE_CLASS_MAPPINGS = {
"EmptyLTXVLatentVideo": EmptyLTXVLatentVideo,
"LTXVImgToVideo": LTXVImgToVideo,
"ModelSamplingLTXV": ModelSamplingLTXV,
"LTXVConditioning": LTXVConditioning,
"LTXVScheduler": LTXVScheduler,
}

Please feel free to review, use, and modify my code for any future updates to prevent this from reoccurring.

Sign up or log in to comment