Spaces:
Build error
Build error
Jiatao Gu
commited on
Commit
•
df44b7d
1
Parent(s):
77c753d
fix bug for cpu running
Browse files- app.py +4 -14
- gradio_queue.db +0 -0
- training/networks.py +3 -3
app.py
CHANGED
@@ -9,29 +9,18 @@ import time
|
|
9 |
import legacy
|
10 |
import torch
|
11 |
import glob
|
12 |
-
|
13 |
import cv2
|
14 |
-
|
15 |
from torch_utils import misc
|
16 |
from renderer import Renderer
|
17 |
from training.networks import Generator
|
18 |
from huggingface_hub import hf_hub_download
|
19 |
|
20 |
|
21 |
-
device = torch.device('cuda')
|
22 |
port = int(sys.argv[1]) if len(sys.argv) > 1 else 21111
|
23 |
|
24 |
|
25 |
-
|
26 |
-
def handler(signum, frame):
|
27 |
-
res = input("Ctrl-c was pressed. Do you really want to exit? y/n ")
|
28 |
-
if res == 'y':
|
29 |
-
gr.close_all()
|
30 |
-
exit(1)
|
31 |
-
|
32 |
-
signal.signal(signal.SIGINT, handler)
|
33 |
-
|
34 |
-
|
35 |
def set_random_seed(seed):
|
36 |
torch.manual_seed(seed)
|
37 |
np.random.seed(seed)
|
@@ -202,11 +191,12 @@ yaw = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="yaw")
|
|
202 |
pitch = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="pitch")
|
203 |
roll = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="roll (optional, not suggested)")
|
204 |
fov = gr.inputs.Slider(minimum=9, maximum=15, default=12, label="fov")
|
205 |
-
css = ".
|
206 |
|
207 |
gr.Interface(fn=f_synthesis,
|
208 |
inputs=[model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, "state"],
|
209 |
title="Interctive Web Demo for StyleNeRF (ICLR 2022)",
|
|
|
210 |
outputs=["image", "state"],
|
211 |
layout='unaligned',
|
212 |
css=css, theme='dark-huggingface',
|
|
|
9 |
import legacy
|
10 |
import torch
|
11 |
import glob
|
|
|
12 |
import cv2
|
13 |
+
|
14 |
from torch_utils import misc
|
15 |
from renderer import Renderer
|
16 |
from training.networks import Generator
|
17 |
from huggingface_hub import hf_hub_download
|
18 |
|
19 |
|
20 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
21 |
port = int(sys.argv[1]) if len(sys.argv) > 1 else 21111
|
22 |
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
def set_random_seed(seed):
|
25 |
torch.manual_seed(seed)
|
26 |
np.random.seed(seed)
|
|
|
191 |
pitch = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="pitch")
|
192 |
roll = gr.inputs.Slider(minimum=-1, maximum=1, default=0, label="roll (optional, not suggested)")
|
193 |
fov = gr.inputs.Slider(minimum=9, maximum=15, default=12, label="fov")
|
194 |
+
css = ".output-image, .input-image, .image-preview {height: 600px !important} "
|
195 |
|
196 |
gr.Interface(fn=f_synthesis,
|
197 |
inputs=[model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, "state"],
|
198 |
title="Interctive Web Demo for StyleNeRF (ICLR 2022)",
|
199 |
+
description="Demo for ICLR 2022 Papaer: A Style-based 3D-Aware Generator for High-resolution Image Synthesis. Currently the demo runs on CPU only."
|
200 |
outputs=["image", "state"],
|
201 |
layout='unaligned',
|
202 |
css=css, theme='dark-huggingface',
|
gradio_queue.db
ADDED
Binary file (856 kB). View file
|
|
training/networks.py
CHANGED
@@ -794,7 +794,7 @@ class SynthesisBlock(torch.nn.Module):
|
|
794 |
def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, add_on=None, block_noise=None, disable_rgb=False, **layer_kwargs):
|
795 |
misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
|
796 |
w_iter = iter(ws.unbind(dim=1))
|
797 |
-
dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
|
798 |
memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
|
799 |
if fused_modconv is None:
|
800 |
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
|
@@ -937,7 +937,7 @@ class SynthesisBlock3(torch.nn.Module):
|
|
937 |
|
938 |
def forward(self, x, img, ws, force_fp32=False, add_on=None, disable_rgb=False, **layer_kwargs):
|
939 |
w_iter = iter(ws.unbind(dim=1))
|
940 |
-
dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
|
941 |
memory_format = torch.contiguous_format
|
942 |
|
943 |
# Main layers.
|
@@ -1141,7 +1141,7 @@ class DiscriminatorBlock(torch.nn.Module):
|
|
1141 |
trainable=next(trainable_iter), resample_filter=resample_filter, channels_last=self.channels_last)
|
1142 |
|
1143 |
def forward(self, x, img, force_fp32=False, downsampler=None):
|
1144 |
-
dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
|
1145 |
memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
|
1146 |
|
1147 |
# Input.
|
|
|
794 |
def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, add_on=None, block_noise=None, disable_rgb=False, **layer_kwargs):
|
795 |
misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
|
796 |
w_iter = iter(ws.unbind(dim=1))
|
797 |
+
dtype = torch.float16 if (self.use_fp16 and x.device.type == 'cuda') and not force_fp32 else torch.float32
|
798 |
memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
|
799 |
if fused_modconv is None:
|
800 |
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
|
|
|
937 |
|
938 |
def forward(self, x, img, ws, force_fp32=False, add_on=None, disable_rgb=False, **layer_kwargs):
|
939 |
w_iter = iter(ws.unbind(dim=1))
|
940 |
+
dtype = torch.float16 if (self.use_fp16 and x.device.type == 'cuda') and not force_fp32 else torch.float32
|
941 |
memory_format = torch.contiguous_format
|
942 |
|
943 |
# Main layers.
|
|
|
1141 |
trainable=next(trainable_iter), resample_filter=resample_filter, channels_last=self.channels_last)
|
1142 |
|
1143 |
def forward(self, x, img, force_fp32=False, downsampler=None):
|
1144 |
+
dtype = torch.float16 if (self.use_fp16 and x.device.type == 'cuda') and not force_fp32 else torch.float32
|
1145 |
memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
|
1146 |
|
1147 |
# Input.
|