hysts HF staff commited on
Commit
121620a
1 Parent(s): ab445b8
Files changed (2) hide show
  1. app.py +9 -17
  2. model.py +12 -9
app.py CHANGED
@@ -53,10 +53,6 @@ def get_cluster_center_image_markdown(model_name: str) -> str:
53
  return f'![cluster center images]({url})'
54
 
55
 
56
- def update_distance_type(multimodal_truncation: bool) -> dict:
57
- return gr.Dropdown.update(visible=multimodal_truncation)
58
-
59
-
60
  def main():
61
  args = parse_args()
62
 
@@ -85,14 +81,14 @@ def main():
85
  step=0.05,
86
  value=0.7,
87
  label='Truncation psi')
88
- multimodal_truncation = gr.Checkbox(
89
- label='Multi-modal Truncation', value=True)
90
- distance_type = gr.Dropdown([
91
- 'lpips',
92
- 'l2',
93
- ],
94
- value='lpips',
95
- label='Distance Type')
96
  run_button = gr.Button('Run')
97
  with gr.Column():
98
  result = gr.Image(label='Result', elem_id='result')
@@ -116,16 +112,12 @@ def main():
116
  gr.Markdown(FOOTER)
117
 
118
  model_name.change(fn=model.set_model, inputs=model_name, outputs=None)
119
- multimodal_truncation.change(fn=update_distance_type,
120
- inputs=multimodal_truncation,
121
- outputs=distance_type)
122
  run_button.click(fn=model.set_model_and_generate_image,
123
  inputs=[
124
  model_name,
125
  seed,
126
  psi,
127
- multimodal_truncation,
128
- distance_type,
129
  ],
130
  outputs=result)
131
  model_name2.change(fn=get_sample_image_markdown,
 
53
  return f'![cluster center images]({url})'
54
 
55
 
 
 
 
 
56
  def main():
57
  args = parse_args()
58
 
 
81
  step=0.05,
82
  value=0.7,
83
  label='Truncation psi')
84
+ truncation_type = gr.Dropdown(
85
+ [
86
+ 'Multimodal (LPIPS)',
87
+ 'Multimodal (L2)',
88
+ 'Global',
89
+ ],
90
+ value='Multimodal (LPIPS)',
91
+ label='Truncation Type')
92
  run_button = gr.Button('Run')
93
  with gr.Column():
94
  result = gr.Image(label='Result', elem_id='result')
 
112
  gr.Markdown(FOOTER)
113
 
114
  model_name.change(fn=model.set_model, inputs=model_name, outputs=None)
 
 
 
115
  run_button.click(fn=model.set_model_and_generate_image,
116
  inputs=[
117
  model_name,
118
  seed,
119
  psi,
120
+ truncation_type,
 
121
  ],
122
  outputs=result)
123
  model_name2.change(fn=get_sample_image_markdown,
model.py CHANGED
@@ -190,15 +190,20 @@ class Model:
190
  return int(np.argmin(distances))
191
 
192
  def generate_image(self, seed: int, truncation_psi: float,
193
- multimodal_truncation: bool,
194
- distance_type: str) -> np.ndarray:
195
  z = self.generate_z(seed)
196
  ws = self.compute_w(z)
197
- if multimodal_truncation:
 
 
 
 
 
 
 
 
198
  cluster_index = self.find_nearest_cluster_center(ws, distance_type)
199
  w0 = self.cluster_centers[cluster_index]
200
- else:
201
- w0 = self.model.mapping.w_avg
202
  new_ws = self.truncate_w(w0, ws, truncation_psi)
203
  out = self.synthesize(new_ws)
204
  out = self.postprocess(out)
@@ -206,8 +211,6 @@ class Model:
206
 
207
  def set_model_and_generate_image(self, model_name: str, seed: int,
208
  truncation_psi: float,
209
- multimodal_truncation: bool,
210
- distance_type: str) -> np.ndarray:
211
  self.set_model(model_name)
212
- return self.generate_image(seed, truncation_psi, multimodal_truncation,
213
- distance_type)
 
190
  return int(np.argmin(distances))
191
 
192
  def generate_image(self, seed: int, truncation_psi: float,
193
+ truncation_type: str) -> np.ndarray:
 
194
  z = self.generate_z(seed)
195
  ws = self.compute_w(z)
196
+ if truncation_type == 'Global':
197
+ w0 = self.model.mapping.w_avg
198
+ else:
199
+ if truncation_type == 'Multimodal (LPIPS)':
200
+ distance_type = 'lpips'
201
+ elif truncation_type == 'Multimodal (L2)':
202
+ distance_type = 'l2'
203
+ else:
204
+ raise ValueError
205
  cluster_index = self.find_nearest_cluster_center(ws, distance_type)
206
  w0 = self.cluster_centers[cluster_index]
 
 
207
  new_ws = self.truncate_w(w0, ws, truncation_psi)
208
  out = self.synthesize(new_ws)
209
  out = self.postprocess(out)
 
211
 
212
  def set_model_and_generate_image(self, model_name: str, seed: int,
213
  truncation_psi: float,
214
+ truncation_type: str) -> np.ndarray:
 
215
  self.set_model(model_name)
216
+ return self.generate_image(seed, truncation_psi, truncation_type)