pulsar-clip / utils.py
neverix
For when the situation calms down
fbf122b
raw
history blame
5.31 kB
import numpy as np
import random
import torch
import math
def rotate_axis(x, add_angle=0, axis=1): # TODO Replace with a rotation matrix # But this is more fun
axes = list(range(3))
axes.remove(axis)
ax1, ax2 = axes
angle = torch.atan2(x[..., ax1], x[..., ax2])
if isinstance(add_angle, torch.Tensor):
while add_angle.ndim < angle.ndim:
add_angle = add_angle.unsqueeze(-1)
angle = angle + add_angle
dist = x.norm(dim=-1)
t = []
_, t = zip(*sorted([
(axis, x[..., axis]),
(ax1, torch.sin(angle) * dist),
(ax2, torch.cos(angle) * dist),
]))
return torch.stack(t, dim=-1)
noise_level = 0.5
# stolen from https://gist.github.com/ac1b097753f217c5c11bc2ff396e0a57
# ported from https://github.com/pvigier/perlin-numpy/blob/master/perlin2d.py
def rand_perlin_2d(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])
grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])), dim=-1) % 1
angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1)
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0],
0).repeat_interleave(
d[1], 1)
dot = lambda grad, shift: (
torch.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]),
dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1)
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
t = fade(grid[:shape[0], :shape[1]])
return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5):
noise = torch.zeros(shape)
frequency = 1
amplitude = 1
for _ in range(octaves):
noise += amplitude * rand_perlin_2d(shape, (frequency * res[0], frequency * res[1]))
frequency *= 2
amplitude *= persistence
noise *= random.random() - noise_level # haha
noise += random.random() - noise_level # haha x2
return noise
def load_clip(model_name="ViT-B/16", device="cuda:0" if torch.cuda.is_available() else "cpu"):
import clip
model, preprocess = clip.load(model_name, device=device, jit=False)
if len(preprocess.transforms) > 4:
preprocess.transforms = preprocess.transforms[-1:]
return model, preprocess
# http://blog.andreaskahler.com/2009/06/creating-icosphere-mesh-in-code.html
def ico():
phi = (1 + 5 ** 0.5) / 2
return (
np.array([
[-1, phi, 0],
[1, phi, 0],
[-1, -phi, 0],
[1, -phi, 0],
[0, -1, phi],
[0, 1, phi],
[0, -1, -phi],
[0, 1, -phi],
[phi, 0, -1],
[phi, 0, 1],
[-phi, 0, -1],
[-phi, 0, 1]
]) / phi,
[
[0, 11, 5],
[0, 5, 1],
[0, 1, 7],
[0, 7, 10],
[0, 10, 11],
[1, 5, 9],
[5, 11, 4],
[11, 10, 2],
[10, 7, 6],
[7, 1, 8],
[3, 9, 4],
[3, 4, 2],
[3, 2, 6],
[3, 6, 8],
[3, 8, 9],
[4, 9, 5],
[2, 4, 11],
[6, 2, 10],
[8, 6, 7],
[9, 8, 1]
]
)
def ico_at(xyz=np.array([0, 0, 0]), radius=1.0, i=0):
vert, idx = ico()
return vert * radius + xyz, [[y + i for y in x] for x in idx]
def save_ply(points, out_path):
with torch.inference_mode():
vert_pos, vert_col, vert_rad, vert_opa = (x.detach().cpu().numpy() for x in points)
verts = []
faces = []
for xyz, radius, (r, g, b), a in zip(vert_pos, vert_rad, vert_col, vert_opa):
v, i = ico_at(xyz, radius, len(verts))
for x, y, z in v:
verts.append((x, y, z, int(r * 255), int(g * 255), int(b * 255), int(a * 255)))
faces += i
with open(out_path, "w") as out_file:
out_file.write("ply\n")
out_file.write("format ascii 1.0\n")
out_file.write(f"element vertex {len(verts)}\n")
out_file.write("property float x\n")
out_file.write("property float y\n")
out_file.write("property float z\n")
out_file.write("property uchar red\n")
out_file.write("property uchar green\n")
out_file.write("property uchar blue\n")
out_file.write("property uchar alpha\n")
out_file.write(f"element face {len(faces)}\n")
out_file.write("property list uchar int vertex_index\n")
out_file.write("end_header\n")
for v in verts:
out_file.write(" ".join(map(str, v)) + "\n")
for f in faces:
out_file.write(" ".join(map(str, [len(f)] + f)) + "\n")