|
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import Tensor, einsum
|
|
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
|
|
from einops import rearrange
|
|
import math
|
|
import comfy.ops
|
|
|
|
class LearnedPositionalEmbedding(nn.Module):
|
|
"""Used for continuous time"""
|
|
|
|
def __init__(self, dim: int):
|
|
super().__init__()
|
|
assert (dim % 2) == 0
|
|
half_dim = dim // 2
|
|
self.weights = nn.Parameter(torch.empty(half_dim))
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
x = rearrange(x, "b -> b 1")
|
|
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi
|
|
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
|
|
fouriered = torch.cat((x, fouriered), dim=-1)
|
|
return fouriered
|
|
|
|
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
|
|
return nn.Sequential(
|
|
LearnedPositionalEmbedding(dim),
|
|
comfy.ops.manual_cast.Linear(in_features=dim + 1, out_features=out_features),
|
|
)
|
|
|
|
|
|
class NumberEmbedder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
features: int,
|
|
dim: int = 256,
|
|
):
|
|
super().__init__()
|
|
self.features = features
|
|
self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
|
|
|
|
def forward(self, x: Union[List[float], Tensor]) -> Tensor:
|
|
if not torch.is_tensor(x):
|
|
device = next(self.embedding.parameters()).device
|
|
x = torch.tensor(x, device=device)
|
|
assert isinstance(x, Tensor)
|
|
shape = x.shape
|
|
x = rearrange(x, "... -> (...)")
|
|
embedding = self.embedding(x)
|
|
x = embedding.view(*shape, self.features)
|
|
return x
|
|
|
|
|
|
class Conditioner(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
output_dim: int,
|
|
project_out: bool = False
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
self.dim = dim
|
|
self.output_dim = output_dim
|
|
self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity()
|
|
|
|
def forward(self, x):
|
|
raise NotImplementedError()
|
|
|
|
class NumberConditioner(Conditioner):
|
|
'''
|
|
Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings
|
|
'''
|
|
def __init__(self,
|
|
output_dim: int,
|
|
min_val: float=0,
|
|
max_val: float=1
|
|
):
|
|
super().__init__(output_dim, output_dim)
|
|
|
|
self.min_val = min_val
|
|
self.max_val = max_val
|
|
|
|
self.embedder = NumberEmbedder(features=output_dim)
|
|
|
|
def forward(self, floats, device=None):
|
|
|
|
floats = [float(x) for x in floats]
|
|
|
|
if device is None:
|
|
device = next(self.embedder.parameters()).device
|
|
|
|
floats = torch.tensor(floats).to(device)
|
|
|
|
floats = floats.clamp(self.min_val, self.max_val)
|
|
|
|
normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val)
|
|
|
|
|
|
embedder_dtype = next(self.embedder.parameters()).dtype
|
|
normalized_floats = normalized_floats.to(embedder_dtype)
|
|
|
|
float_embeds = self.embedder(normalized_floats).unsqueeze(1)
|
|
|
|
return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)]
|
|
|