Shap-slahai / model.py
hysts's picture
hysts HF staff
Add files
feb3220
raw
history blame
5.76 kB
import tempfile
import imageio
import numpy as np
import PIL.Image
import torch
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.diffusion.sample import sample_latents
from shap_e.models.download import load_config, load_model
from shap_e.models.nn.camera import (DifferentiableCameraBatch,
DifferentiableProjectiveCamera)
from shap_e.models.transmitter.base import Transmitter, VectorDecoder
from shap_e.util.collections import AttrDict
from shap_e.util.image_util import load_image
# Copied from https://github.com/openai/shap-e/blob/d99cedaea18e0989e340163dbaeb4b109fa9e8ec/shap_e/util/notebooks.py#L15-L42
def create_pan_cameras(size: int,
device: torch.device) -> DifferentiableCameraBatch:
origins = []
xs = []
ys = []
zs = []
for theta in np.linspace(0, 2 * np.pi, num=20):
z = np.array([np.sin(theta), np.cos(theta), -0.5])
z /= np.sqrt(np.sum(z**2))
origin = -z * 4
x = np.array([np.cos(theta), -np.sin(theta), 0.0])
y = np.cross(z, x)
origins.append(origin)
xs.append(x)
ys.append(y)
zs.append(z)
return DifferentiableCameraBatch(
shape=(1, len(xs)),
flat_camera=DifferentiableProjectiveCamera(
origin=torch.from_numpy(np.stack(origins,
axis=0)).float().to(device),
x=torch.from_numpy(np.stack(xs, axis=0)).float().to(device),
y=torch.from_numpy(np.stack(ys, axis=0)).float().to(device),
z=torch.from_numpy(np.stack(zs, axis=0)).float().to(device),
width=size,
height=size,
x_fov=0.7,
y_fov=0.7,
),
)
# Copied from https://github.com/openai/shap-e/blob/d99cedaea18e0989e340163dbaeb4b109fa9e8ec/shap_e/util/notebooks.py#L45-L60
@torch.no_grad()
def decode_latent_images(
xm: Transmitter | VectorDecoder,
latent: torch.Tensor,
cameras: DifferentiableCameraBatch,
rendering_mode: str = 'stf',
):
decoded = xm.renderer.render_views(
AttrDict(cameras=cameras),
params=(xm.encoder if isinstance(xm, Transmitter) else
xm).bottleneck_to_params(latent[None]),
options=AttrDict(rendering_mode=rendering_mode,
render_with_direction=False),
)
arr = decoded.channels.clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
return [PIL.Image.fromarray(x) for x in arr]
class Model:
def __init__(self):
self.device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
self.xm = load_model('transmitter', device=self.device)
self.diffusion = diffusion_from_config(load_config('diffusion'))
self.model_name = ''
self.model = None
def load_model(self, model_name: str) -> None:
assert model_name in ['text300M', 'image300M']
if model_name == self.model_name:
return
self.model = load_model(model_name, device=self.device)
self.model_name = model_name
@staticmethod
def to_video(frames: list[PIL.Image.Image], fps: int = 5) -> str:
out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
writer = imageio.get_writer(out_file.name, format='FFMPEG', fps=fps)
for frame in frames:
writer.append_data(np.asarray(frame))
writer.close()
return out_file.name
def run_text(self,
prompt: str,
seed: int = 0,
guidance_scale: float = 15.0,
num_steps: int = 64,
output_image_size: int = 64,
render_mode: str = 'nerf') -> str:
self.load_model('text300M')
torch.manual_seed(seed)
latents = sample_latents(
batch_size=1,
model=self.model,
diffusion=self.diffusion,
guidance_scale=guidance_scale,
model_kwargs=dict(texts=[prompt]),
progress=True,
clip_denoised=True,
use_fp16=True,
use_karras=True,
karras_steps=num_steps,
sigma_min=1e-3,
sigma_max=160,
s_churn=0,
)
cameras = create_pan_cameras(output_image_size, self.device)
frames = decode_latent_images(self.xm,
latents[0],
cameras,
rendering_mode=render_mode)
return self.to_video(frames)
def run_image(self,
image_path: str,
seed: int = 0,
guidance_scale: float = 3.0,
num_steps: int = 64,
output_image_size: int = 64,
render_mode: str = 'nerf') -> str:
self.load_model('image300M')
torch.manual_seed(seed)
image = load_image(image_path)
latents = sample_latents(
batch_size=1,
model=self.model,
diffusion=self.diffusion,
guidance_scale=guidance_scale,
model_kwargs=dict(images=[image]),
progress=True,
clip_denoised=True,
use_fp16=True,
use_karras=True,
karras_steps=num_steps,
sigma_min=1e-3,
sigma_max=160,
s_churn=0,
)
cameras = create_pan_cameras(output_image_size, self.device)
frames = decode_latent_images(self.xm,
latents[0],
cameras,
rendering_mode=render_mode)
return self.to_video(frames)