update to latest version of gradio

#1
by akhaliq HF staff - opened
Files changed (38) hide show
  1. .pre-commit-config.yaml +0 -60
  2. .vscode/settings.json +0 -30
  3. README.md +1 -4
  4. app.py +146 -96
  5. cluster_center_images/horses_256.jpg +0 -3
  6. cluster_center_images/lions_512.jpg +0 -3
  7. cluster_center_images/parrots_512.jpg +0 -3
  8. model.py +0 -192
  9. requirements.txt +5 -6
  10. cluster_center_images/elephants_512.jpg β†’ samples/bicycles.jpg +2 -2
  11. samples/bicycles_256_global.jpg +0 -3
  12. samples/bicycles_256_multimodal_l2.jpg +0 -3
  13. samples/bicycles_256_multimodal_lpips.jpg +0 -3
  14. cluster_center_images/dogs_1024.jpg β†’ samples/dogs.jpg +2 -2
  15. samples/dogs_1024_global.jpg +0 -3
  16. samples/dogs_1024_multimodal_l2.jpg +0 -3
  17. samples/dogs_1024_multimodal_lpips.jpg +0 -3
  18. samples/elephants.jpg +3 -0
  19. samples/elephants_512_global.jpg +0 -3
  20. samples/elephants_512_multimodal_l2.jpg +0 -3
  21. samples/elephants_512_multimodal_lpips.jpg +0 -3
  22. samples/giraffes.jpg +3 -0
  23. samples/giraffes_512_global.jpg +0 -3
  24. samples/giraffes_512_multimodal_l2.jpg +0 -3
  25. samples/giraffes_512_multimodal_lpips.jpg +0 -3
  26. cluster_center_images/giraffes_512.jpg β†’ samples/horses.jpg +2 -2
  27. samples/horses_256_global.jpg +0 -3
  28. samples/horses_256_multimodal_l2.jpg +0 -3
  29. samples/horses_256_multimodal_lpips.jpg +0 -3
  30. samples/lions.jpg +3 -0
  31. samples/lions_512_global.jpg +0 -3
  32. samples/lions_512_multimodal_l2.jpg +0 -3
  33. samples/lions_512_multimodal_lpips.jpg +0 -3
  34. cluster_center_images/bicycles_256.jpg β†’ samples/parrots.jpg +2 -2
  35. samples/parrots_512_global.jpg +0 -3
  36. samples/parrots_512_multimodal_l2.jpg +0 -3
  37. samples/parrots_512_multimodal_lpips.jpg +0 -3
  38. style.css +0 -8
.pre-commit-config.yaml DELETED
@@ -1,60 +0,0 @@
1
- repos:
2
- - repo: https://github.com/pre-commit/pre-commit-hooks
3
- rev: v4.6.0
4
- hooks:
5
- - id: check-executables-have-shebangs
6
- - id: check-json
7
- - id: check-merge-conflict
8
- - id: check-shebang-scripts-are-executable
9
- - id: check-toml
10
- - id: check-yaml
11
- - id: end-of-file-fixer
12
- - id: mixed-line-ending
13
- args: ["--fix=lf"]
14
- - id: requirements-txt-fixer
15
- - id: trailing-whitespace
16
- - repo: https://github.com/myint/docformatter
17
- rev: v1.7.5
18
- hooks:
19
- - id: docformatter
20
- args: ["--in-place"]
21
- - repo: https://github.com/pycqa/isort
22
- rev: 5.13.2
23
- hooks:
24
- - id: isort
25
- args: ["--profile", "black"]
26
- - repo: https://github.com/pre-commit/mirrors-mypy
27
- rev: v1.10.0
28
- hooks:
29
- - id: mypy
30
- args: ["--ignore-missing-imports"]
31
- additional_dependencies:
32
- [
33
- "types-python-slugify",
34
- "types-requests",
35
- "types-PyYAML",
36
- "types-pytz",
37
- ]
38
- - repo: https://github.com/psf/black
39
- rev: 24.4.2
40
- hooks:
41
- - id: black
42
- language_version: python3.10
43
- args: ["--line-length", "119"]
44
- - repo: https://github.com/kynan/nbstripout
45
- rev: 0.7.1
46
- hooks:
47
- - id: nbstripout
48
- args:
49
- [
50
- "--extra-keys",
51
- "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
52
- ]
53
- - repo: https://github.com/nbQA-dev/nbQA
54
- rev: 1.8.5
55
- hooks:
56
- - id: nbqa-black
57
- - id: nbqa-pyupgrade
58
- args: ["--py37-plus"]
59
- - id: nbqa-isort
60
- args: ["--float-to-top"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.vscode/settings.json DELETED
@@ -1,30 +0,0 @@
1
- {
2
- "editor.formatOnSave": true,
3
- "files.insertFinalNewline": false,
4
- "[python]": {
5
- "editor.defaultFormatter": "ms-python.black-formatter",
6
- "editor.formatOnType": true,
7
- "editor.codeActionsOnSave": {
8
- "source.organizeImports": "explicit"
9
- }
10
- },
11
- "[jupyter]": {
12
- "files.insertFinalNewline": false
13
- },
14
- "black-formatter.args": [
15
- "--line-length=119"
16
- ],
17
- "isort.args": ["--profile", "black"],
18
- "flake8.args": [
19
- "--max-line-length=119"
20
- ],
21
- "ruff.lint.args": [
22
- "--line-length=119"
23
- ],
24
- "notebook.output.scrolling": true,
25
- "notebook.formatOnCellExecution": true,
26
- "notebook.formatOnSave.enabled": true,
27
- "notebook.codeActionsOnSave": {
28
- "source.organizeImports": "explicit"
29
- }
30
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -4,12 +4,9 @@ emoji: 🐨
4
  colorFrom: indigo
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
- suggested_hardware: t4-small
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
14
-
15
- https://arxiv.org/abs/2202.12211
 
4
  colorFrom: indigo
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 3.0.5
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
 
 
app.py CHANGED
@@ -2,105 +2,155 @@
2
 
3
  from __future__ import annotations
4
 
5
- import pathlib
 
 
 
 
6
 
7
  import gradio as gr
8
  import numpy as np
9
-
10
- from model import Model
11
-
12
- DESCRIPTION = "# [Self-Distilled StyleGAN](https://github.com/self-distilled-stylegan/self-distilled-internet-photos)"
13
-
14
-
15
- def get_sample_image_url(name: str) -> str:
16
- sample_image_dir = "https://huggingface.co/spaces/hysts/Self-Distilled-StyleGAN/resolve/main/samples"
17
- return f"{sample_image_dir}/{name}.jpg"
18
-
19
-
20
- def get_sample_image_markdown(name: str) -> str:
21
- url = get_sample_image_url(name)
22
- size = name.split("_")[1]
23
- truncation_type = "_".join(name.split("_")[2:])
24
- return f"""
25
- - size: {size}x{size}
26
- - seed: 0-99
27
- - truncation: 0.7
28
- - truncation type: {truncation_type}
29
- ![sample images]({url})"""
30
-
31
-
32
- def get_cluster_center_image_url(model_name: str) -> str:
33
- cluster_center_image_dir = (
34
- "https://huggingface.co/spaces/hysts/Self-Distilled-StyleGAN/resolve/main/cluster_center_images"
35
- )
36
- return f"{cluster_center_image_dir}/{model_name}.jpg"
37
-
38
-
39
- def get_cluster_center_image_markdown(model_name: str) -> str:
40
- url = get_cluster_center_image_url(model_name)
41
- return f"![cluster center images]({url})"
42
-
43
-
44
- model = Model()
45
-
46
- with gr.Blocks(css="style.css") as demo:
47
- gr.Markdown(DESCRIPTION)
48
-
49
- with gr.Tabs():
50
- with gr.TabItem("App"):
51
- with gr.Row():
52
- with gr.Column():
53
- with gr.Group():
54
- model_name = gr.Dropdown(label="Model", choices=model.MODEL_NAMES, value=model.MODEL_NAMES[0])
55
- seed = gr.Slider(label="Seed", minimum=0, maximum=np.iinfo(np.uint32).max, step=1, value=0)
56
- psi = gr.Slider(label="Truncation psi", minimum=0, maximum=2, step=0.05, value=0.7)
57
- truncation_type = gr.Dropdown(
58
- label="Truncation Type", choices=model.TRUNCATION_TYPES, value=model.TRUNCATION_TYPES[0]
59
- )
60
- run_button = gr.Button("Run")
61
- with gr.Column():
62
- result = gr.Image(label="Result", elem_id="result")
63
-
64
- with gr.TabItem("Sample Images"):
65
- with gr.Row():
66
- paths = sorted(pathlib.Path("samples").glob("*"))
67
- names = [path.stem for path in paths]
68
- model_name2 = gr.Dropdown(label="Type", choices=names, value="dogs_1024_multimodal_lpips")
69
- with gr.Row():
70
- text = get_sample_image_markdown(model_name2.value)
71
- sample_images = gr.Markdown(text)
72
-
73
- with gr.TabItem("Cluster Center Images"):
74
- with gr.Row():
75
- model_name3 = gr.Dropdown(label="Model", choices=model.MODEL_NAMES, value=model.MODEL_NAMES[0])
76
- with gr.Row():
77
- text = get_cluster_center_image_markdown(model_name3.value)
78
- cluster_center_images = gr.Markdown(value=text)
79
-
80
- model_name.change(
81
- fn=model.set_model,
82
- inputs=model_name,
83
- )
84
- run_button.click(
85
- fn=model.set_model_and_generate_image,
86
- inputs=[
87
- model_name,
88
- seed,
89
- psi,
90
- truncation_type,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  ],
92
- outputs=result,
93
- )
94
- model_name2.change(
95
- fn=get_sample_image_markdown,
96
- inputs=model_name2,
97
- outputs=sample_images,
98
- )
99
- model_name3.change(
100
- fn=get_cluster_center_image_markdown,
101
- inputs=model_name3,
102
- outputs=cluster_center_images,
103
  )
104
 
105
- if __name__ == "__main__":
106
- demo.queue(max_size=10).launch()
 
 
2
 
3
  from __future__ import annotations
4
 
5
+ import argparse
6
+ import functools
7
+ import os
8
+ import pickle
9
+ import sys
10
 
11
  import gradio as gr
12
  import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ from huggingface_hub import hf_hub_download
16
+
17
+ sys.path.insert(0, 'stylegan3')
18
+
19
+ TITLE = 'Self-Distilled StyleGAN'
20
+ DESCRIPTION = '''This is an unofficial demo for models provided in https://github.com/self-distilled-stylegan/self-distilled-internet-photos.
21
+
22
+ Expected execution time on Hugging Face Spaces: 2s
23
+ '''
24
+ SAMPLE_IMAGE_DIR = 'https://huggingface.co/spaces/hysts/Self-Distilled-StyleGAN/resolve/main/samples'
25
+ ARTICLE = f'''## Generated images
26
+ - truncation: 0.7
27
+ ### Dogs
28
+ - size: 1024x1024
29
+ - seed: 0-99
30
+ ![Dogs]({SAMPLE_IMAGE_DIR}/dogs.jpg)
31
+ ### Elephants
32
+ - size: 512x512
33
+ - seed: 0-99
34
+ ![Elephants]({SAMPLE_IMAGE_DIR}/elephants.jpg)
35
+ ### Horses
36
+ - size: 256x256
37
+ - seed: 0-99
38
+ ![Horses]({SAMPLE_IMAGE_DIR}/horses.jpg)
39
+ ### Bicycles
40
+ - size: 256x256
41
+ - seed: 0-99
42
+ ![Bicycles]({SAMPLE_IMAGE_DIR}/bicycles.jpg)
43
+ ### Lions
44
+ - size: 512x512
45
+ - seed: 0-99
46
+ ![Lions]({SAMPLE_IMAGE_DIR}/lions.jpg)
47
+ ### Giraffes
48
+ - size: 512x512
49
+ - seed: 0-99
50
+ ![Giraffes]({SAMPLE_IMAGE_DIR}/giraffes.jpg)
51
+ ### Parrots
52
+ - size: 512x512
53
+ - seed: 0-99
54
+ ![Parrots]({SAMPLE_IMAGE_DIR}/parrots.jpg)
55
+
56
+ <center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.self-distilled-stylegan" alt="visitor badge"/></center>
57
+ '''
58
+
59
+ TOKEN = os.environ['TOKEN']
60
+
61
+
62
+ def parse_args() -> argparse.Namespace:
63
+ parser = argparse.ArgumentParser()
64
+ parser.add_argument('--device', type=str, default='cpu')
65
+ parser.add_argument('--theme', type=str)
66
+ parser.add_argument('--live', action='store_true')
67
+ parser.add_argument('--share', action='store_true')
68
+ parser.add_argument('--port', type=int)
69
+ parser.add_argument('--disable-queue',
70
+ dest='enable_queue',
71
+ action='store_false')
72
+ parser.add_argument('--allow-flagging', type=str, default='never')
73
+ return parser.parse_args()
74
+
75
+
76
+ def generate_z(z_dim: int, seed: int, device: torch.device) -> torch.Tensor:
77
+ return torch.from_numpy(np.random.RandomState(seed).randn(
78
+ 1, z_dim)).to(device).float()
79
+
80
+
81
+ @torch.inference_mode()
82
+ def generate_image(model_name: str, seed: int, truncation_psi: float,
83
+ model_dict: dict[str, nn.Module],
84
+ device: torch.device) -> np.ndarray:
85
+ model = model_dict[model_name]
86
+ seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
87
+
88
+ z = generate_z(model.z_dim, seed, device)
89
+ label = torch.zeros([1, model.c_dim], device=device)
90
+
91
+ out = model(z, label, truncation_psi=truncation_psi)
92
+ out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
93
+ return out[0].cpu().numpy()
94
+
95
+
96
+ def load_model(model_name: str, device: torch.device) -> nn.Module:
97
+ path = hf_hub_download('hysts/Self-Distilled-StyleGAN',
98
+ f'models/{model_name}_pytorch.pkl',
99
+ use_auth_token=TOKEN)
100
+ with open(path, 'rb') as f:
101
+ model = pickle.load(f)['G_ema']
102
+ model.eval()
103
+ model.to(device)
104
+ with torch.inference_mode():
105
+ z = torch.zeros((1, model.z_dim)).to(device)
106
+ label = torch.zeros([1, model.c_dim], device=device)
107
+ model(z, label)
108
+ return model
109
+
110
+
111
+ def main():
112
+ args = parse_args()
113
+ device = torch.device(args.device)
114
+
115
+ model_names = [
116
+ 'dogs_1024',
117
+ 'elephants_512',
118
+ 'horses_256',
119
+ 'bicycles_256',
120
+ 'lions_512',
121
+ 'giraffes_512',
122
+ 'parrots_512',
123
+ ]
124
+
125
+ model_dict = {name: load_model(name, device) for name in model_names}
126
+
127
+ func = functools.partial(generate_image,
128
+ model_dict=model_dict,
129
+ device=device)
130
+ func = functools.update_wrapper(func, generate_image)
131
+
132
+ gr.Interface(
133
+ func,
134
+ [
135
+ gr.inputs.Radio(
136
+ model_names, type='value', default='dogs_1024', label='Model'),
137
+ gr.inputs.Number(default=0, label='Seed'),
138
+ gr.inputs.Slider(
139
+ 0, 2, step=0.05, default=0.7, label='Truncation psi'),
140
  ],
141
+ gr.outputs.Image(type='numpy', label='Output'),
142
+ title=TITLE,
143
+ description=DESCRIPTION,
144
+ article=ARTICLE,
145
+ theme=args.theme,
146
+ allow_flagging=args.allow_flagging,
147
+ live=args.live,
148
+ ).launch(
149
+ enable_queue=args.enable_queue,
150
+ server_port=args.port,
151
+ share=args.share,
152
  )
153
 
154
+
155
+ if __name__ == '__main__':
156
+ main()
cluster_center_images/horses_256.jpg DELETED

Git LFS Details

  • SHA256: 3327bc1fe938f27a60d3df46bde5620e50cd5abdfbffbcdc36a9f0d6ed1eaca1
  • Pointer size: 132 Bytes
  • Size of remote file: 2 MB
cluster_center_images/lions_512.jpg DELETED

Git LFS Details

  • SHA256: 1775e96848d8f2b8150a1a2f7ce09fe051566385c20a1aca317e5bd1cd1fbf0a
  • Pointer size: 132 Bytes
  • Size of remote file: 7.75 MB
cluster_center_images/parrots_512.jpg DELETED

Git LFS Details

  • SHA256: ee1f0758a580117cbeccda384441edb896e43e7d8295f56cc137962f4b7df2f9
  • Pointer size: 132 Bytes
  • Size of remote file: 4.12 MB
model.py DELETED
@@ -1,192 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import pathlib
4
- import pickle
5
- import sys
6
-
7
- import lpips
8
- import numpy as np
9
- import torch
10
- import torch.nn as nn
11
- from huggingface_hub import hf_hub_download
12
-
13
- current_dir = pathlib.Path(__file__).parent
14
- submodule_dir = current_dir / "stylegan3"
15
- sys.path.insert(0, submodule_dir.as_posix())
16
-
17
-
18
- class LPIPS(lpips.LPIPS):
19
- @staticmethod
20
- def preprocess(image: np.ndarray) -> torch.Tensor:
21
- data = torch.from_numpy(image).float() / 255
22
- data = data * 2 - 1
23
- return data.permute(2, 0, 1).unsqueeze(0)
24
-
25
- @torch.inference_mode()
26
- def compute_features(self, data: torch.Tensor) -> list[torch.Tensor]:
27
- data = self.scaling_layer(data)
28
- data = self.net(data)
29
- return [lpips.normalize_tensor(x) for x in data]
30
-
31
- @torch.inference_mode()
32
- def compute_distance(self, features0: list[torch.Tensor], features1: list[torch.Tensor]) -> float:
33
- res = 0
34
- for lin, x0, x1 in zip(self.lins, features0, features1):
35
- d = (x0 - x1) ** 2
36
- y = lin(d)
37
- y = lpips.lpips.spatial_average(y)
38
- res += y.item()
39
- return res
40
-
41
-
42
- class Model:
43
-
44
- MODEL_NAMES = [
45
- "dogs_1024",
46
- "elephants_512",
47
- "horses_256",
48
- "bicycles_256",
49
- "lions_512",
50
- "giraffes_512",
51
- "parrots_512",
52
- ]
53
- TRUNCATION_TYPES = [
54
- "Multimodal (LPIPS)",
55
- "Multimodal (L2)",
56
- "Global",
57
- ]
58
-
59
- def __init__(self):
60
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
61
- self._download_all_models()
62
- self._download_all_cluster_centers()
63
- self._download_all_cluster_center_images()
64
-
65
- self.model_name = self.MODEL_NAMES[0]
66
- self.model = self._load_model(self.model_name)
67
- self.cluster_centers = self._load_cluster_centers(self.model_name)
68
- self.cluster_center_images = self._load_cluster_center_images(self.model_name)
69
-
70
- self.lpips = LPIPS()
71
- self.cluster_center_lpips_feature_dict = self._compute_cluster_center_lpips_features()
72
-
73
- def _load_model(self, model_name: str) -> nn.Module:
74
- path = hf_hub_download("public-data/Self-Distilled-StyleGAN", f"models/{model_name}_pytorch.pkl")
75
- with open(path, "rb") as f:
76
- model = pickle.load(f)["G_ema"]
77
- model.eval()
78
- model.to(self.device)
79
- return model
80
-
81
- def _load_cluster_centers(self, model_name: str) -> torch.Tensor:
82
- path = hf_hub_download("public-data/Self-Distilled-StyleGAN", f"cluster_centers/{model_name}.npy")
83
- centers = np.load(path)
84
- centers = torch.from_numpy(centers).float().to(self.device)
85
- return centers
86
-
87
- def _load_cluster_center_images(self, model_name: str) -> np.ndarray:
88
- path = hf_hub_download("public-data/Self-Distilled-StyleGAN", f"cluster_center_images/{model_name}.npy")
89
- return np.load(path)
90
-
91
- def set_model(self, model_name: str) -> None:
92
- if model_name == self.model_name:
93
- return
94
- self.model_name = model_name
95
- self.model = self._load_model(model_name)
96
- self.cluster_centers = self._load_cluster_centers(model_name)
97
- self.cluster_center_images = self._load_cluster_center_images(model_name)
98
-
99
- def _download_all_models(self):
100
- for name in self.MODEL_NAMES:
101
- self._load_model(name)
102
-
103
- def _download_all_cluster_centers(self):
104
- for name in self.MODEL_NAMES:
105
- self._load_cluster_centers(name)
106
-
107
- def _download_all_cluster_center_images(self):
108
- for name in self.MODEL_NAMES:
109
- self._load_cluster_center_images(name)
110
-
111
- def generate_z(self, seed: int) -> torch.Tensor:
112
- seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
113
- return torch.from_numpy(np.random.RandomState(seed).randn(1, self.model.z_dim)).float().to(self.device)
114
-
115
- def compute_w(self, z: torch.Tensor) -> torch.Tensor:
116
- label = torch.zeros((1, self.model.c_dim), device=self.device)
117
- w = self.model.mapping(z, label)
118
- return w
119
-
120
- @staticmethod
121
- def truncate_w(w_center: torch.Tensor, w: torch.Tensor, psi: float) -> torch.Tensor:
122
- if psi == 1:
123
- return w
124
- return w_center.lerp(w, psi)
125
-
126
- @torch.inference_mode()
127
- def synthesize(self, w: torch.Tensor) -> torch.Tensor:
128
- return self.model.synthesis(w)
129
-
130
- def postprocess(self, tensor: torch.Tensor) -> np.ndarray:
131
- tensor = (tensor.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
132
- return tensor.cpu().numpy()
133
-
134
- def compute_lpips_features(self, image: np.ndarray) -> list[torch.Tensor]:
135
- data = self.lpips.preprocess(image)
136
- return self.lpips.compute_features(data)
137
-
138
- def _compute_cluster_center_lpips_features(self) -> dict[str, list[list[torch.Tensor]]]:
139
- res = dict()
140
- for name in self.MODEL_NAMES:
141
- images = self._load_cluster_center_images(name)
142
- res[name] = [self.compute_lpips_features(image) for image in images]
143
- return res
144
-
145
- def compute_distance_to_cluster_centers(self, ws: torch.Tensor, distance_type: str) -> list[torch.Tensor]:
146
- if distance_type == "l2":
147
- return self._compute_l2_distance_to_cluster_centers(ws)
148
- elif distance_type == "lpips":
149
- return self._compute_lpips_distance_to_cluster_centers(ws)
150
- else:
151
- raise ValueError
152
-
153
- def _compute_l2_distance_to_cluster_centers(self, ws: torch.Tensor) -> np.ndarray:
154
- dist2 = ((self.cluster_centers - ws[0, 0]) ** 2).sum(dim=1)
155
- return dist2.cpu().numpy()
156
-
157
- def _compute_lpips_distance_to_cluster_centers(self, ws: torch.Tensor) -> np.ndarray:
158
- x = self.synthesize(ws)
159
- x = self.postprocess(x)[0]
160
- feat0 = self.compute_lpips_features(x)
161
- cluster_center_features = self.cluster_center_lpips_feature_dict[self.model_name]
162
- distances = [self.lpips.compute_distance(feat0, feat1) for feat1 in cluster_center_features]
163
- return np.asarray(distances)
164
-
165
- def find_nearest_cluster_center(self, ws: torch.Tensor, distance_type: str) -> int:
166
- distances = self.compute_distance_to_cluster_centers(ws, distance_type)
167
- return int(np.argmin(distances))
168
-
169
- def generate_image(self, seed: int, truncation_psi: float, truncation_type: str) -> np.ndarray:
170
- z = self.generate_z(seed)
171
- ws = self.compute_w(z)
172
- if truncation_type == self.TRUNCATION_TYPES[2]:
173
- w0 = self.model.mapping.w_avg
174
- else:
175
- if truncation_type == self.TRUNCATION_TYPES[0]:
176
- distance_type = "lpips"
177
- elif truncation_type == self.TRUNCATION_TYPES[1]:
178
- distance_type = "l2"
179
- else:
180
- raise ValueError
181
- cluster_index = self.find_nearest_cluster_center(ws, distance_type)
182
- w0 = self.cluster_centers[cluster_index]
183
- new_ws = self.truncate_w(w0, ws, truncation_psi)
184
- out = self.synthesize(new_ws)
185
- out = self.postprocess(out)
186
- return out[0]
187
-
188
- def set_model_and_generate_image(
189
- self, model_name: str, seed: int, truncation_psi: float, truncation_type: str
190
- ) -> np.ndarray:
191
- self.set_model(model_name)
192
- return self.generate_image(seed, truncation_psi, truncation_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,6 +1,5 @@
1
- lpips==0.1.4
2
- numpy==1.26.4
3
- Pillow==10.3.0
4
- scipy==1.13.1
5
- torch==2.0.1
6
- torchvision==0.15.2
 
1
+ numpy==1.22.3
2
+ Pillow==9.0.1
3
+ scipy==1.8.0
4
+ torch==1.11.0
5
+ torchvision==0.12.0
 
cluster_center_images/elephants_512.jpg β†’ samples/bicycles.jpg RENAMED
File without changes
samples/bicycles_256_global.jpg DELETED

Git LFS Details

  • SHA256: 7f298bfafbb01eb40b139478b765c5a3821d96c0f7768f626fd9aaa0aaff76fe
  • Pointer size: 132 Bytes
  • Size of remote file: 3.06 MB
samples/bicycles_256_multimodal_l2.jpg DELETED

Git LFS Details

  • SHA256: 5251d110bdb3137cd27e4a0f41dd067bd1e743a45f2066eea3f5640a4da09ec8
  • Pointer size: 132 Bytes
  • Size of remote file: 3.11 MB
samples/bicycles_256_multimodal_lpips.jpg DELETED

Git LFS Details

  • SHA256: 572a68fe268667c283d89b8262534b432aef6600c157838ac942e4b393ee9d30
  • Pointer size: 132 Bytes
  • Size of remote file: 3.1 MB
cluster_center_images/dogs_1024.jpg β†’ samples/dogs.jpg RENAMED
File without changes
samples/dogs_1024_global.jpg DELETED

Git LFS Details

  • SHA256: 01d084e70b5e203a86e80735df708e6e1dc4e809681d1581d757ef3d481a44e9
  • Pointer size: 133 Bytes
  • Size of remote file: 32.5 MB
samples/dogs_1024_multimodal_l2.jpg DELETED

Git LFS Details

  • SHA256: 4e05642979481ef2e69eb16f2c6de5b5d400ce91ac79b01be91f6c0f635bf6ea
  • Pointer size: 133 Bytes
  • Size of remote file: 32.7 MB
samples/dogs_1024_multimodal_lpips.jpg DELETED

Git LFS Details

  • SHA256: 069dc857838aba588903e87e769cfa587ea156db209a9d81fb1226dfe25fdb09
  • Pointer size: 133 Bytes
  • Size of remote file: 33 MB
samples/elephants.jpg ADDED

Git LFS Details

  • SHA256: 106b68d11a1c9d3d9b9f51ef80674bf351d1fd78291e8690d2ab4ed259986493
  • Pointer size: 133 Bytes
  • Size of remote file: 12.1 MB
samples/elephants_512_global.jpg DELETED

Git LFS Details

  • SHA256: 0a7454098794427b946405ab4acb3af2da50097cc8ab0f2068a1b7b49af16343
  • Pointer size: 133 Bytes
  • Size of remote file: 12.1 MB
samples/elephants_512_multimodal_l2.jpg DELETED

Git LFS Details

  • SHA256: dcea1f74b5acf885c9df7f49b66f66a66134c741be013bc1e74145952ddf27da
  • Pointer size: 133 Bytes
  • Size of remote file: 12 MB
samples/elephants_512_multimodal_lpips.jpg DELETED

Git LFS Details

  • SHA256: 5ad535f24a634eb6a51d87ac5ddb3dcdb8ee0c44bdb7eb8342851069fbee7d45
  • Pointer size: 133 Bytes
  • Size of remote file: 11.8 MB
samples/giraffes.jpg ADDED

Git LFS Details

  • SHA256: 3d91a32b61056874a698af5749e3d002e20bff608055b6104413081d845bedb4
  • Pointer size: 133 Bytes
  • Size of remote file: 10.6 MB
samples/giraffes_512_global.jpg DELETED

Git LFS Details

  • SHA256: d2df09be7e8305d6a4448b1d2c060340c67dbdaa461a2145e49f74deb6c59bab
  • Pointer size: 133 Bytes
  • Size of remote file: 10.6 MB
samples/giraffes_512_multimodal_l2.jpg DELETED

Git LFS Details

  • SHA256: def18b83f88b8c4ae5da399db7f61073c4fc757487ddc3a99f41e59821493a5a
  • Pointer size: 133 Bytes
  • Size of remote file: 10.4 MB
samples/giraffes_512_multimodal_lpips.jpg DELETED

Git LFS Details

  • SHA256: 75e1c466cdf2a4b302584c139b0e23d798a8cb031828a489352c1b9720339f84
  • Pointer size: 133 Bytes
  • Size of remote file: 10.4 MB
cluster_center_images/giraffes_512.jpg β†’ samples/horses.jpg RENAMED
File without changes
samples/horses_256_global.jpg DELETED

Git LFS Details

  • SHA256: be290bc1bf68aa5b1ed2fc2a267453868cb9cba64697dc38274b1e9068b7004a
  • Pointer size: 132 Bytes
  • Size of remote file: 3.15 MB
samples/horses_256_multimodal_l2.jpg DELETED

Git LFS Details

  • SHA256: 5f8873d00066b11742c2d8e37f59e415757a6fbaefe36b60aa15235192700cb4
  • Pointer size: 132 Bytes
  • Size of remote file: 3.12 MB
samples/horses_256_multimodal_lpips.jpg DELETED

Git LFS Details

  • SHA256: 0de60b0c30d46bd2194a23299dbc7541dbc3a389fe14502e9aac2acc9be85cff
  • Pointer size: 132 Bytes
  • Size of remote file: 3.13 MB
samples/lions.jpg ADDED

Git LFS Details

  • SHA256: 4216e153da49fbff81ef41484f48f1c68c6c1d455cba0a1eed8458aa64dacccc
  • Pointer size: 133 Bytes
  • Size of remote file: 11.3 MB
samples/lions_512_global.jpg DELETED

Git LFS Details

  • SHA256: 239cdf4156b42105787baedc5334d5044ab9bfe3340dcfd7353c2c6e0eae7e03
  • Pointer size: 133 Bytes
  • Size of remote file: 11.3 MB
samples/lions_512_multimodal_l2.jpg DELETED

Git LFS Details

  • SHA256: 6f85c3528c313ccc8534c71ea242a02183d9e5cf6c20ab0cad18e22e39d0be8c
  • Pointer size: 133 Bytes
  • Size of remote file: 11.2 MB
samples/lions_512_multimodal_lpips.jpg DELETED

Git LFS Details

  • SHA256: a404d939b181cc708a3569458fee55eed72e5c66e550d23c3e77478787dd85a4
  • Pointer size: 133 Bytes
  • Size of remote file: 11.1 MB
cluster_center_images/bicycles_256.jpg β†’ samples/parrots.jpg RENAMED
File without changes
samples/parrots_512_global.jpg DELETED

Git LFS Details

  • SHA256: 3e77e7c12316d1855962d62cd2fbf9eda4c5cc5416e69f570a841be70ff3347f
  • Pointer size: 132 Bytes
  • Size of remote file: 6.71 MB
samples/parrots_512_multimodal_l2.jpg DELETED

Git LFS Details

  • SHA256: d743d436eb3585d0cb60a0b600d2db73ecb3b01558485281ff993ff39fe7c460
  • Pointer size: 132 Bytes
  • Size of remote file: 6.94 MB
samples/parrots_512_multimodal_lpips.jpg DELETED

Git LFS Details

  • SHA256: 5b5346871f6d5b15f5b6d570a1adef7c48516a6326e332232b5df8fd25eebe0f
  • Pointer size: 132 Bytes
  • Size of remote file: 6.93 MB
style.css DELETED
@@ -1,8 +0,0 @@
1
- h1 {
2
- text-align: center;
3
- display: block;
4
- }
5
- div#result {
6
- max-width: 600px;
7
- max-height: 600px;
8
- }