AlekseyCalvin's picture
Create app.py
eb164bc verified
raw
history blame
35.3 kB
import os
if os.environ.get("SPACES_ZERO_GPU") is not None:
import spaces
else:
class spaces:
@staticmethod
def GPU(func):
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
import gradio as gr
import json
import logging
import argparse
import torch
import torchvision
from os import path
from PIL import Image
import numpy as np
import spaces
import copy
import random
import time
from torchvision import transforms
from dataclasses import dataclass
import math
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
from huggingface_hub import hf_hub_download, snapshot_download
from diffusers import DiffusionPipeline, AutoencoderTiny, AutoPipelineForImage2Image
from diffusers.models.transformers import FluxTransformer2DModel
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
import safetensors.torch
from safetensors.torch import load_file
import random
from tqdm import tqdm
from einops import rearrange, repeat
from torch import Tensor, nn
from pipeline import FluxWithCFGPipeline
from diffusers.models.autoencoders import AutoencoderKL
from transformers import CLIPModel, CLIPProcessor, CLIPTextModel, CLIPTokenizer, CLIPConfig, T5EncoderModel, T5Tokenizer
import gc
import warnings
model_path = snapshot_download(repo_id="Kijai/OpenFLUX-comfy")
cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
os.environ["TRANSFORMERS_CACHE"] = cache_path
os.environ["HF_HUB_CACHE"] = cache_path
os.environ["HF_HOME"] = cache_path
device = "cuda" if torch.cuda.is_available() else "cpu"
#torch.backends.cuda.matmul.allow_tf32 = True
# Load LoRAs from JSON file
with open('loras.json', 'r') as f:
loras = json.load(f)
dtype = torch.bfloat16
#clipmodel = 'norm'
#if clipmodel == "long":
# model_id = "zer0int/LongCLIP-GmP-ViT-L-14"
# config = CLIPConfig.from_pretrained(model_id)
# maxtokens = 77
#if clipmodel == "norm":
# model_id = "zer0int/CLIP-GmP-ViT-L-14"
# config = CLIPConfig.from_pretrained(model_id)
# maxtokens = 77
#clip_model = CLIPModel.from_pretrained(model_id, torch_dtype=torch.bfloat16, config=config, ignore_mismatched_sizes=True).to("cuda")
#clip_processor = CLIPProcessor.from_pretrained(model_id, padding="max_length", max_length=maxtokens, ignore_mismatched_sizes=True, return_tensors="pt", truncation=True)
#pipe.tokenizer = clip_processor.tokenizer
#pipe.text_encoder = clip_model.text_model
#pipe.tokenizer_max_length = maxtokens
#pipe.text_encoder.dtype = torch.bfloat16
class HFEmbedder(nn.Module):
def __init__(self, version: str, max_length: int, **hf_kwargs):
super().__init__()
self.is_clip = version.startswith("openai")
self.max_length = max_length
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
if self.is_clip:
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
else:
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
self.hf_module = self.hf_module.eval().requires_grad_(False)
def forward(self, text: list[str]) -> Tensor:
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=False,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
outputs = self.hf_module(
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
attention_mask=None,
output_hidden_states=False,
)
return outputs[self.output_key]
device = "cuda"
t5 = HFEmbedder("DeepFloyd/t5-v1_1-xxl", max_length=512, torch_dtype=torch.bfloat16).to(device)
clip = HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
ae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device)
# quantize(t5, weights=qfloat8)
# freeze(t5)
# ---------------- Model ----------------
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
q, k = apply_rope(q, k, pe)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
# x = rearrange(x, "B H L D -> B L (H D)")
x = x.permute(0, 2, 1, 3).reshape(x.size(0), x.size(2), -1)
return x
def rope(pos, dim, theta):
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta ** scale)
# out = torch.einsum("...n,d->...nd", pos, omega)
out = pos.unsqueeze(-1) * omega.unsqueeze(0)
cos_out = torch.cos(out)
sin_out = torch.sin(out)
out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
# out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
b, n, d, _ = out.shape
out = out.view(b, n, d, 2, 2)
return out.float()
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: Tensor) -> Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(1)
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
t = time_factor * t
half = dim // 2
# Do not block CUDA steam, but having about 1e-4 differences with Flux official codes:
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
# Block CUDA steam, but consistent with official codes:
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
if torch.is_floating_point(t):
embedding = embedding.to(t)
return embedding
class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int):
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
self.silu = nn.SiLU()
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.scale = nn.Parameter(torch.ones(dim))
def forward(self, x: Tensor):
x_dtype = x.dtype
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return (x * rrms).to(dtype=x_dtype) * self.scale
class QKNorm(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.query_norm = RMSNorm(dim)
self.key_norm = RMSNorm(dim)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
q = self.query_norm(q)
k = self.key_norm(k)
return q.to(v), k.to(v)
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim)
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
qkv = self.qkv(x)
# q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
B, L, _ = qkv.shape
qkv = qkv.view(B, L, 3, self.num_heads, -1)
q, k, v = qkv.permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v)
x = attention(q, k, v, pe=pe)
x = self.proj(x)
return x
@dataclass
class ModulationOut:
shift: Tensor
scale: Tensor
gate: Tensor
class Modulation(nn.Module):
def __init__(self, dim: int, double: bool):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
return (
ModulationOut(*out[:3]),
ModulationOut(*out[3:]) if self.is_double else None,
)
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
self.txt_mod = Modulation(hidden_size, double=True)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
# img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
B, L, _ = img_qkv.shape
H = self.num_heads
D = img_qkv.shape[-1] // (3 * H)
img_q, img_k, img_v = img_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
# txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
B, L, _ = txt_qkv.shape
txt_q, txt_k, txt_v = txt_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
q = torch.cat((txt_q, img_q), dim=2)
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)
attn = attention(q, k, v, pe=pe)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
return img, txt
class SingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float | None = None,
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
# proj and mlp_out
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
self.norm = QKNorm(head_dim)
self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
# q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
qkv = qkv.view(qkv.size(0), qkv.size(1), 3, self.num_heads, self.hidden_size // self.num_heads)
q, k, v = qkv.permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + mod.gate * output
class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x
class FluxParams:
in_channels: int = 64
vec_in_dim: int = 768
context_in_dim: int = 4096
hidden_size: int = 3072
mlp_ratio: float = 4.0
num_heads: int = 24
depth: int = 19
depth_single_blocks: int = 38
axes_dim: list = [16, 56, 56]
theta: int = 10_000
qkv_bias: bool = True
guidance_embed: bool = True
class Flux(nn.Module):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, params = FluxParams()):
super().__init__()
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
# self.guidance_in = (
# MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
# )
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
for _ in range(params.depth_single_blocks)
]
)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
def forward(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor | None = None,
use_guidance_vec = True,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256))
# if self.params.guidance_embed and use_guidance_vec:
# if guidance is None:
# raise ValueError("Didn't get guidance strength for guidance distilled model.")
# vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
bs, c, h, w = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
if isinstance(prompt, str):
prompt = [prompt]
txt = t5(prompt)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, txt.shape[1], 3)
vec = clip(prompt)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
return {
"img": img,
"img_ids": img_ids.to(img.device),
"txt": txt.to(img.device),
"txt_ids": txt_ids.to(img.device),
"vec": vec.to(img.device),
}
def time_shift(mu: float, sigma: float, t: Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# eastimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
def denoise(
model: Flux,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
# sampling parameters
timesteps: list[float],
guidance: float = 4.0,
use_cfg_guidance = False,
):
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:])):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
if use_cfg_guidance:
half_x = img[:len(img)//2]
img = torch.cat([half_x, half_x], dim=0)
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
joint_attention_kwargs=self.joint_attention_kwargs,
use_guidance_vec=not use_cfg_guidance,
)
if use_cfg_guidance:
uncond, cond = pred.chunk(2, dim=0)
model_output = uncond + guidance * (cond - uncond)
pred = torch.cat([model_output, model_output], dim=0)
img = img + (t_prev - t_curr) * pred
return img
def unpack(x: Tensor, height: int, width: int) -> Tensor:
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(height / 16),
w=math.ceil(width / 16),
ph=2,
pw=2,
)
@dataclass
class SamplingOptions:
prompt: str
width: int
height: int
guidance: float
seed: int | None
joint_attention_kwargs: Any | None
def get_image(image) -> torch.Tensor | None:
if image is None:
return None
image = Image.fromarray(image).convert("RGB")
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: 2.0 * x - 1.0),
])
img: torch.Tensor = transform(image)
return img[None, ...]
# ---------------- Demo ----------------
class EmptyInitWrapper(torch.overrides.TorchFunctionMode):
def __init__(self, device=None):
self.device = device
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if getattr(func, "__module__", None) == "torch.nn.init":
if "tensor" in kwargs:
return kwargs["tensor"]
else:
return args[0]
if (
self.device is not None
and func in torch.utils._device._device_constructors()
and kwargs.get("device") is None
):
kwargs["device"] = self.device
return func(*args, **kwargs)
with EmptyInitWrapper():
model = Flux().to(dtype=torch.bfloat16, device="cuda")
sd = load_file(f"{model_path}/OpenFlux-fp8_e4m3fn.safetensors")
sd = {k.replace("model.", ""): v for k, v in sd.items()}
result = model.load_state_dict(sd)
@torch.cuda.empty_cache()
@spaces.GPU(duration=70)
@torch.no_grad()
def generate_image(
prompt, neg_prompt,num_steps ,width, height, guidance, seed,
do_img2img, init_image, image2image_strength, resize_img, lora_scale,
progress=gr.Progress(track_tqdm=True)
):
if seed == 0:
seed = int(random.random() * 1000000)
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_device = torch.device(device)
lora_scale = (self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None)
if do_img2img and init_image is not None:
init_image = get_image(init_image)
if resize_img:
init_image = torch.nn.functional.interpolate(init_image, (height, width))
else:
h, w = init_image.shape[-2:]
init_image = init_image[..., : 16 * (h // 16), : 16 * (w // 16)]
height = init_image.shape[-2]
width = init_image.shape[-1]
init_image = ae.encode(init_image.to(torch_device)).latent_dist.sample()
init_image = (init_image - ae.config.shift_factor) * ae.config.scaling_factor
generator = torch.Generator(device=device).manual_seed(seed)
x = torch.randn(1, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16), device=device, dtype=torch.bfloat16, generator=generator)
# num_steps = 28
timesteps = get_schedule(num_steps, (x.shape[-1] * x.shape[-2]) // 4, shift=True)
if do_img2img and init_image is not None:
t_idx = int((1 - image2image_strength) * num_steps)
t = timesteps[t_idx]
timesteps = timesteps[t_idx:]
x = t * x + (1.0 - t) * init_image.to(x.dtype)
inp = prepare(t5=t5, clip=clip, img=x, prompt=[neg_prompt, prompt])
x = denoise(model, **inp, timesteps=timesteps, guidance=guidance, use_cfg_guidance=True, )
# with profile(activities=[ProfilerActivity.CPU],record_shapes=True,profile_memory=True) as prof:
# print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
x = unpack(x.float(), height, width)
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
x = x = (x / ae.config.scaling_factor) + ae.config.shift_factor
x = ae.decode(x).sample
x = x.clamp(-1, 1)
x = rearrange(x[0], "c h w -> h w c")
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
return img, seed
class calculateDuration:
def __init__(self, activity_name=""):
self.activity_name = activity_name
def __enter__(self):
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.end_time = time.time()
self.elapsed_time = self.end_time - self.start_time
if self.activity_name:
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
else:
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
def update_selection(evt: gr.SelectData, width, height):
selected_lora = loras[evt.index]
new_placeholder = f"Type a prompt for {selected_lora['title']}"
lora_repo = selected_lora["repo"]
updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
if "aspect" in selected_lora:
if selected_lora["aspect"] == "portrait":
width = 768
height = 1024
elif selected_lora["aspect"] == "landscape":
width = 1024
height = 768
return (
gr.update(placeholder=new_placeholder),
updated_text,
evt.index,
width,
height,
)
def run_lora(
prompt, neg_prompt, num_steps, width, height, selected_index, guidance, seed, do_img2img, init_image,
image2image_strength, resize_img, lora_scale, progress=gr.Progress(track_tqdm=True)
):
if neg_prompt == "":
neg_prompt = None
if selected_index is None:
raise gr.Error("You must select a LoRA before proceeding.")
selected_lora = loras[selected_index]
lora_scale = (self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None)
lora_path = selected_lora["repo"]
trigger_word = selected_lora["trigger_word"]
# Load LoRA weights
with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
if "weights" in selected_lora:
pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
else:
pipe.load_lora_weights(lora_path)
# Set random seed for reproducibility
with calculateDuration("Randomizing seed"):
if randomize_seed:
seed = random.randint(0, 2**32-1)
image = generate_image(prompt, neg_prompt, num_steps, guidance, width, height, guidance, seed, do_img2img, init_image, image2image_strength, resize_img, lora_scale, progress)
pipe.to("cpu")
pipe.unload_lora_weights()
return image, seed
run_lora.zerogpu = True
css = '''
#gen_btn{height: 100%}
#title{text-align: center}
#title h1{font-size: 3em; display:inline-flex; align-items:center}
#title img{width: 100px; margin-right: 0.5em}
#gallery .grid-wrap{height: 10vh}
'''
def create_demo():
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
title = gr.HTML(
"""<h1><img src="https://huggingface.co/AlekseyCalvin/HSTklimbimOPENfluxLora/resolve/main/acs62iv.png" alt="LoRA">OpenFlux LoRAsoon®</h1>""",
elem_id="title",
)
# Info blob stating what the app is running
info_blob = gr.HTML(
"""<div id="info_blob"> SOON®'s curated LoRa Gallery & Art Manufactory Space.|Runs on Ostris' OpenFLUX.1 model + fast-gen LoRA & Zer0int's fine-tuned CLIP-GmP-ViT-L-14*! (*'normal' 77 tokens)| Largely stocked w/our trained LoRAs: Historic Color, Silver Age Poets, Sots Art, more!|</div>"""
)
# Info blob stating what the app is running
info_blob = gr.HTML(
"""<div id="info_blob"> *Auto-planting of prompts with a choice LoRA trigger errors out in this space over flaws yet unclear. In its stead, we pose numbered LoRA-box rows & a matched token cheat-sheet: ungainly & free. So, prephrase your prompts w/: 1-2. HST style autochrome |3. RCA style Communist poster |4. SOTS art |5. HST Austin Osman Spare style |6. Vladimir Mayakovsky |7-8. Marina Tsvetaeva Tsvetaeva_02.CR2 |9. Anna Akhmatova |10. Osip Mandelshtam |11-12. Alexander Blok |13. Blok_02.CR2 |14. LEN Lenin |15. Leon Trotsky |16. Rosa Fluxemburg |17. HST Peterhof photo |18-19. HST |20. HST portrait |21. HST |22. HST 80s Perestroika-era Soviet photo |23-30. HST |31. How2Draw a__ |32. propaganda poster |33. TOK hybrid photo of__ with cartoon of__ |34. 2004 IMG_1099.CR2 photo |35. unexpected photo of |36. flmft |37. 80s yearbook photo |38. TOK portra |39. pficonics |40. retrofuturism |41. wh3r3sw4ld0 |42. amateur photo |43. crisp |44-45. IMG_1099.CR2 |46. FilmFotos |47. ff-collage |48. HST |49-50. AOS |51. cover </div>"""
)
selected_index = gr.State(None)
with gr.Row():
with gr.Column(scale=2):
prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Select LoRa/Style & type prompt!")
with gr.Row():
with gr.Column(scale=1):
neg_prompt = gr.Textbox(label="Negative Prompt", lines=1, placeholder="List unwanted conditions, open-fluxedly!")
with gr.Column(scale=1, elem_id="gen_column"):
generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
with gr.Column(scale=1, elem_id="gen_column"):
do_img2img = gr.Checkbox(label="Image to Image", value=False)
with gr.Column(scale=1, elem_id="gen_column"):
init_image = gr.Image(label="Input Image", visible=False)
with gr.Column(scale=1, elem_id="gen_column"):
image2image_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Noising strength", value=0.8, visible=False)
with gr.Column(scale=1, elem_id="gen_column"):
do_img2img = gr.Checkbox(label="Image to Image", value=False)
with gr.Column(scale=1, elem_id="gen_column"):
resize_img = gr.Checkbox(label="Resize image", value=True, visible=False)
with gr.Column():
generate_button = gr.Button("Generate")
with gr.Row():
with gr.Column(scale=1):
selected_info = gr.Markdown("")
gallery = gr.Gallery(
[(item["image"], item["title"]) for item in loras],
label="LoRA Inventory",
allow_preview=False,
columns=1,
elem_id="gallery")
with gr.Column(scale=2):
result = gr.Image(label="Generated Image")
with gr.Row():
with gr.Accordion("Advanced Settings", open=True):
with gr.Column():
with gr.Row():
guidance = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=1, value=3)
num_steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=6)
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=768)
height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=768)
with gr.Row():
randomize_seed = gr.Checkbox(True, label="Randomize seed")
seed = gr.Slider(label="Seed", minimum=0, maximum=2**32-1, step=1, value=0, randomize=True)
lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.95)
gallery.select(
update_selection,
inputs=[width, height],
outputs=[prompt, selected_index, width, height]
)
gr.on(
triggers=[generate_button.click, prompt.submit],
fn=run_lora,
inputs=[prompt, num_steps, selected_index, width, height, guidance, seed, neg_prompt, lora_scale],
outputs=[result, seed]
)
gr.on(
triggers=[generate_button.click, prompt.submit],
fn=generate_image,
inputs=[prompt, num_steps, selected_index, width, height, guidance, seed, neg_prompt],
outputs=[result, seed]
)
demo = create_demo()
demo.launch(share=True)