hysts HF staff commited on
Commit
3fc678c
1 Parent(s): 121620a
Files changed (2) hide show
  1. app.py +2 -6
  2. model.py +8 -3
app.py CHANGED
@@ -82,12 +82,8 @@ def main():
82
  value=0.7,
83
  label='Truncation psi')
84
  truncation_type = gr.Dropdown(
85
- [
86
- 'Multimodal (LPIPS)',
87
- 'Multimodal (L2)',
88
- 'Global',
89
- ],
90
- value='Multimodal (LPIPS)',
91
  label='Truncation Type')
92
  run_button = gr.Button('Run')
93
  with gr.Column():
 
82
  value=0.7,
83
  label='Truncation psi')
84
  truncation_type = gr.Dropdown(
85
+ model.TRUNCATION_TYPES,
86
+ value=model.TRUNCATION_TYPES[0],
 
 
 
 
87
  label='Truncation Type')
88
  run_button = gr.Button('Run')
89
  with gr.Column():
model.py CHANGED
@@ -54,6 +54,11 @@ class Model:
54
  'giraffes_512',
55
  'parrots_512',
56
  ]
 
 
 
 
 
57
 
58
  def __init__(self, device: str | torch.device):
59
  self.device = torch.device(device)
@@ -193,12 +198,12 @@ class Model:
193
  truncation_type: str) -> np.ndarray:
194
  z = self.generate_z(seed)
195
  ws = self.compute_w(z)
196
- if truncation_type == 'Global':
197
  w0 = self.model.mapping.w_avg
198
  else:
199
- if truncation_type == 'Multimodal (LPIPS)':
200
  distance_type = 'lpips'
201
- elif truncation_type == 'Multimodal (L2)':
202
  distance_type = 'l2'
203
  else:
204
  raise ValueError
 
54
  'giraffes_512',
55
  'parrots_512',
56
  ]
57
+ TRUNCATION_TYPES = [
58
+ 'Multimodal (LPIPS)',
59
+ 'Multimodal (L2)',
60
+ 'Global',
61
+ ]
62
 
63
  def __init__(self, device: str | torch.device):
64
  self.device = torch.device(device)
 
198
  truncation_type: str) -> np.ndarray:
199
  z = self.generate_z(seed)
200
  ws = self.compute_w(z)
201
+ if truncation_type == self.TRUNCATION_TYPES[2]:
202
  w0 = self.model.mapping.w_avg
203
  else:
204
+ if truncation_type == self.TRUNCATION_TYPES[0]:
205
  distance_type = 'lpips'
206
+ elif truncation_type == self.TRUNCATION_TYPES[1]:
207
  distance_type = 'l2'
208
  else:
209
  raise ValueError