Spaces:
Build error
Build error
johnowhitaker
commited on
Commit
•
d171496
1
Parent(s):
36d8a22
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os, glob
|
3 |
+
from functools import partial
|
4 |
+
import glob
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from PIL import Image
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
11 |
+
|
12 |
+
class RuleCA(nn.Module):
|
13 |
+
def __init__(self, hidden_n=6, rule_channels=4, zero_w2=True, device=device):
|
14 |
+
super().__init__()
|
15 |
+
# The hard-coded filters:
|
16 |
+
self.filters = torch.stack([torch.tensor([[0.0,0.0,0.0],[0.0,1.0,0.0],[0.0,0.0,0.0]]),
|
17 |
+
torch.tensor([[-1.0,0.0,1.0],[-2.0,0.0,2.0],[-1.0,0.0,1.0]]),
|
18 |
+
torch.tensor([[-1.0,0.0,1.0],[-2.0,0.0,2.0],[-1.0,0.0,1.0]]).T,
|
19 |
+
torch.tensor([[1.0,2.0,1.0],[2.0,-12,2.0],[1.0,2.0,1.0]])]).to(device)
|
20 |
+
self.chn = 4
|
21 |
+
self.rule_channels = rule_channels
|
22 |
+
self.w1 = nn.Conv2d(4*4+rule_channels, hidden_n, 1).to(device)
|
23 |
+
self.relu = nn.ReLU()
|
24 |
+
self.w2 = nn.Conv2d(hidden_n, 4, 1, bias=False).to(device)
|
25 |
+
if zero_w2:
|
26 |
+
self.w2.weight.data.zero_()
|
27 |
+
self.device = device
|
28 |
+
|
29 |
+
def perchannel_conv(self, x, filters):
|
30 |
+
'''filters: [filter_n, h, w]'''
|
31 |
+
b, ch, h, w = x.shape
|
32 |
+
y = x.reshape(b*ch, 1, h, w)
|
33 |
+
y = torch.nn.functional.pad(y, [1, 1, 1, 1], 'circular')
|
34 |
+
y = torch.nn.functional.conv2d(y, filters[:,None])
|
35 |
+
return y.reshape(b, -1, h, w)
|
36 |
+
|
37 |
+
def forward(self, x, rule=0, update_rate=0.5):
|
38 |
+
b, ch, xsz, ysz = x.shape
|
39 |
+
rule_grid = torch.zeros(b, self.rule_channels, xsz, ysz).to(self.device)
|
40 |
+
rule_grid[:,rule] = 1
|
41 |
+
y = self.perchannel_conv(x, self.filters) # Apply the filters
|
42 |
+
y = torch.cat([y, rule_grid], dim=1)
|
43 |
+
y = self.w2(self.relu(self.w1(y))) # pass the result through out 'brain'
|
44 |
+
b, c, h, w = y.shape
|
45 |
+
update_mask = (torch.rand(b, 1, h, w).to(self.device)+update_rate).floor()
|
46 |
+
return x+y*update_mask
|
47 |
+
|
48 |
+
def forward_w_rule_grid(self, x, rule_grid, update_rate=0.5):
|
49 |
+
y = self.perchannel_conv(x, self.filters) # Apply the filters
|
50 |
+
y = torch.cat([y, rule_grid], dim=1)
|
51 |
+
y = self.w2(self.relu(self.w1(y))) # pass the result through out 'brain'
|
52 |
+
b, c, h, w = y.shape
|
53 |
+
update_mask = (torch.rand(b, 1, h, w).to(self.device)+update_rate).floor()
|
54 |
+
return x+y*update_mask
|
55 |
+
|
56 |
+
def to_rgb(self, x):
|
57 |
+
# TODO: rename this to_rgb & explain
|
58 |
+
return x[...,:3,:,:]+0.5
|
59 |
+
|
60 |
+
def seed(self, n, sz=128):
|
61 |
+
"""Initializes n 'grids', size sz. In this case all 0s."""
|
62 |
+
return torch.zeros(n, self.chn, sz, sz).to(self.device)
|
63 |
+
|
64 |
+
def to_frames(video_file):
|
65 |
+
os.system('rm -r guide_frames;mkdir guide_frames')
|
66 |
+
os.system(f"ffmpeg -i {video_file} guide_frames/%04d.jpg")
|
67 |
+
|
68 |
+
def update(preset, enhance, video_file):
|
69 |
+
|
70 |
+
# Load presets
|
71 |
+
ca = RuleCA(hidden_n=32, rule_channels=3)
|
72 |
+
ca_fn = ''
|
73 |
+
if preset == 'Glowing Crystals':
|
74 |
+
ca_fn = 'glowing_crystals.pt'
|
75 |
+
elif preset == 'Rainbow Diamonds':
|
76 |
+
ca_fn = 'rainbow_diamonds.pt'
|
77 |
+
elif preset == 'Dark Diamonds':
|
78 |
+
ca_fn = 'dark_diamonds.pt'
|
79 |
+
elif preset == 'Dragon Scales':
|
80 |
+
ca = RuleCA(hidden_n=16, rule_channels=3)
|
81 |
+
ca_fn = 'dragon_scales.pt'
|
82 |
+
|
83 |
+
ca.load_state_dict(torch.load(ca_fn, map_location=device))
|
84 |
+
|
85 |
+
# Get video frames
|
86 |
+
to_frames(video_file)
|
87 |
+
|
88 |
+
size=(426, 240)
|
89 |
+
vid_size = Image.open(f'guide_frames/0001.jpg').size
|
90 |
+
if vid_size[0]>vid_size[1]:
|
91 |
+
size = (256, int(256*(vid_size[1]/vid_size[0])))
|
92 |
+
else:
|
93 |
+
size = (int(256*(vid_size[0]/vid_size[1])), 256)
|
94 |
+
|
95 |
+
# Starting grid
|
96 |
+
x = torch.zeros(1, 4, size[1], size[0]).to(ca.device)
|
97 |
+
os.system("rm -r steps;mkdir steps")
|
98 |
+
for i in range(2*len(glob.glob('guide_frames/*.jpg'))-1):
|
99 |
+
# load frame
|
100 |
+
im = Image.open(f'guide_frames/{i//2+1:04}.jpg').resize(size)
|
101 |
+
|
102 |
+
# make rule grid
|
103 |
+
rule_grid = torch.tensor(np.array(im)/255).permute(2, 0, 1).unsqueeze(0).to(ca.device)
|
104 |
+
if enhance:
|
105 |
+
rule_grid = rule_grid * 2 - 0.3 # Add * 2 - 0.3 to 'enhance' an effect
|
106 |
+
|
107 |
+
# Apply the updates
|
108 |
+
with torch.no_grad():
|
109 |
+
x = ca.forward_w_rule_grid(x, rule_grid.float())
|
110 |
+
if i%2==0:
|
111 |
+
img = ca.to_rgb(x).detach().cpu().clip(0, 1).squeeze().permute(1, 2, 0)
|
112 |
+
img = Image.fromarray(np.array(img*255).astype(np.uint8))
|
113 |
+
img.save(f'steps/{i//2:05}.jpeg')
|
114 |
+
|
115 |
+
# Write output video from saved frames
|
116 |
+
os.system("ffmpeg -y -v 0 -framerate 24 -i steps/%05d.jpeg video.mp4")
|
117 |
+
return 'video.mp4'
|
118 |
+
|
119 |
+
|
120 |
+
demo = gr.Blocks()
|
121 |
+
|
122 |
+
with demo:
|
123 |
+
gr.Markdown("Start typing below and then click **Run** to see the output.")
|
124 |
+
with gr.Row():
|
125 |
+
preset = gr.Dropdown(['Glowing Crystals', 'Rainbow Diamonds', 'Dark Diamonds', 'Dragon Scales'], label='Preset')
|
126 |
+
enhance = gr.Checkbox(label='Rescale inputs (more extreme results)')
|
127 |
+
with gr.Row():
|
128 |
+
inp = gr.Video(format='mp4', source='upload', label="Input video (ideally <30s)")
|
129 |
+
out = gr.Video(label="Output")
|
130 |
+
btn = gr.Button("Run")
|
131 |
+
btn.click(fn=update, inputs=[preset, enhance, inp], outputs=out)
|
132 |
+
|
133 |
+
demo.launch(enable_queue=True)
|