gan-control / app.py
hysts's picture
hysts HF staff
Update
d09e40b
#!/usr/bin/env python
from __future__ import annotations
import functools
import os
import pathlib
import shlex
import subprocess
import sys
import tarfile
import gradio as gr
import huggingface_hub
import numpy as np
import PIL.Image
import torch
if os.getenv("SYSTEM") == "spaces":
with open("patch") as f:
subprocess.run(shlex.split("patch -p1"), cwd="gan-control", stdin=f)
sys.path.insert(0, "gan-control/src")
from gan_control.inference.controller import Controller
TITLE = "GAN-Control"
DESCRIPTION = "https://github.com/amazon-research/gan-control"
def download_models() -> None:
model_dir = pathlib.Path("controller_age015id025exp02hai04ori02gam15")
if not model_dir.exists():
path = huggingface_hub.hf_hub_download(
"public-data/gan-control", "controller_age015id025exp02hai04ori02gam15.tar.gz"
)
with tarfile.open(path) as f:
f.extractall()
@torch.inference_mode()
def run(
seed: int,
truncation: float,
yaw: int,
pitch: int,
age: int,
hair_color_r: float,
hair_color_g: float,
hair_color_b: float,
nrows: int,
ncols: int,
controller: Controller,
device: torch.device,
) -> PIL.Image.Image:
seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
batch_size = nrows * ncols
latent_size = controller.config.model_config["latent_size"]
latent = torch.from_numpy(np.random.RandomState(seed).randn(batch_size, latent_size)).float().to(device)
initial_image_tensors, initial_latent_z, initial_latent_w = controller.gen_batch(
latent=latent, truncation=truncation
)
res0 = controller.make_resized_grid_image(initial_image_tensors, nrow=ncols)
pose_control = torch.tensor([[yaw, pitch, 0]], dtype=torch.float32)
image_tensors, _, modified_latent_w = controller.gen_batch_by_controls(
latent=initial_latent_w, input_is_latent=True, orientation=pose_control
)
res1 = controller.make_resized_grid_image(image_tensors, nrow=ncols)
age_control = torch.tensor([[age]], dtype=torch.float32)
image_tensors, _, modified_latent_w = controller.gen_batch_by_controls(
latent=initial_latent_w, input_is_latent=True, age=age_control
)
res2 = controller.make_resized_grid_image(image_tensors, nrow=ncols)
hair_color = torch.tensor([[hair_color_r, hair_color_g, hair_color_b]], dtype=torch.float32) / 255
hair_color = torch.clamp(hair_color, 0, 1)
image_tensors, _, modified_latent_w = controller.gen_batch_by_controls(
latent=initial_latent_w, input_is_latent=True, hair=hair_color
)
res3 = controller.make_resized_grid_image(image_tensors, nrow=ncols)
return res0, res1, res2, res3
download_models()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
path = "controller_age015id025exp02hai04ori02gam15/"
controller = Controller(path, device)
fn = functools.partial(run, controller=controller, device=device)
demo = gr.Interface(
fn=fn,
inputs=[
gr.Slider(label="Seed", minimum=0, maximum=1000000, step=1, value=0),
gr.Slider(label="Truncation", minimum=0, maximum=1, step=0.1, value=0.7),
gr.Slider(label="Yaw", minimum=-90, maximum=90, step=1, value=30),
gr.Slider(label="Pitch", minimum=-90, maximum=90, step=1, value=0),
gr.Slider(label="Age", minimum=15, maximum=75, step=1, value=75),
gr.Slider(label="Hair Color (R)", minimum=0, maximum=255, step=1, value=186),
gr.Slider(label="Hair Color (G)", minimum=0, maximum=255, step=1, value=158),
gr.Slider(label="Hair Color (B)", minimum=0, maximum=255, step=1, value=92),
gr.Slider(label="Number of Rows", minimum=1, maximum=3, step=1, value=1),
gr.Slider(label="Number of Columns", minimum=1, maximum=5, step=1, value=5),
],
outputs=[
gr.Image(label="Generated Image"),
gr.Image(label="Head Pose Controlled"),
gr.Image(label="Age Controlled"),
gr.Image(label="Hair Color Controlled"),
],
title=TITLE,
description=DESCRIPTION,
)
if __name__ == "__main__":
demo.queue(max_size=10).launch()