liuyuan-pal commited on
Commit
ab287b7
1 Parent(s): 0fa63ef
Files changed (1) hide show
  1. app.py +28 -10
app.py CHANGED
@@ -17,12 +17,19 @@ _DESCRIPTION = '''
17
  <a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/2309.03453"><img src="https://img.shields.io/badge/2309.03453-f9f7f7?logo="></a>
18
  <a style="display:inline-block; margin-left: .5em" href='https://github.com/liuyuan-pal/SyncDreamer'><img src='https://img.shields.io/github/stars/liuyuan-pal/SyncDreamer?style=social' /></a>
19
  </div>
20
- Given a single-view image, SyncDreamer is able to generate multiview-consistent images, which enables direct 3D reconstruction with NeuS or NeRF without SDS loss'''
 
 
 
 
 
 
21
  _USER_GUIDE0 = "Step0: Please upload an image in the block above (or choose an example above). We use alpha values as object masks if given."
22
  _USER_GUIDE1 = "Step1: Please select a crop size using the glider."
23
  _USER_GUIDE2 = "Step2: Please choose a suitable elevation angle and then click the Generate button."
24
  _USER_GUIDE3 = "Generated multiview images are shown below!"
25
 
 
26
 
27
  def mask_prediction(mask_predictor, image_in: Image.Image):
28
  if image_in.mode=='RGBA':
@@ -56,11 +63,16 @@ def generate(model, batch_view_num, sample_num, cfg_scale, seed, image_input, el
56
  elevation_input = torch.from_numpy(np.asarray([np.deg2rad(elevation_input)], np.float32))
57
  data = {"input_image": image_input, "input_elevation": elevation_input}
58
  for k, v in data.items():
59
- data[k] = v.unsqueeze(0)#.cuda()
 
 
 
60
  data[k] = torch.repeat_interleave(data[k], sample_num, dim=0)
61
 
62
- x_sample = model.sample(data, cfg_scale, batch_view_num)
63
- # x_sample = torch.zeros(sample_num, 16, 3, 256, 256)
 
 
64
 
65
  B, N, _, H, W = x_sample.shape
66
  x_sample = (torch.clamp(x_sample,max=1.0,min=-1.0) + 1) * 0.5
@@ -80,12 +92,15 @@ def run_demo():
80
  ckpt = 'ckpt/syncdreamer-pretrain.ckpt'
81
  config = OmegaConf.load(cfg)
82
  # model = None
83
- model = instantiate_from_config(config.model)
84
- print(f'loading model from {ckpt} ...')
85
- ckpt = torch.load(ckpt,map_location='cpu')
86
- model.load_state_dict(ckpt['state_dict'], strict=True)
87
- model = model.cuda().eval()
88
- del ckpt
 
 
 
89
 
90
  # init sam model
91
  mask_predictor = None # sam_init(device_idx)
@@ -121,10 +136,12 @@ def run_demo():
121
  examples_per_page=40
122
  )
123
 
 
124
  with gr.Column(scale=1):
125
  sam_block = gr.Image(type='pil', image_mode='RGBA', label="SAM output", height=256, interactive=False)
126
  crop_size_slider = gr.Slider(120, 240, 200, step=10, label='Crop size', interactive=True)
127
  crop_btn = gr.Button('Crop the image', variant='primary', interactive=True)
 
128
 
129
  with gr.Column(scale=1):
130
  input_block = gr.Image(type='pil', image_mode='RGB', label="Input to SyncDreamer", height=256, interactive=False)
@@ -134,6 +151,7 @@ def run_demo():
134
  # batch_view_num = gr.Slider(1, 16, 8, step=1, label='', interactive=True)
135
  seed = gr.Number(6033, label='Random seed', interactive=True)
136
  run_btn = gr.Button('Run Generation', variant='primary', interactive=True)
 
137
 
138
  output_block = gr.Image(type='pil', image_mode='RGB', label="Outputs of SyncDreamer", height=256, interactive=False)
139
 
 
17
  <a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/2309.03453"><img src="https://img.shields.io/badge/2309.03453-f9f7f7?logo="></a>
18
  <a style="display:inline-block; margin-left: .5em" href='https://github.com/liuyuan-pal/SyncDreamer'><img src='https://img.shields.io/github/stars/liuyuan-pal/SyncDreamer?style=social' /></a>
19
  </div>
20
+ Given a single-view image, SyncDreamer is able to generate multiview-consistent images, which enables direct 3D reconstruction with NeuS or NeRF without SDS loss
21
+
22
+ 1. Upload the image.
23
+ 2. Predict the mask for the foreground object.
24
+ 3. Crop the foreground object.
25
+ 4. Generate multiview images.
26
+ '''
27
  _USER_GUIDE0 = "Step0: Please upload an image in the block above (or choose an example above). We use alpha values as object masks if given."
28
  _USER_GUIDE1 = "Step1: Please select a crop size using the glider."
29
  _USER_GUIDE2 = "Step2: Please choose a suitable elevation angle and then click the Generate button."
30
  _USER_GUIDE3 = "Generated multiview images are shown below!"
31
 
32
+ deployed = True
33
 
34
  def mask_prediction(mask_predictor, image_in: Image.Image):
35
  if image_in.mode=='RGBA':
 
63
  elevation_input = torch.from_numpy(np.asarray([np.deg2rad(elevation_input)], np.float32))
64
  data = {"input_image": image_input, "input_elevation": elevation_input}
65
  for k, v in data.items():
66
+ if deployed:
67
+ data[k] = v.unsqueeze(0).cuda()
68
+ else:
69
+ data[k] = v.unsqueeze(0)
70
  data[k] = torch.repeat_interleave(data[k], sample_num, dim=0)
71
 
72
+ if deployed:
73
+ x_sample = model.sample(data, cfg_scale, batch_view_num)
74
+ else:
75
+ x_sample = torch.zeros(sample_num, 16, 3, 256, 256)
76
 
77
  B, N, _, H, W = x_sample.shape
78
  x_sample = (torch.clamp(x_sample,max=1.0,min=-1.0) + 1) * 0.5
 
92
  ckpt = 'ckpt/syncdreamer-pretrain.ckpt'
93
  config = OmegaConf.load(cfg)
94
  # model = None
95
+ if deployed:
96
+ model = instantiate_from_config(config.model)
97
+ print(f'loading model from {ckpt} ...')
98
+ ckpt = torch.load(ckpt,map_location='cpu')
99
+ model.load_state_dict(ckpt['state_dict'], strict=True)
100
+ model = model.cuda().eval()
101
+ del ckpt
102
+ else:
103
+ model = None
104
 
105
  # init sam model
106
  mask_predictor = None # sam_init(device_idx)
 
136
  examples_per_page=40
137
  )
138
 
139
+
140
  with gr.Column(scale=1):
141
  sam_block = gr.Image(type='pil', image_mode='RGBA', label="SAM output", height=256, interactive=False)
142
  crop_size_slider = gr.Slider(120, 240, 200, step=10, label='Crop size', interactive=True)
143
  crop_btn = gr.Button('Crop the image', variant='primary', interactive=True)
144
+ fig0 = gr.Image(value=Image.open('assets/crop_size.jpg'), type='pil', image_mode='RGB', height=256, show_label=False, tool=None, interactive=False)
145
 
146
  with gr.Column(scale=1):
147
  input_block = gr.Image(type='pil', image_mode='RGB', label="Input to SyncDreamer", height=256, interactive=False)
 
151
  # batch_view_num = gr.Slider(1, 16, 8, step=1, label='', interactive=True)
152
  seed = gr.Number(6033, label='Random seed', interactive=True)
153
  run_btn = gr.Button('Run Generation', variant='primary', interactive=True)
154
+ fig1 = gr.Image(value=Image.open('assets/elevation.jpg'), type='pil', image_mode='RGB', height=256, show_label=False, tool=None, interactive=False)
155
 
156
  output_block = gr.Image(type='pil', image_mode='RGB', label="Outputs of SyncDreamer", height=256, interactive=False)
157