pulsar-clip / pulsar_clip.py
neverix
Even better presets
e000f79
raw
history blame
11.7 kB
from transformers import set_seed
from tqdm.auto import trange
from PIL import Image
import numpy as np
import random
import utils
import torch
CONFIG_SPEC = [
("General", [
("text", "A cloud at dawn", str),
("iterations", 5000, (0, 7500)),
("seed", 12, int),
("show_every", 10, int),
]),
("Rendering", [
("w", 224, [224, 252]),
("h", 224, [224, 252]),
("showoff", 5000, (0, 10000)),
("turns", 4, int),
("focal_length", 0.1, float),
("plane_width", 0.1, float),
("shade_strength", 0.25, float),
("gamma", 0.5, float),
("max_depth", 7, float),
("offset", 5, float),
("offset_random", 0.75, float),
("xyz_random", 0.25, float),
("altitude_range", 0.3, float),
("augments", 4, int),
]),
("Optimization", [
("epochs", 6, int),
("lr", 0.6, float),
#@markdown CLIP loss type, might improve the results
("loss_type", "spherical", ["spherical", "cosine"]),
#@markdown CLIP loss weight
("clip_weight", 1.0, float), #@param {type: "number"}
]),
("Elements", [
("num_objects", 256, int),
#@markdown Number of dimensions. 0 is for point clouds (default), 1 will make
#@markdown strokes, 2 will make planes, 3 produces little cubes
("ndim", 0, [0, 1, 2, 3]), #@param {type: "integer"}
#@markdown Opacity scale:
("min_opacity", 1e-4, float), #@param {type: "number"}
("max_opacity", 1.0, float), #@param {type: "number"}
("log_opacity", False, bool), #@param {type: "boolean"}
("min_radius", 0.030, float),
("max_radius", 0.170, float),
("log_radius", False, bool),
# TODO dynamically decide bezier_res
#@markdown Bezier resolution: how many points a line/plane/cube will have. Not applicable to points
("bezier_res", 8, int), #@param {type: "integer"}
#@markdown Maximum scale of parameters: position, velocity, acceleration
("pos_scale", 0.4, float), #@param {type: "number"}
("vel_scale", 0.15, float), #@param {type: "number"}
("acc_scale", 0.15, float), #@param {type: "number"}
#@markdown Scale of each individual 3D object. Master control for velocity and acceleration scale.
("scale", 1, float), #@param {type: "number"}
]),
]
# TODO: one day separate the config into multiple parts and split this megaobject into multiple objects
# 2022/08/09: halfway done
class PulsarCLIP(object):
def __init__(self, args):
args = DotDict(**args)
set_seed(args.seed)
self.args = args
self.device = args.get("device", "cuda" if torch.cuda.is_available() else "cpu")
# Defer the import so that we can import `pulsar_clip` and then install `pytorch3d`
import pytorch3d.renderer.points.pulsar as ps
self.ndim = int(self.args.ndim)
self.renderer = ps.Renderer(self.args.w, self.args.h,
self.args.num_objects * (self.args.bezier_res ** self.ndim)).to(self.device)
self.bezier_pos = torch.nn.Parameter(torch.randn((args.num_objects, 4)).to(self.device))
self.bezier_vel = torch.nn.Parameter(torch.randn((args.num_objects, 3 * self.ndim)).to(self.device))
self.bezier_acc = torch.nn.Parameter(torch.randn((args.num_objects, 3 * self.ndim)).to(self.device))
self.bezier_col = torch.nn.Parameter(torch.randn((args.num_objects, 4 * (1 + self.ndim))).to(self.device))
self.optimizer = torch.optim.Adam([dict(params=[self.bezier_col], lr=5e-1 * args.lr),
dict(params=[self.bezier_pos], lr=1e-1 * args.lr),
dict(params=[self.bezier_vel, self.bezier_acc], lr=5e-2 * args.lr),
])
self.model_clip, self.preprocess_clip = utils.load_clip()
self.model_clip.visual.requires_grad_(False)
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer,
int(self.args.iterations
/ self.args.augments
/ self.args.epochs),
eta_min=args.lr / 100)
import clip
self.txt_emb = self.model_clip.encode_text(clip.tokenize([self.args.text]).to(self.device))[0].detach()
self.txt_emb = torch.nn.functional.normalize(self.txt_emb, dim=-1)
def get_points(self):
if self.ndim > 0:
bezier_ts = torch.stack(torch.meshgrid(
(torch.linspace(0, 1, self.args.bezier_res, device=self.device),) * self.ndim), dim=0
).unsqueeze(1).repeat((1, self.args.num_objects) + (1,) * self.ndim).unsqueeze(-1)
def interpolate_3D(pos, vel=0.0, acc=0.0, pos_scale=None, vel_scale=None, acc_scale=None, scale=None):
pos_scale = self.args.pos_scale if pos_scale is None else pos_scale
vel_scale = self.args.vel_scale if vel_scale is None else vel_scale
acc_scale = self.args.acc_scale if acc_scale is None else acc_scale
scale = self.args.scale if scale is None else scale
if self.ndim == 0:
return pos * pos_scale
result = 0.0
s = pos.shape[-1]
assert s * self.ndim == vel.shape[-1] == acc.shape[-1]
# O(dim) sequential lol
for d, bezier_t in zip(range(self.ndim), bezier_ts): # TODO replace with fused dimension operation
result = (result
+ torch.tanh(vel[..., d * s:(d + 1) * s]).view(
(-1,) + (1,) * self.ndim + (s,)) * vel_scale * bezier_t
+ torch.tanh(acc[..., d * s:(d + 1) * s]).view(
(-1,) + (1,) * self.ndim + (s,)) * acc_scale * bezier_t.pow(2))
result = (result * scale
+ torch.tanh(pos[..., :s]).view((-1,) + (1,) * self.ndim + (s,)) * pos_scale).view(-1, s)
return result
vert_pos = interpolate_3D(self.bezier_pos[..., :3], self.bezier_vel, self.bezier_acc)
vert_col = interpolate_3D(self.bezier_col[..., :4],
self.bezier_col[..., 4:4 + 4 * self.ndim],
self.bezier_col[..., -4 * self.ndim:])
to_bezier = lambda x: x.view((-1,) + (1,) * self.ndim + (x.shape[-1],)).repeat(
(1,) + (self.args.bezier_res,) * self.ndim + (1,)).reshape(-1, x.shape[-1])
rescale = lambda x, a, b, is_log=False: (torch.exp(x
* np.log(b / a)
+ np.log(a))) if is_log else x * (b - a) + a
return (
vert_pos,
torch.sigmoid(vert_col[..., :3]),
rescale(
torch.sigmoid(to_bezier(self.bezier_pos[..., -1:])[..., 0]),
self.args.min_radius, self.args.max_radius, is_log=self.args.log_radius
),
rescale(torch.sigmoid(vert_col[..., -1]),
self.args.min_opacity, self.args.max_opacity, is_log=self.args.log_opacity))
def camera(self, angle, altitude=0.0, offset=None, use_random=True, offset_random=None,
xyz_random=None, focal_length=None, plane_width=None):
if offset is None:
offset = self.args.offset
if xyz_random is None:
xyz_random = self.args.xyz_random
if focal_length is None:
focal_length = self.args.focal_length
if plane_width is None:
plane_width = self.args.plane_width
if offset_random is None:
offset_random = self.args.offset_random
device = self.device
offset = offset + np.random.normal() * offset_random * int(use_random)
position = torch.tensor([0, 0, -offset], dtype=torch.float)
position = utils.rotate_axis(position, altitude, 0)
position = utils.rotate_axis(position, angle, 1)
position = position + torch.randn(3) * xyz_random * int(use_random)
return torch.tensor([position[0], position[1], position[2],
altitude, angle, 0,
focal_length, plane_width], dtype=torch.float, device=device)
def render(self, cam_params=None):
if cam_params is None:
cam_params = self.camera(0, 0)
vert_pos, vert_col, radius, opacity = self.get_points()
rgb = self.renderer(vert_pos, vert_col, radius, cam_params,
self.args.gamma, self.args.max_depth, opacity=opacity)
opacity = self.renderer(vert_pos, vert_col * 0, radius, cam_params,
self.args.gamma, self.args.max_depth, opacity=opacity)
return rgb, opacity
def random_view_render(self):
angle = random.uniform(0, np.pi * 2)
altitude = random.uniform(-self.args.altitude_range / 2, self.args.altitude_range / 2)
cam_params = self.camera(angle, altitude)
result, alpha = self.render(cam_params)
back = torch.zeros_like(result)
s = back.shape
for j in range(s[-1]):
n = random.choice([7, 14, 28])
back[..., j] = utils.rand_perlin_2d_octaves(s[:-1], (n, n)).clip(-0.5, 0.5) + 0.5
result = result * (1 - alpha) + back * alpha
return result
def generate(self):
self.optimizer.zero_grad()
try:
for i in trange(self.args.iterations + self.args.showoff):
if i < self.args.iterations:
result = self.random_view_render()
img_emb = self.model_clip.encode_image(
self.preprocess_clip(result.permute(2, 0, 1)).unsqueeze(0).clamp(0., 1.))
img_emb = torch.nn.functional.normalize(img_emb, dim=-1)
if self.args.loss_type == "spherical":
clip_loss = (img_emb - self.txt_emb).norm(dim=-1).div(2).arcsin().pow(2).mul(2).mean()
elif self.args.loss_type == "cosine":
clip_loss = (1 - img_emb @ self.txt_emb.T).mean()
else:
raise NotImplementedError(f"CLIP loss type not supported: {self.args.loss_type}")
loss = clip_loss * self.args.clip_weight + (0 and ...) # TODO add more loss types
loss.backward()
if i % self.args.augments == self.args.augments - 1:
self.optimizer.step()
self.optimizer.zero_grad()
try:
self.scheduler.step()
except AttributeError:
pass
if i % self.args.show_every == 0:
cam_params = self.camera(i / self.args.iterations * np.pi * 2 * self.args.turns, use_random=False)
img_show, _ = self.render(cam_params)
img = Image.fromarray((img_show.cpu().detach().numpy() * 255).astype(np.uint8))
yield img
except KeyboardInterrupt:
pass
def save_obj(self, fn):
utils.save_obj(self.get_points(), fn)
class DotDict(dict):
def __getattr__(self, item):
return self.__getitem__(item)