callum-canavan commited on
Commit
954caab
1 Parent(s): b4209f3

Add helpers, change to hot dog example

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ env/
2
+ __pycache__/
3
+ assets/
app.py CHANGED
@@ -1,9 +1,25 @@
1
  import gradio as gr
 
2
 
 
3
 
4
- def greet(name):
5
- return "Hello " + name + "!!"
6
 
 
 
 
7
 
8
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
9
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import pipeline
3
 
4
+ pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
5
 
 
 
6
 
7
+ def predict(input_img):
8
+ predictions = pipeline(input_img)
9
+ return input_img, {p["label"]: p["score"] for p in predictions}
10
 
11
+
12
+ gradio_app = gr.Interface(
13
+ predict,
14
+ inputs=gr.Image(
15
+ label="Select hot dog candidate", sources=["upload", "webcam"], type="pil"
16
+ ),
17
+ outputs=[
18
+ gr.Image(label="Processed Image"),
19
+ gr.Label(label="Result", num_top_classes=2),
20
+ ],
21
+ title="Hot Dog? Or Not?",
22
+ )
23
+
24
+ if __name__ == "__main__":
25
+ gradio_app.launch()
diffuse.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import DiffusionPipeline
2
+ from diffusers.utils import pt_to_pil
3
+ import torch
4
+
5
+ # stage 1
6
+ stage_1 = DiffusionPipeline.from_pretrained(
7
+ "DeepFloyd/IF-I-M-v1.0", variant="fp16", torch_dtype=torch.float16
8
+ )
9
+ stage_1.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0
10
+ stage_1.enable_model_cpu_offload()
11
+
12
+ # stage 2
13
+ stage_2 = DiffusionPipeline.from_pretrained(
14
+ "DeepFloyd/IF-II-M-v1.0",
15
+ text_encoder=None,
16
+ variant="fp16",
17
+ torch_dtype=torch.float16,
18
+ )
19
+ stage_2.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0
20
+ stage_2.enable_model_cpu_offload()
21
+
22
+ # stage 3
23
+ safety_modules = {
24
+ "feature_extractor": stage_1.feature_extractor,
25
+ "safety_checker": stage_1.safety_checker,
26
+ "watermarker": stage_1.watermarker,
27
+ }
28
+ stage_3 = DiffusionPipeline.from_pretrained(
29
+ "stabilityai/stable-diffusion-x4-upscaler",
30
+ **safety_modules,
31
+ torch_dtype=torch.float16
32
+ )
33
+ stage_3.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0
34
+ stage_3.enable_model_cpu_offload()
generate.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ from diffusers import DiffusionPipeline
6
+
7
+ from visual_anagrams.views import get_views
8
+ from visual_anagrams.samplers import sample_stage_1, sample_stage_2
9
+ from visual_anagrams.utils import add_args, save_illusion, save_metadata
10
+
11
+
12
+
13
+ # Parse args
14
+ parser = argparse.ArgumentParser()
15
+ parser = add_args(parser)
16
+ args = parser.parse_args()
17
+
18
+ # Do admin stuff
19
+ save_dir = Path(args.save_dir) / args.name
20
+ save_dir.mkdir(exist_ok=True, parents=True)
21
+
22
+ # Make models
23
+ stage_1 = DiffusionPipeline.from_pretrained(
24
+ "DeepFloyd/IF-I-M-v1.0",
25
+ variant="fp16",
26
+ torch_dtype=torch.float16)
27
+ stage_2 = DiffusionPipeline.from_pretrained(
28
+ "DeepFloyd/IF-II-M-v1.0",
29
+ text_encoder=None,
30
+ variant="fp16",
31
+ torch_dtype=torch.float16,
32
+ )
33
+ stage_1.enable_model_cpu_offload()
34
+ stage_2.enable_model_cpu_offload()
35
+ stage_1 = stage_1.to(args.device)
36
+ stage_2 = stage_2.to(args.device)
37
+
38
+ # Get prompt embeddings
39
+ prompt_embeds = [stage_1.encode_prompt(f'{args.style} {p}'.strip()) for p in args.prompts]
40
+ prompt_embeds, negative_prompt_embeds = zip(*prompt_embeds)
41
+ prompt_embeds = torch.cat(prompt_embeds)
42
+ negative_prompt_embeds = torch.cat(negative_prompt_embeds) # These are just null embeds
43
+
44
+ # Get views
45
+ views = get_views(args.views)
46
+
47
+ # Save metadata
48
+ save_metadata(views, args, save_dir)
49
+
50
+ # Sample illusions
51
+ for i in range(args.num_samples):
52
+ # Admin stuff
53
+ generator = torch.manual_seed(args.seed + i)
54
+ sample_dir = save_dir / f'{i:04}'
55
+ sample_dir.mkdir(exist_ok=True, parents=True)
56
+
57
+ # Sample 64x64 image
58
+ image = sample_stage_1(stage_1,
59
+ prompt_embeds,
60
+ negative_prompt_embeds,
61
+ views,
62
+ num_inference_steps=args.num_inference_steps,
63
+ guidance_scale=args.guidance_scale,
64
+ reduction=args.reduction,
65
+ generator=generator)
66
+ save_illusion(image, views, sample_dir)
67
+
68
+ # Sample 256x256 image, by upsampling 64x64 image
69
+ image = sample_stage_2(stage_2,
70
+ image,
71
+ prompt_embeds,
72
+ negative_prompt_embeds,
73
+ views,
74
+ num_inference_steps=args.num_inference_steps,
75
+ guidance_scale=args.guidance_scale,
76
+ reduction=args.reduction,
77
+ noise_level=args.noise_level,
78
+ generator=generator)
79
+ save_illusion(image, views, sample_dir)
requirements.txt ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.4.0
2
+ aiohttp==3.8.5
3
+ aiosignal==1.3.1
4
+ annotated-types==0.5.0
5
+ anyio==3.7.1
6
+ argcomplete @ file:///private/tmp/python-argcomplete-20231112-5493-8o8e4p/argcomplete-3.1.6
7
+ arrow==1.2.3
8
+ astroid==2.15.6
9
+ astunparse==1.6.3
10
+ async-timeout==4.0.3
11
+ attrs==23.1.0
12
+ aws-cdk-lib==2.104.0
13
+ aws-cdk.asset-awscli-v1==2.2.201
14
+ aws-cdk.asset-kubectl-v20==2.1.2
15
+ aws-cdk.asset-node-proxy-agent-v6==2.0.1
16
+ backoff==2.2.1
17
+ beautifulsoup4==4.12.2
18
+ black==23.9.1
19
+ blessed==1.20.0
20
+ cachetools==5.3.1
21
+ cattrs==23.1.2
22
+ certifi==2022.12.7
23
+ charset-normalizer==3.1.0
24
+ click==8.1.7
25
+ constructs==10.3.0
26
+ contourpy==1.0.7
27
+ croniter==1.4.1
28
+ cycler==0.11.0
29
+ dataclasses-json==0.6.1
30
+ dateutils==0.6.12
31
+ deepdiff==6.5.0
32
+ diffusers==0.24.0
33
+ dill==0.3.7
34
+ distlib==0.3.6
35
+ easydict==1.10
36
+ fastapi==0.103.1
37
+ filelock==3.9.0
38
+ flatbuffers==23.5.26
39
+ fonttools==4.39.3
40
+ frozenlist==1.4.0
41
+ fsspec==2023.9.0
42
+ gast==0.4.0
43
+ gitdb==4.0.10
44
+ GitPython==3.1.36
45
+ google-auth==2.22.0
46
+ google-auth-oauthlib==1.0.0
47
+ google-pasta==0.2.0
48
+ grpcio==1.57.0
49
+ h11==0.14.0
50
+ h5py==3.9.0
51
+ huggingface-hub==0.19.4
52
+ idna==3.4
53
+ importlib-metadata==6.9.0
54
+ importlib-resources==6.1.0
55
+ iniconfig==2.0.0
56
+ inquirer==3.1.3
57
+ isort==5.12.0
58
+ itsdangerous==2.1.2
59
+ Jinja2==3.1.2
60
+ joblib==1.3.2
61
+ jsii==1.91.0
62
+ jsonpatch==1.33
63
+ jsonpointer==2.4
64
+ keras==2.13.1
65
+ kiwisolver==1.4.4
66
+ langchain==0.0.330
67
+ langsmith==0.0.57
68
+ lazy-object-proxy==1.9.0
69
+ libclang==16.0.6
70
+ lightning==2.0.8
71
+ lightning-cloud==0.5.38
72
+ lightning-utilities==0.9.0
73
+ Markdown==3.4.4
74
+ markdown-it-py==3.0.0
75
+ MarkupSafe==2.1.2
76
+ marshmallow==3.20.1
77
+ matplotlib==3.7.2
78
+ mccabe==0.7.0
79
+ mdurl==0.1.2
80
+ mpmath==1.3.0
81
+ multidict==6.0.4
82
+ mypy-extensions==1.0.0
83
+ networkx==3.1
84
+ numpy==1.24.2
85
+ oauthlib==3.2.2
86
+ opencv-python==4.7.0.72
87
+ opt-einsum==3.3.0
88
+ ordered-set==4.1.0
89
+ packaging==23.1
90
+ pandas==2.0.3
91
+ pathspec==0.11.2
92
+ Pillow==9.5.0
93
+ platformdirs==3.1.0
94
+ pluggy==1.3.0
95
+ protobuf==4.24.0
96
+ psutil==5.9.5
97
+ publication==0.0.3
98
+ py-cpuinfo==9.0.0
99
+ pyasn1==0.5.0
100
+ pyasn1-modules==0.3.0
101
+ pybind11==2.11.1
102
+ pydantic==2.1.1
103
+ pydantic_core==2.4.0
104
+ Pygments==2.16.1
105
+ PyJWT==2.8.0
106
+ pylint==2.17.5
107
+ pyparsing==3.0.9
108
+ pytest==7.4.2
109
+ python-dateutil==2.8.2
110
+ python-dotenv==1.0.0
111
+ python-editor==1.0.4
112
+ python-multipart==0.0.6
113
+ pytorch-lightning==2.0.8
114
+ pytz==2023.3
115
+ PyYAML==6.0.1
116
+ readchar==4.0.5
117
+ regex==2023.10.3
118
+ requests==2.28.2
119
+ requests-oauthlib==1.3.1
120
+ rich==13.5.2
121
+ rsa==4.9
122
+ safetensors==0.4.1
123
+ scikit-learn==1.3.0
124
+ seaborn==0.12.2
125
+ six==1.16.0
126
+ smmap==5.0.0
127
+ sniffio==1.3.0
128
+ soupsieve==2.5
129
+ SQLAlchemy==2.0.23
130
+ starlette==0.27.0
131
+ starsessions==1.3.0
132
+ sympy==1.11.1
133
+ tenacity==8.2.3
134
+ tensorboard==2.13.0
135
+ tensorboard-data-server==0.7.1
136
+ tensorflow==2.13.0
137
+ tensorflow-estimator==2.13.0
138
+ termcolor==2.3.0
139
+ threadpoolctl==3.2.0
140
+ tokenizers==0.15.0
141
+ tomlkit==0.12.1
142
+ torch==2.0.1
143
+ torchaudio==2.0.2
144
+ torchmetrics==1.1.2
145
+ torchvision==0.15.2
146
+ tqdm==4.65.0
147
+ traitlets==5.10.0
148
+ transformers==4.35.2
149
+ typeguard==2.13.3
150
+ typing-inspect==0.9.0
151
+ typing_extensions==4.6.1
152
+ tzdata==2023.3
153
+ ultralytics==8.0.178
154
+ urllib3==1.26.15
155
+ uvicorn==0.23.2
156
+ virtualenv==20.20.0
157
+ wcwidth==0.2.6
158
+ websocket-client==1.6.3
159
+ websockets==11.0.3
160
+ Werkzeug==2.3.7
161
+ wrapt==1.15.0
162
+ yacs==0.1.8
163
+ yarl==1.9.2
164
+ yolov4==2.0.3
165
+ zipp==3.17.0
visual_anagrams/__init__.py ADDED
File without changes
visual_anagrams/samplers.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from diffusers.utils.torch_utils import randn_tensor
7
+
8
+ @torch.no_grad()
9
+ def sample_stage_1(model,
10
+ prompt_embeds,
11
+ negative_prompt_embeds,
12
+ views,
13
+ num_inference_steps=100,
14
+ guidance_scale=7.0,
15
+ reduction='mean',
16
+ generator=None):
17
+
18
+ # Params
19
+ num_images_per_prompt = 1
20
+ device = model.device
21
+ height = model.unet.config.sample_size
22
+ width = model.unet.config.sample_size
23
+ batch_size = 1 # TODO: Support larger batch sizes, maybe
24
+ num_prompts = prompt_embeds.shape[0]
25
+ assert num_prompts == len(views), \
26
+ "Number of prompts must match number of views!"
27
+
28
+ # For CFG
29
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
30
+
31
+ # Setup timesteps
32
+ model.scheduler.set_timesteps(num_inference_steps, device=device)
33
+ timesteps = model.scheduler.timesteps
34
+
35
+ # Make intermediate_images
36
+ noisy_images = model.prepare_intermediate_images(
37
+ batch_size * num_images_per_prompt,
38
+ model.unet.config.in_channels,
39
+ height,
40
+ width,
41
+ prompt_embeds.dtype,
42
+ device,
43
+ generator,
44
+ )
45
+
46
+ for i, t in enumerate(tqdm(timesteps)):
47
+ # Apply views to noisy_image
48
+ viewed_noisy_images = []
49
+ for view_fn in views:
50
+ viewed_noisy_images.append(view_fn.view(noisy_images[0]))
51
+ viewed_noisy_images = torch.stack(viewed_noisy_images)
52
+
53
+ # Duplicate inputs for CFG
54
+ # Model input is: [ neg_0, neg_1, ..., pos_0, pos_1, ... ]
55
+ model_input = torch.cat([viewed_noisy_images] * 2)
56
+ model_input = model.scheduler.scale_model_input(model_input, t)
57
+
58
+ # Predict noise estimate
59
+ noise_pred = model.unet(
60
+ model_input,
61
+ t,
62
+ encoder_hidden_states=prompt_embeds,
63
+ cross_attention_kwargs=None,
64
+ return_dict=False,
65
+ )[0]
66
+
67
+ # Extract uncond (neg) and cond noise estimates
68
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
69
+
70
+ # Invert the unconditional (negative) estimates
71
+ inverted_preds = []
72
+ for pred, view in zip(noise_pred_uncond, views):
73
+ inverted_pred = view.inverse_view(pred)
74
+ inverted_preds.append(inverted_pred)
75
+ noise_pred_uncond = torch.stack(inverted_preds)
76
+
77
+ # Invert the conditional estimates
78
+ inverted_preds = []
79
+ for pred, view in zip(noise_pred_text, views):
80
+ inverted_pred = view.inverse_view(pred)
81
+ inverted_preds.append(inverted_pred)
82
+ noise_pred_text = torch.stack(inverted_preds)
83
+
84
+ # Split into noise estimate and variance estimates
85
+ noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
86
+ noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
87
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
88
+
89
+ # Reduce predicted noise and variances
90
+ noise_pred = noise_pred.view(-1,num_prompts,3,64,64)
91
+ predicted_variance = predicted_variance.view(-1,num_prompts,3,64,64)
92
+ if reduction == 'mean':
93
+ noise_pred = noise_pred.mean(1)
94
+ predicted_variance = predicted_variance.mean(1)
95
+ elif reduction == 'alternate':
96
+ noise_pred = noise_pred[:,i%num_prompts]
97
+ predicted_variance = predicted_variance[:,i%num_prompts]
98
+ else:
99
+ raise ValueError('Reduction must be either `mean` or `alternate`')
100
+ noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
101
+
102
+ # compute the previous noisy sample x_t -> x_t-1
103
+ noisy_images = model.scheduler.step(
104
+ noise_pred, t, noisy_images, generator=generator, return_dict=False
105
+ )[0]
106
+
107
+ # Return denoised images
108
+ return noisy_images
109
+
110
+
111
+
112
+
113
+
114
+
115
+
116
+ @torch.no_grad()
117
+ def sample_stage_2(model,
118
+ image,
119
+ prompt_embeds,
120
+ negative_prompt_embeds,
121
+ views,
122
+ num_inference_steps=100,
123
+ guidance_scale=7.0,
124
+ reduction='mean',
125
+ noise_level=50,
126
+ generator=None):
127
+
128
+ # Params
129
+ batch_size = 1 # TODO: Support larger batch sizes, maybe
130
+ num_prompts = prompt_embeds.shape[0]
131
+ height = model.unet.config.sample_size
132
+ width = model.unet.config.sample_size
133
+ device = model.device
134
+ num_images_per_prompt = 1
135
+
136
+ # For CFG
137
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
138
+
139
+ # Get timesteps
140
+ model.scheduler.set_timesteps(num_inference_steps, device=device)
141
+ timesteps = model.scheduler.timesteps
142
+
143
+ num_channels = model.unet.config.in_channels // 2
144
+ noisy_images = model.prepare_intermediate_images(
145
+ batch_size * num_images_per_prompt,
146
+ num_channels,
147
+ height,
148
+ width,
149
+ prompt_embeds.dtype,
150
+ device,
151
+ generator,
152
+ )
153
+
154
+ # Prepare upscaled image and noise level
155
+ image = model.preprocess_image(image, num_images_per_prompt, device)
156
+ upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True)
157
+
158
+ noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device)
159
+ noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype)
160
+ upscaled = model.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level)
161
+
162
+ # Condition on noise level, for each model input
163
+ noise_level = torch.cat([noise_level] * num_prompts * 2)
164
+
165
+ # Denoising Loop
166
+ for i, t in enumerate(tqdm(timesteps)):
167
+ # Cat noisy image with upscaled conditioning image
168
+ model_input = torch.cat([noisy_images, upscaled], dim=1)
169
+
170
+ # Apply views to noisy_image
171
+ viewed_inputs = []
172
+ for view_fn in views:
173
+ viewed_inputs.append(view_fn.view(model_input[0]))
174
+ viewed_inputs = torch.stack(viewed_inputs)
175
+
176
+ # Duplicate inputs for CFG
177
+ # Model input is: [ neg_0, neg_1, ..., pos_0, pos_1, ... ]
178
+ model_input = torch.cat([viewed_inputs] * 2)
179
+ model_input = model.scheduler.scale_model_input(model_input, t)
180
+
181
+ # predict the noise residual
182
+ noise_pred = model.unet(
183
+ model_input,
184
+ t,
185
+ encoder_hidden_states=prompt_embeds,
186
+ class_labels=noise_level,
187
+ cross_attention_kwargs=None,
188
+ return_dict=False,
189
+ )[0]
190
+
191
+ # Extract uncond (neg) and cond noise estimates
192
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
193
+
194
+ # Invert the unconditional (negative) estimates
195
+ # TODO: pretty sure you can combine these into one loop
196
+ inverted_preds = []
197
+ for pred, view in zip(noise_pred_uncond, views):
198
+ inverted_pred = view.inverse_view(pred)
199
+ inverted_preds.append(inverted_pred)
200
+ noise_pred_uncond = torch.stack(inverted_preds)
201
+
202
+ # Invert the conditional estimates
203
+ inverted_preds = []
204
+ for pred, view in zip(noise_pred_text, views):
205
+ inverted_pred = view.inverse_view(pred)
206
+ inverted_preds.append(inverted_pred)
207
+ noise_pred_text = torch.stack(inverted_preds)
208
+
209
+ # Split predicted noise and predicted variances
210
+ noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
211
+ noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1)
212
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
213
+
214
+ # Combine noise estimates (and variance estimates)
215
+ noise_pred = noise_pred.view(-1,num_prompts,3,256,256)
216
+ predicted_variance = predicted_variance.view(-1,num_prompts,3,256,256)
217
+ if reduction == 'mean':
218
+ noise_pred = noise_pred.mean(1)
219
+ predicted_variance = predicted_variance.mean(1)
220
+ elif reduction == 'alternate':
221
+ noise_pred = noise_pred[:,i%num_prompts]
222
+ predicted_variance = predicted_variance[:,i%num_prompts]
223
+
224
+ noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
225
+
226
+ # compute the previous noisy sample x_t -> x_t-1
227
+ noisy_images = model.scheduler.step(
228
+ noise_pred, t, noisy_images, generator=generator, return_dict=False
229
+ )[0]
230
+
231
+ # Return denoised images
232
+ return noisy_images
visual_anagrams/utils.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ from torchvision.utils import save_image
6
+
7
+
8
+ def add_args(parser):
9
+ """
10
+ Add arguments for sampling to a parser
11
+ """
12
+
13
+ parser.add_argument("--name", required=True, type=str)
14
+ parser.add_argument(
15
+ "--save_dir",
16
+ type=str,
17
+ default="results",
18
+ help="Location to samples and metadata",
19
+ )
20
+ parser.add_argument(
21
+ "--prompts",
22
+ required=True,
23
+ type=str,
24
+ nargs="+",
25
+ help="Prompts to use, corresponding to each view.",
26
+ )
27
+ parser.add_argument(
28
+ "--views",
29
+ required=True,
30
+ type=str,
31
+ nargs="+",
32
+ help="Name of views to use. See `get_views` in `views.py`.",
33
+ )
34
+ parser.add_argument(
35
+ "--style", default="", type=str, help="Optional string to prepend prompt with"
36
+ )
37
+ parser.add_argument("--num_inference_steps", type=int, default=100)
38
+ parser.add_argument("--num_samples", type=int, default=100)
39
+ parser.add_argument("--reduction", type=str, default="mean")
40
+ parser.add_argument("--seed", type=int, default=0)
41
+ parser.add_argument("--guidance_scale", type=float, default=7.0)
42
+ parser.add_argument(
43
+ "--noise_level", type=int, default=50, help="Noise level for stage 2"
44
+ )
45
+ parser.add_argument("--device", type=str, default="cpu")
46
+ parser.add_argument(
47
+ "--save_metadata",
48
+ action="store_true",
49
+ help="If true, save metadata about the views. May use lots of disk space, particular for permutation views.",
50
+ )
51
+
52
+ return parser
53
+
54
+
55
+ def save_illusion(image, views, sample_dir):
56
+ """
57
+ Saves the illusion (`image`), as well as all views of the illusion
58
+
59
+ image (torch.tensor) :
60
+ Tensor of shape (1,3,H,W) representing the image
61
+
62
+ views (views.BaseView) :
63
+ Represents the view, inherits from BaseView
64
+
65
+ sample_dir (pathlib.Path) :
66
+ pathlib Path object, representing the directory to save to
67
+ """
68
+
69
+ size = image.shape[-1]
70
+
71
+ # Save illusion
72
+ save_image(image / 2.0 + 0.5, sample_dir / f"sample_{size}.png", padding=0)
73
+
74
+ # Save views of the illusion
75
+ im_views = torch.stack([view.view(image[0]) for view in views])
76
+ save_image(im_views / 2.0 + 0.5, sample_dir / f"sample_{size}.views.png", padding=0)
77
+
78
+
79
+ def save_metadata(views, args, save_dir):
80
+ """
81
+ Saves the following the sample_dir
82
+ 1) pickled view object
83
+ 2) args for the illusion
84
+ """
85
+
86
+ metadata = {"views": views, "args": args}
87
+ with open(save_dir / "metadata.pkl", "wb") as f:
88
+ pickle.dump(metadata, f)
89
+
90
+
91
+ def get_courier_font_path():
92
+ font_path = Path(__file__).parent / "assets" / "CourierPrime-Regular.ttf"
93
+ return str(font_path)
visual_anagrams/views/__init__.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from PIL import Image
3
+ import numpy as np
4
+
5
+ from .view_identity import IdentityView
6
+ from .view_flip import FlipView
7
+ from .view_rotate import Rotate180View, Rotate90CCWView, Rotate90CWView
8
+ from .view_negate import NegateView
9
+ from .view_skew import SkewView
10
+ from .view_patch_permute import PatchPermuteView
11
+ from .view_jigsaw import JigsawView
12
+ from .view_inner_circle import InnerCircleView
13
+
14
+ VIEW_MAP = {
15
+ 'identity': IdentityView,
16
+ 'flip': FlipView,
17
+ 'rotate_cw': Rotate90CWView,
18
+ 'rotate_ccw': Rotate90CCWView,
19
+ 'rotate_180': Rotate180View,
20
+ 'negate': NegateView,
21
+ 'skew': SkewView,
22
+ 'patch_permute': PatchPermuteView,
23
+ 'pixel_permute': PatchPermuteView,
24
+ 'jigsaw': JigsawView,
25
+ 'inner_circle': InnerCircleView,
26
+ }
27
+
28
+ def get_views(view_names):
29
+ '''
30
+ Bespoke function to get views (just to make command line usage easier)
31
+ '''
32
+ views = []
33
+ for view_name in view_names:
34
+ if view_name == 'patch_permute':
35
+ args = [8]
36
+ elif view_name == 'pixel_permute':
37
+ args = [64]
38
+ elif view_name == 'skew':
39
+ args = [1.5]
40
+ else:
41
+ args = []
42
+
43
+ view = VIEW_MAP[view_name](*args)
44
+ views.append(view)
45
+
46
+ return views
visual_anagrams/views/jigsaw_helpers.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from PIL import Image
3
+ import numpy as np
4
+
5
+ def get_jigsaw_pieces(size):
6
+ '''
7
+ Load all pieces of the 4x4 jigsaw puzzle.
8
+
9
+ size (int) :
10
+ Should be 64 or 256, indicating side length of jigsaw puzzle
11
+ '''
12
+
13
+ # Location of pieces
14
+ piece_dir = Path(__file__).parent / 'assets'
15
+
16
+ # Helper function to load pieces as np arrays
17
+ def load_pieces(path):
18
+ '''
19
+ Load a piece, from the given path, as a binary numpy array.
20
+ Return a list of the "base" piece, and all four of its rotations.
21
+ '''
22
+ piece = Image.open(path)
23
+ piece = np.array(piece)[:,:,0] // 255
24
+ pieces = np.stack([np.rot90(piece, k=-i) for i in range(4)])
25
+ return pieces
26
+
27
+ # Load pieces and rotate to get 16 pieces, and cat
28
+ pieces_corner = load_pieces(piece_dir / f'4x4/4x4_corner_{size}.png')
29
+ pieces_inner = load_pieces(piece_dir / f'4x4/4x4_inner_{size}.png')
30
+ pieces_edge1 = load_pieces(piece_dir / f'4x4/4x4_edge1_{size}.png')
31
+ pieces_edge2 = load_pieces(piece_dir / f'4x4/4x4_edge2_{size}.png')
32
+ pieces = np.concatenate([pieces_corner, pieces_inner, pieces_edge1, pieces_edge2])
33
+
34
+ return pieces
35
+
visual_anagrams/views/permutations.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import numpy as np
3
+ import torch
4
+ import torchvision.transforms.functional as TF
5
+ from einops import rearrange, repeat
6
+
7
+ from .jigsaw_helpers import get_jigsaw_pieces
8
+
9
+ def get_inv_perm(perm):
10
+ '''
11
+ Get the inverse permutation of a permutation. That is, the array such that
12
+ perm[perm_inv] = perm_inv[perm] = arange(len(perm))
13
+
14
+ perm (torch.tensor) :
15
+ A 1-dimensional integer array, representing a permutation. Indicates
16
+ that element i should move to index perm[i]
17
+ '''
18
+ perm_inv = torch.empty_like(perm)
19
+ perm_inv[perm] = torch.arange(len(perm))
20
+ return perm_inv
21
+
22
+
23
+ def make_inner_circle_perm(im_size=64, r=24):
24
+ '''
25
+ Makes permutations for "inner circle" view. Given size of image, and
26
+ `r`, the radius of the circle. We do this by iterating through every
27
+ pixel and figuring out where it should go.
28
+ '''
29
+ perm = [] # Permutation array
30
+
31
+ # Iterate through all positions, in order
32
+ for iy in range(im_size):
33
+ for ix in range(im_size):
34
+ # Get coordinates, with origin at (0, 0)
35
+ x = ix - im_size // 2 + 0.5
36
+ y = iy - im_size // 2 + 0.5
37
+
38
+ # Do 180 deg rotation if in circle
39
+ if x**2 + y**2 < r**2:
40
+ x = -x
41
+ y = -y
42
+
43
+ # Convert back to integer coordinates
44
+ x = int(x + im_size // 2 - 0.5)
45
+ y = int(y + im_size // 2 - 0.5)
46
+
47
+ # Append destination pixel index to permutation
48
+ perm.append(x + y * im_size)
49
+ perm = torch.tensor(perm)
50
+
51
+ return perm
52
+
53
+
54
+
55
+
56
+ def make_jigsaw_perm(size, seed=0):
57
+ '''
58
+ Returns a permutation of pixels that is a jigsaw permutation
59
+
60
+ There are 3 types of pieces: corner, edge, and inner pieces. These were
61
+ created in MS Paint. They are all identical and laid out like:
62
+
63
+ c0 e0 f0 c1
64
+ f3 i0 i1 e1
65
+ e3 i3 i2 f1
66
+ c3 f2 e2 c2
67
+
68
+ where c is "corner," i is "inner," and "e" and "f" are "edges."
69
+ "e" and "f" pieces are identical, but labeled differently such that
70
+ to move any piece to the next index you can apply a 90 deg rotation.
71
+
72
+ Pieces c0, e0, f0, and i0 are defined by pngs, and will be loaded in. All
73
+ other pieces are obtained by 90 deg rotations of these "base" pieces.
74
+
75
+ Permutations are defined by:
76
+ 1. permutation of corner (c) pieces (length 4 perm list)
77
+ 2. permutation of inner (i) pieces (length 4 perm list)
78
+ 3. permutation of edge (e) pieces (length 4 perm list)
79
+ 4. permutation of edge (f) pieces (length 4 perm list)
80
+ 5. list of four swaps, indicating swaps between e and f
81
+ edge pieces along the same edge (length 4 bit list)
82
+
83
+ Note these perm indexes will just be a "rotation index" indicating
84
+ how many 90 deg rotations to apply to the base pieces. The swaps
85
+ ensure that any edge piece can go to any edge piece, and are indexed
86
+ by the indexes of the "e" and "f" pieces on the edge.
87
+
88
+ Also note, order of indexes in permutation array is raster scan order. So,
89
+ go along x's first, then y's. This means y * size + x gives us the
90
+ 1-D location in the permutation array. And image arrays are in
91
+ (y,x) order.
92
+
93
+ Plan of attack for making a pixel permutation array that represents
94
+ a jigsaw permutation:
95
+
96
+ 1. Iterate through all pixels (in raster scan order)
97
+ 2. Figure out which puzzle piece it is in initially
98
+ 3. Look at the permutations, and see where it should go
99
+ 4. Additionally, see if it's an edge piece, and needs to be swapped
100
+ 5. Add the new (1-D) index to the permutation array
101
+
102
+ '''
103
+ np.random.seed(seed)
104
+
105
+ # Get location of puzzle pieces
106
+ piece_dir = Path(__file__).parent / 'assets'
107
+
108
+ # Get random permutations of groups of 4, and cat
109
+ identity = np.arange(4)
110
+ perm_corner = np.random.permutation(identity)
111
+ perm_inner = np.random.permutation(identity)
112
+ perm_edge1 = np.random.permutation(identity)
113
+ perm_edge2 = np.random.permutation(identity)
114
+ edge_swaps = np.random.randint(2, size=4)
115
+ piece_perms = np.concatenate([perm_corner, perm_inner, perm_edge1, perm_edge2])
116
+
117
+ # Get all 16 jigsaw pieces (in the order above)
118
+ pieces = get_jigsaw_pieces(size)
119
+
120
+ # Make permutation array to fill
121
+ perm = []
122
+
123
+ # For each pixel, figure out where it should go
124
+ for y in range(size):
125
+ for x in range(size):
126
+ # Figure out which piece (x,y) is in:
127
+ piece_idx = pieces[:,y,x].argmax()
128
+
129
+ # Figure out how many 90 deg rotations are on the piece
130
+ rot_idx = piece_idx % 4
131
+
132
+ # The perms tells us how many 90 deg rotations to apply to
133
+ # arrive at new pixel location
134
+ dest_rot_idx = piece_perms[piece_idx]
135
+ angle = (dest_rot_idx - rot_idx) * 90 / 180 * np.pi
136
+
137
+ # Center coordinates on origin
138
+ cx = x - (size - 1) / 2.
139
+ cy = y - (size - 1) / 2.
140
+
141
+ # Perform rotation
142
+ nx = np.cos(angle) * cx - np.sin(angle) * cy
143
+ ny = np.sin(angle) * cx + np.cos(angle) * cy
144
+
145
+ # Translate back and round coordinates to _nearest_ integer
146
+ nx = nx + (size - 1) / 2.
147
+ ny = ny + (size - 1) / 2.
148
+ nx = int(np.rint(nx))
149
+ ny = int(np.rint(ny))
150
+
151
+ # Perform swap if piece is an edge, and swap == 1 at NEW location
152
+ new_piece_idx = pieces[:,ny,nx].argmax()
153
+ edge_idx = new_piece_idx % 4
154
+ if new_piece_idx >= 8 and edge_swaps[edge_idx] == 1:
155
+ is_f_edge = (new_piece_idx - 8) // 4 # 1 if f, 0 if e edge
156
+ edge_type_parity = 1 - 2 * is_f_edge
157
+ rotation_parity = 1 - 2 * (edge_idx // 2)
158
+ swap_dist = size // 4
159
+
160
+ # if edge_idx is even, swap in x direction, else y
161
+ if edge_idx % 2 == 0:
162
+ nx = nx + swap_dist * edge_type_parity * rotation_parity
163
+ else:
164
+ ny = ny + swap_dist * edge_type_parity * rotation_parity
165
+
166
+ # append new index to permutation array
167
+ new_idx = int(ny * size + nx)
168
+ perm.append(new_idx)
169
+
170
+ # sanity check
171
+ #import matplotlib.pyplot as plt
172
+ #missing = sorted(set(range(size*size)).difference(set(perm)))
173
+ #asdf = np.zeros(size*size)
174
+ #asdf[missing] = 1
175
+ #plt.imshow(asdf.reshape(size,size))
176
+ #plt.savefig('tmp.png')
177
+ #plt.show()
178
+ #print(np.sum(asdf))
179
+
180
+ #viz = np.zeros((64,64))
181
+ #for idx in perm:
182
+ # y, x = idx // 64, idx % 64
183
+ # viz[y,x] = 1
184
+ #plt.imshow(viz)
185
+ #plt.savefig('tmp.png')
186
+ #Image.fromarray(viz * 255).convert('RGB').save('tmp.png')
187
+ #Image.fromarray(pieces_edge1[0] * 255).convert('RGB').save('tmp.png')
188
+
189
+ # sanity check on test image
190
+ #im = Image.open('results/flip.campfire.man/0000/sample_64.png')
191
+ #im = Image.open('results/flip.campfire.man/0000/sample_256.png')
192
+ #im = np.array(im)
193
+ #Image.fromarray(im.reshape(-1, 3)[perm].reshape(size,size,3)).save('test.png')
194
+
195
+ return torch.tensor(perm), (piece_perms, edge_swaps)
196
+
197
+ #for i in range(100):
198
+ #make_jigsaw_perm(64, seed=i)
199
+ #make_jigsaw_perm(256, seed=11)
200
+
201
+
202
+ def recover_patch_permute(im_0, im_1, patch_size):
203
+ '''
204
+ Given two views of a patch permutation illusion, recover the patch
205
+ permutation used.
206
+
207
+ im_0 (PIL.Image) :
208
+ Identity view of the illusion
209
+
210
+ im_1 (PIL.Image) :
211
+ Patch permuted view of the illusion
212
+
213
+ patch_size (int) :
214
+ Size of the patches in the image
215
+ '''
216
+
217
+ # Convert to tensors
218
+ im_0 = TF.to_tensor(im_0)
219
+ im_1 = TF.to_tensor(im_1)
220
+
221
+ # Extract patches
222
+ patches_0 = rearrange(im_0,
223
+ 'c (h p1) (w p2) -> (h w) c p1 p2',
224
+ p1=patch_size,
225
+ p2=patch_size)
226
+ patches_1 = rearrange(im_1,
227
+ 'c (h p1) (w p2) -> (h w) c p1 p2',
228
+ p1=patch_size,
229
+ p2=patch_size)
230
+
231
+ # Repeat patches_1 for each patch in patches_0
232
+ patches_1_repeated = repeat(patches_1,
233
+ 'np c p1 p2 -> np1 np c p1 p2',
234
+ np=patches_1.shape[0],
235
+ np1=patches_1.shape[0],
236
+ p1=patch_size,
237
+ p2=patch_size)
238
+
239
+ # Find closest patch in other image by L1 dist, and return indexes
240
+ perm = (patches_1_repeated - patches_0[:,None]).abs().sum((2,3,4)).argmin(1)
241
+
242
+ return perm
visual_anagrams/views/view_base.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class BaseView:
2
+ '''
3
+ BaseView class, from which all views inherit. Implements the
4
+ following functions:
5
+ '''
6
+
7
+ def __init__(self):
8
+ pass
9
+
10
+ def view(self, im):
11
+ '''
12
+ Apply transform to an image.
13
+
14
+ im (`torch.tensor`):
15
+ For stage 1: Tensor of shape (3, H, W) representing a noisy image
16
+ OR
17
+ For stage 2: Tensor of shape (6, H, W) representing a noisy image
18
+ concatenated with an upsampled conditioning image from stage 1
19
+ '''
20
+ raise NotImplementedError()
21
+
22
+ def inverse_view(self, noise):
23
+ '''
24
+ Apply inverse transform to noise estimates.
25
+ Because DeepFloyd estimates the variance in addition to
26
+ the noise, this function must apply the inverse to the
27
+ variance as well.
28
+
29
+ im (`torch.tensor`):
30
+ Tensor of shape (6, H, W) representing the noise estimate
31
+ (first three channel dims) and variacne estimates (last
32
+ three channel dims)
33
+ '''
34
+ raise NotImplementedError()
35
+
36
+ def make_frame(self, im, t):
37
+ '''
38
+ Make a frame, transitioning linearly from the identity view (t=0)
39
+ to this view (t=1)
40
+
41
+ im (`PIL.Image`) :
42
+ A PIL Image of the illusion
43
+
44
+ t (float) :
45
+ A float in [0,1] indicating time in the animation. Should start
46
+ at the identity view at t=0, and continuously transition to the
47
+ view at t=1.
48
+ '''
49
+ raise NotImplementedError()
visual_anagrams/views/view_flip.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+
3
+ import torch
4
+
5
+ from .view_base import BaseView
6
+
7
+ class FlipView(BaseView):
8
+ def __init__(self):
9
+ pass
10
+
11
+ def view(self, im):
12
+ return torch.flip(im, [1])
13
+
14
+ def inverse_view(self, noise):
15
+ return torch.flip(noise, [1])
16
+
17
+ def make_frame(self, im, t):
18
+ im_size = im.size[0]
19
+ frame_size = int(im_size * 1.5)
20
+ theta = t * 180
21
+
22
+ # TODO: Technically not a flip, change this to a homography later
23
+ frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255))
24
+ frame.paste(im, ((frame_size - im_size) // 2, (frame_size - im_size) // 2))
25
+ frame = frame.rotate(theta,
26
+ resample=Image.Resampling.BILINEAR,
27
+ expand=False,
28
+ fillcolor=(255,255,255))
29
+
30
+ return frame
visual_anagrams/views/view_identity.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .view_base import BaseView
2
+
3
+ class IdentityView(BaseView):
4
+ def __init__(self):
5
+ pass
6
+
7
+ def view(self, im):
8
+ return im
9
+
10
+ def inverse_view(self, noise):
11
+ return noise
visual_anagrams/views/view_inner_circle.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+
4
+ import torch
5
+ import torchvision.transforms.functional as TF
6
+
7
+ from .permutations import make_inner_circle_perm
8
+ from .view_permute import PermuteView
9
+
10
+ class InnerCircleView(PermuteView):
11
+ '''
12
+ Implements an "inner circle" view, where a circle inside the image spins
13
+ but the border stays still. Inherits from `PermuteView`, which implements
14
+ the `view` and `inverse_view` functions as permutations. We just make
15
+ the correct permutation here, and implement the `make_frame` method
16
+ for animation
17
+ '''
18
+ def __init__(self):
19
+ '''
20
+ Make the correct "inner circle" permutations and pass it to the
21
+ parent class constructor.
22
+ '''
23
+ self.perm_64 = make_inner_circle_perm(im_size=64, r=24)
24
+ self.perm_256 = make_inner_circle_perm(im_size=256, r=96)
25
+
26
+ super().__init__(self.perm_64, self.perm_256)
27
+
28
+ def make_frame(self, im, t):
29
+ im_size = im.size[0]
30
+ frame_size = int(im_size * 1.5)
31
+ theta = -t * 180
32
+
33
+ # Convert to tensor
34
+ im = torch.tensor(np.array(im) / 255.).permute(2,0,1)
35
+
36
+ # Get mask of circle (TODO: assuming size 256)
37
+ coords = torch.arange(0, 256) - 127.5
38
+ xx, yy = torch.meshgrid(coords, coords)
39
+ mask = xx**2 + yy**2 < (24*4)**2
40
+ mask = torch.stack([mask]*3).float()
41
+
42
+ # Get rotate image
43
+ im_rotated = TF.rotate(im, theta)
44
+
45
+ # Composite rotated circle + border together
46
+ im = im * (1 - mask) + im_rotated * mask
47
+
48
+ # Convert back to PIL
49
+ im = Image.fromarray((np.array(im.permute(1,2,0)) * 255.).astype(np.uint8))
50
+
51
+ # Paste on to canvas
52
+ frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255))
53
+ frame.paste(im, ((frame_size - im_size) // 2, (frame_size - im_size) // 2))
54
+
55
+ return frame
56
+
visual_anagrams/views/view_jigsaw.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import torch
4
+ from einops import einsum, rearrange
5
+
6
+ from .permutations import make_jigsaw_perm, get_inv_perm
7
+ from .view_permute import PermuteView
8
+ from .jigsaw_helpers import get_jigsaw_pieces
9
+
10
+ class JigsawView(PermuteView):
11
+ '''
12
+ Implements a 4x4 jigsaw puzzle view...
13
+ '''
14
+ def __init__(self, seed=11):
15
+ '''
16
+ '''
17
+ # Get pixel permutations, corresponding to jigsaw permutations
18
+ self.perm_64, _ = make_jigsaw_perm(64, seed=seed)
19
+ self.perm_256, (jigsaw_perm) = make_jigsaw_perm(256, seed=seed)
20
+
21
+ # keep track of jigsaw permutation as well
22
+ self.piece_perms, self.edge_swaps = jigsaw_perm
23
+
24
+ # Init parent PermuteView, with above pixel perms
25
+ super().__init__(self.perm_64, self.perm_256)
26
+
27
+ def extract_pieces(self, im):
28
+ '''
29
+ Given an image, extract jigsaw puzzle pieces from it
30
+
31
+ im (PIL.Image) :
32
+ PIL Image of the jigsaw illusion
33
+ '''
34
+ im = np.array(im)
35
+ size = im.shape[0]
36
+ pieces = []
37
+
38
+ # Get jigsaw pieces
39
+ piece_masks = get_jigsaw_pieces(size)
40
+
41
+ # Save pieces
42
+ for piece_mask in piece_masks:
43
+ # Add mask as alpha mask to image
44
+ im_piece = np.concatenate([im, piece_mask[:,:,None] * 255], axis=2)
45
+
46
+ # Get extents of piece, and crop
47
+ x_min = np.nonzero(im_piece[:,:,-1].sum(0))[0].min()
48
+ x_max = np.nonzero(im_piece[:,:,-1].sum(0))[0].max()
49
+ y_min = np.nonzero(im_piece[:,:,-1].sum(1))[0].min()
50
+ y_max = np.nonzero(im_piece[:,:,-1].sum(1))[0].max()
51
+ im_piece = im_piece[y_min:y_max+1, x_min:x_max+1]
52
+
53
+ pieces.append(Image.fromarray(im_piece))
54
+
55
+ return pieces
56
+
57
+
58
+ def paste_piece(self, piece, x, y, theta, xc, yc, canvas_size=384):
59
+ '''
60
+ Given a PIL Image of a piece, place it so that it's center is at
61
+ (x,y) and it's rotate about that center at theta degrees
62
+
63
+ x (float) : x coordinate to place piece at
64
+ y (float) : y coordinate to place piece at
65
+ theta (float) : degrees to rotate piece about center
66
+ xc (float) : x coordinate of center of piece
67
+ yc (float) : y coordinate of center of piece
68
+ '''
69
+
70
+ # Make canvas
71
+ canvas = Image.new("RGBA",
72
+ (canvas_size, canvas_size),
73
+ (255, 255, 255, 0))
74
+
75
+ # Past piece so center is at (x, y)
76
+ canvas.paste(piece, (x-xc,y-yc), piece)
77
+
78
+ # Rotate about (x, y)
79
+ canvas = canvas.rotate(theta, resample=Image.BILINEAR, center=(x, y))
80
+ return canvas
81
+
82
+
83
+ def make_frame(self, im, t, canvas_size=384, knot_seed=0):
84
+ '''
85
+ This function returns a PIL image of a frame animating a jigsaw
86
+ permutation. Pieces move and rotate from the identity view
87
+ (t = 0) to the rearranged view (t = 1) along splines.
88
+
89
+ The approach is as follows:
90
+
91
+ 1. Extract all 16 pieces
92
+ 2. Figure out start locations for each of these pieces (t=0)
93
+ 3. Figure out how these pieces permute
94
+ 4. Using these permutations, figure out end locations (t=1)
95
+ 5. Make knots for splines, randomly offset normally from the
96
+ midpoint of the start and end locations
97
+ 6. Paste pieces into correct locations, determined by
98
+ spline interpolation
99
+
100
+ im (PIL.Image) :
101
+ PIL image representing the jigsaw illusion
102
+
103
+ t (float) :
104
+ Interpolation parameter in [0,1] indicating what frame of the
105
+ animation to generate
106
+
107
+ canvas_size (int) :
108
+ Side length of the frame
109
+
110
+ knot_seed (int) :
111
+ Seed for random offsets for the knots
112
+ '''
113
+ im_size = im.size[0]
114
+
115
+ # Extract 16 jigsaw pieces
116
+ pieces = self.extract_pieces(im)
117
+
118
+ # Rotate all pieces to "base" piece orientation
119
+ pieces = [p.rotate(90 * (i % 4),
120
+ resample=Image.BILINEAR,
121
+ expand=1) for i, p in enumerate(pieces)]
122
+
123
+ # Get (hardcoded) start locations for each base piece, on a
124
+ # 4x4 grid centered on the origin.
125
+ corner_start_loc = np.array([-1.5, -1.5])
126
+ inner_start_loc = np.array([-0.5, -0.5])
127
+ edge_e_start_loc = np.array([-1.5, -0.5])
128
+ edge_f_start_loc = np.array([-1.5, 0.5])
129
+ base_start_locs = np.stack([corner_start_loc,
130
+ inner_start_loc,
131
+ edge_e_start_loc,
132
+ edge_f_start_loc])
133
+
134
+ # Construct all start locations by rotating around (0,0)
135
+ # by 90 degrees, 4 times, and concatenating the results
136
+ rot_mats = []
137
+ for theta in -np.arange(4) * 90 / 180 * np.pi:
138
+ rot_mat = np.array([[np.cos(theta), -np.sin(theta)],
139
+ [np.sin(theta), np.cos(theta)]])
140
+ rot_mats.append(rot_mat)
141
+ rot_mats = np.stack(rot_mats)
142
+ start_locs = einsum(base_start_locs, rot_mats,
143
+ 'start i, rot j i -> start rot j')
144
+ start_locs = rearrange(start_locs,
145
+ 'start rot j -> (start rot) j')
146
+
147
+ # Add rotation information to start locations
148
+ thetas = np.tile(np.arange(4) * -90, 4)[:, None]
149
+ start_locs = np.concatenate([start_locs, thetas], axis=1)
150
+
151
+ # Get explicit permutation of pieces from permutation metadata
152
+ perm = self.piece_perms + np.repeat(np.arange(4), 4) * 4
153
+ for edge_idx, to_swap in enumerate(self.edge_swaps):
154
+ if to_swap:
155
+ # Make swap permutation array
156
+ swap_perm = np.arange(16)
157
+ swap_perm[8 + edge_idx], swap_perm[12 + edge_idx] = \
158
+ swap_perm[12 + edge_idx], swap_perm[8 + edge_idx]
159
+
160
+ # Apply swap permutation after perm
161
+ perm = np.array([swap_perm[perm[i]] for i in range(16)])
162
+
163
+ # Get inverse perm (the actual permutation needed)...
164
+ perm_inv = get_inv_perm(torch.tensor(perm))
165
+
166
+ # ...and use it to get the final locations of pieces
167
+ end_locs = start_locs[perm_inv]
168
+
169
+ # Convert start and end locations to pixel coordinate system
170
+ start_locs[:,:2] = (start_locs[:,:2] + 2) * 64
171
+ end_locs[:,:2] = (end_locs[:,:2] + 2) * 64
172
+
173
+ # Add offset so pieces are centered on canvas
174
+ start_locs[:,:2] = start_locs[:,:2] + (canvas_size - im_size) // 2
175
+ end_locs[:,:2] = end_locs[:,:2] + (canvas_size - im_size) // 2
176
+
177
+ # Get random offsets from middle for spline knot (so path is pretty)
178
+ # Wrapped in a set seed
179
+ original_state = np.random.get_state()
180
+ np.random.seed(knot_seed)
181
+ rand_offsets = np.random.rand(16, 1) * 2 - 1
182
+ rand_offsets = rand_offsets * 2
183
+ eps = np.random.randn(16, 2) # Add epsilon for divide by zero
184
+ np.random.set_state(original_state)
185
+
186
+ # Make spline knots by taking average of start and end,
187
+ # and offsetting by some amount normal from the line
188
+ avg_locs = (start_locs[:, :2] + end_locs[:, :2]) / 2.
189
+ norm = (end_locs[:, :2] - start_locs[:, :2])
190
+ norm = norm + eps
191
+ norm = norm / np.linalg.norm(norm, axis=1, keepdims=True)
192
+ rot_mat = np.array([[0,1], [-1,0]])
193
+ norm = norm @ rot_mat
194
+ rand_offsets = rand_offsets * (im_size / 4)
195
+ knot_locs = avg_locs + norm * rand_offsets
196
+
197
+ # Paste pieces on to a canvas
198
+ canvas = Image.new("RGBA", (canvas_size, canvas_size), (255,255,255,255))
199
+ for i in range(16):
200
+ # Get start and end coords
201
+ y_0, x_0, theta_0 = start_locs[i]
202
+ y_1, x_1, theta_1 = end_locs[i]
203
+ y_k, x_k = knot_locs[i]
204
+
205
+ # Take spline interpolation for x and y
206
+ x_int_0 = x_0 * (1-t) + x_k * t
207
+ y_int_0 = y_0 * (1-t) + y_k * t
208
+ x_int_1 = x_k * (1-t) + x_1 * t
209
+ y_int_1 = y_k * (1-t) + y_1 * t
210
+ x = int(np.round(x_int_0 * (1-t) + x_int_1 * t))
211
+ y = int(np.round(y_int_0 * (1-t) + y_int_1 * t))
212
+
213
+ # Just take normal interpolation for theta
214
+ theta = int(np.round(theta_0 * (1-t) + theta_1 * t))
215
+
216
+ # Get piece in location and rotation
217
+ xc = yc = im_size // 4 // 2
218
+ pasted_piece = self.paste_piece(pieces[i], x, y, theta, xc, yc)
219
+
220
+ canvas.paste(pasted_piece, (0,0), pasted_piece)
221
+
222
+ return canvas
visual_anagrams/views/view_negate.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+
4
+ import torch
5
+
6
+ from .view_base import BaseView
7
+
8
+ class NegateView(BaseView):
9
+ def __init__(self):
10
+ pass
11
+
12
+ def view(self, im):
13
+ return -im
14
+
15
+ def inverse_view(self, noise):
16
+ '''
17
+ Negating the variance estimate is "weird" so just don't do it.
18
+ This hack seems to work just fine
19
+ '''
20
+ invert_mask = torch.ones_like(noise)
21
+ invert_mask[:3] = -1
22
+ return noise * invert_mask
23
+
24
+ def make_frame(self, im, t):
25
+ im_size = im.size[0]
26
+ frame_size = int(im_size * 1.5)
27
+
28
+ # map t from [0, 1] -> [1, -1]
29
+ t = 1 - t
30
+ t = t * 2 - 1
31
+
32
+ # Interpolate from pixels from [0, 1] to [1, 0]
33
+ im = np.array(im) / 255.
34
+ im = ((2 * im - 1) * t + 1) / 2.
35
+ im = Image.fromarray((im * 255.).astype(np.uint8))
36
+
37
+ # Paste on to canvas
38
+ frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255))
39
+ frame.paste(im, ((frame_size - im_size) // 2, (frame_size - im_size) // 2))
40
+
41
+ return frame
visual_anagrams/views/view_patch_permute.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torchvision.transforms.functional as TF
6
+ from einops import rearrange
7
+
8
+ from .permutations import get_inv_perm
9
+ from .view_base import BaseView
10
+
11
+
12
+ class PatchPermuteView(BaseView):
13
+ def __init__(self, num_patches=8):
14
+ '''
15
+ Implements random patch permutations, with `num_patches`
16
+ patches per side
17
+
18
+ num_patches (int) :
19
+ Number of patches in one dimension. Total number
20
+ of patches will be num_patches**2. Should be a power of 2.
21
+ '''
22
+
23
+ assert 64 % num_patches == 0 and 256 % num_patches == 0, \
24
+ "`num_patches` must divide image side lengths of 64 and 256"
25
+
26
+ self.num_patches = num_patches
27
+
28
+ # Get random permutation and inverse permutation
29
+ self.perm = torch.randperm(self.num_patches**2)
30
+ self.perm_inv = get_inv_perm(self.perm)
31
+
32
+ def view(self, im):
33
+ im_size = im.shape[-1]
34
+
35
+ # Get number of pixels on one side of a patch
36
+ patch_size = int(im_size / self.num_patches)
37
+
38
+ # Reshape into patches of size (c, patch_size, patch_size)
39
+ patches = rearrange(im,
40
+ 'c (h p1) (w p2) -> (h w) c p1 p2',
41
+ p1=patch_size,
42
+ p2=patch_size)
43
+
44
+ # Permute
45
+ patches = patches[self.perm]
46
+
47
+ # Reshape back into image
48
+ im_rearr = rearrange(patches,
49
+ '(h w) c p1 p2 -> c (h p1) (w p2)',
50
+ h=self.num_patches,
51
+ w=self.num_patches,
52
+ p1=patch_size,
53
+ p2=patch_size)
54
+ return im_rearr
55
+
56
+ def inverse_view(self, noise):
57
+ im_size = noise.shape[-1]
58
+
59
+ # Get number of pixels on one side of a patch
60
+ patch_size = int(im_size / self.num_patches)
61
+
62
+ # Reshape into patches of size (c, patch_size, patch_size)
63
+ patches = rearrange(noise,
64
+ 'c (h p1) (w p2) -> (h w) c p1 p2',
65
+ p1=patch_size,
66
+ p2=patch_size)
67
+
68
+ # Apply inverse permutation
69
+ patches = patches[self.perm_inv]
70
+
71
+ # Reshape back into image
72
+ im_rearr = rearrange(patches,
73
+ '(h w) c p1 p2 -> c (h p1) (w p2)',
74
+ h=self.num_patches,
75
+ w=self.num_patches,
76
+ p1=patch_size,
77
+ p2=patch_size)
78
+ return im_rearr
79
+
80
+ def make_frame(self, im, t, canvas_size=384, scale=4, knot_seed=0):
81
+ '''
82
+ Scale is a hack, because PIL for some reason doesn't support pasting
83
+ at floating point coordinates. So just render at larger scale
84
+ and resize by 1/scale
85
+ '''
86
+ # Get useful info
87
+ im_size = im.size[0]
88
+ offset = (canvas_size - im_size) // 2 # offset to center animation
89
+
90
+ canvas_size = canvas_size * scale
91
+ offset = offset * scale
92
+
93
+ im = TF.to_tensor(im)
94
+
95
+ # Get number of pixels on one side of a patch
96
+ im_size = im.shape[-1]
97
+ patch_size = int(im_size / self.num_patches)
98
+
99
+ # Extract patches
100
+ patches = rearrange(im,
101
+ 'c (h p1) (w p2) -> (h w) c p1 p2',
102
+ p1=patch_size,
103
+ p2=patch_size)
104
+
105
+ # Get start locations (top left corner of patch)
106
+ yy, xx = torch.meshgrid(
107
+ torch.arange(self.num_patches),
108
+ torch.arange(self.num_patches)
109
+ )
110
+ xx = xx.flatten()
111
+ yy = yy.flatten()
112
+ start_locs = torch.stack([xx, yy], dim=1) * patch_size * scale
113
+ start_locs = start_locs + offset
114
+
115
+ # Get end locations by permuting
116
+ end_locs = start_locs[self.perm]
117
+
118
+ # Get random anchor locations
119
+ original_state = np.random.get_state()
120
+ np.random.seed(knot_seed)
121
+ rand_offsets = np.random.rand(self.num_patches**2, 1) * 2 - 1
122
+ rand_offsets = rand_offsets * 2 * scale
123
+ eps = np.random.randn(*start_locs.shape) # Add epsilon for divide by zero
124
+ np.random.set_state(original_state)
125
+
126
+ # Make spline knots by taking average of start and end,
127
+ # and offsetting by some amount normal from the line
128
+ avg_locs = (start_locs + end_locs) / 2.
129
+ norm = (end_locs - start_locs)
130
+ norm = norm + eps
131
+ norm = norm / np.linalg.norm(norm, axis=1, keepdims=True)
132
+ rot_mat = np.array([[0,1], [-1,0]])
133
+ norm = norm @ rot_mat
134
+ rand_offsets = rand_offsets * (im_size / 4)
135
+ knot_locs = avg_locs + norm * rand_offsets
136
+
137
+ # Get paste locations
138
+ spline_0 = start_locs * (1 - t) + knot_locs * t
139
+ spline_1 = knot_locs * (1 - t) + end_locs * t
140
+ paste_locs = spline_0 * (1 - t) + spline_1 * t
141
+ paste_locs = paste_locs.to(int)
142
+
143
+ # Paste patches onto canvas
144
+ canvas = Image.new("RGBA", (canvas_size, canvas_size), (255,255,255,255))
145
+ for patch, paste_loc in zip(patches, paste_locs):
146
+ patch = TF.to_pil_image(patch).convert('RGBA')
147
+ patch = patch.resize((patch_size * scale, patch_size * scale))
148
+ paste_loc = (paste_loc[0].item(), paste_loc[1].item())
149
+ canvas.paste(patch, paste_loc, patch)
150
+
151
+ if scale != 1.0:
152
+ canvas = canvas.resize((canvas_size // scale, canvas_size // scale))
153
+
154
+ return canvas
visual_anagrams/views/view_permute.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+
4
+ from .permutations import get_inv_perm
5
+ from .view_base import BaseView
6
+
7
+ class PermuteView(BaseView):
8
+ def __init__(self, perm_64, perm_256):
9
+ '''
10
+ Implements arbitrary pixel permutations, for a given permutation.
11
+ We need two permutations. One of size 64x64 for stage 1, and
12
+ one of size 256x256 for stage 2.
13
+
14
+ perm_64 (torch.tensor) :
15
+ Tensor of integer indexes, defining a permutation, of size 64*64
16
+
17
+ perm_256 (torch.tensor) :
18
+ Tensor of integer indexes, defining a permutation, of size 256*256
19
+ '''
20
+
21
+ assert perm_64.shape == torch.Size([64*64]), \
22
+ "`perm_64` must be a permutation tensor of size 64*64"
23
+
24
+ assert perm_256.shape == torch.Size([256*256]), \
25
+ "`perm_256` must be a permutation tensor of size 256*256"
26
+
27
+ # Get random permutation and inverse permutation for stage 1
28
+ self.perm_64 = perm_64
29
+ self.perm_64_inv = get_inv_perm(self.perm_64)
30
+
31
+ # Get random permutation and inverse permutation for stage 2
32
+ self.perm_256 = perm_256
33
+ self.perm_256_inv = get_inv_perm(self.perm_256)
34
+
35
+ def view(self, im):
36
+ im_size = im.shape[-1]
37
+ perm = self.perm_64 if im_size == 64 else self.perm_256
38
+ num_patches = im_size
39
+
40
+ # Permute every pixel in the image
41
+ patch_size = 1
42
+
43
+ # Reshape into patches of size (c, patch_size, patch_size)
44
+ patches = rearrange(im,
45
+ 'c (h p1) (w p2) -> (h w) c p1 p2',
46
+ p1=patch_size,
47
+ p2=patch_size)
48
+
49
+ # Permute
50
+ patches = patches[perm]
51
+
52
+ # Reshape back into image
53
+ im_rearr = rearrange(patches,
54
+ '(h w) c p1 p2 -> c (h p1) (w p2)',
55
+ h=num_patches,
56
+ w=num_patches,
57
+ p1=patch_size,
58
+ p2=patch_size)
59
+ return im_rearr
60
+
61
+ def inverse_view(self, noise):
62
+ im_size = noise.shape[-1]
63
+ perm_inv = self.perm_64_inv if im_size == 64 else self.perm_256_inv
64
+ num_patches = im_size
65
+
66
+ # Permute every pixel in the image
67
+ patch_size = 1
68
+
69
+ # Reshape into patches of size (c, patch_size, patch_size)
70
+ patches = rearrange(noise,
71
+ 'c (h p1) (w p2) -> (h w) c p1 p2',
72
+ p1=patch_size,
73
+ p2=patch_size)
74
+
75
+ # Apply inverse permutation
76
+ patches = patches[perm_inv]
77
+
78
+ # Reshape back into image
79
+ im_rearr = rearrange(patches,
80
+ '(h w) c p1 p2 -> c (h p1) (w p2)',
81
+ h=num_patches,
82
+ w=num_patches,
83
+ p1=patch_size,
84
+ p2=patch_size)
85
+ return im_rearr
86
+
87
+ def make_frame(self, im, t):
88
+ # TODO: Implement this, as just moving pixels around
89
+ raise NotImplementedError()
90
+
91
+
visual_anagrams/views/view_rotate.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+
3
+ import torchvision.transforms.functional as TF
4
+ from torchvision.transforms import InterpolationMode
5
+
6
+ from .view_base import BaseView
7
+
8
+
9
+ class Rotate90CWView(BaseView):
10
+ def __init__(self):
11
+ pass
12
+
13
+ def view(self, im):
14
+ # TODO: Is nearest-exact better?
15
+ return TF.rotate(im, -90, interpolation=InterpolationMode.NEAREST)
16
+
17
+ def inverse_view(self, noise):
18
+ return TF.rotate(noise, 90, interpolation=InterpolationMode.NEAREST)
19
+
20
+ def make_frame(self, im, t):
21
+ im_size = im.size[0]
22
+ frame_size = int(im_size * 1.5)
23
+ theta = t * -90
24
+
25
+ frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255))
26
+ centered_loc = (frame_size - im_size) // 2
27
+ frame.paste(im, (centered_loc, centered_loc))
28
+ frame = frame.rotate(theta,
29
+ resample=Image.Resampling.BILINEAR,
30
+ expand=False,
31
+ fillcolor=(255,255,255))
32
+
33
+ return frame
34
+
35
+
36
+ class Rotate90CCWView(BaseView):
37
+ def __init__(self):
38
+ pass
39
+
40
+ def view(self, im):
41
+ # TODO: Is nearest-exact better?
42
+ return TF.rotate(im, 90, interpolation=InterpolationMode.NEAREST)
43
+
44
+ def inverse_view(self, noise):
45
+ return TF.rotate(noise, -90, interpolation=InterpolationMode.NEAREST)
46
+
47
+ def make_frame(self, im, t):
48
+ im_size = im.size[0]
49
+ frame_size = int(im_size * 1.5)
50
+ theta = t * 90
51
+
52
+ frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255))
53
+ centered_loc = (frame_size - im_size) // 2
54
+ frame.paste(im, (centered_loc, centered_loc))
55
+ frame = frame.rotate(theta,
56
+ resample=Image.Resampling.BILINEAR,
57
+ expand=False,
58
+ fillcolor=(255,255,255))
59
+
60
+ return frame
61
+
62
+
63
+ class Rotate180View(BaseView):
64
+ def __init__(self):
65
+ pass
66
+
67
+ def view(self, im):
68
+ # TODO: Is nearest-exact better?
69
+ return TF.rotate(im, 180, interpolation=InterpolationMode.NEAREST)
70
+
71
+ def inverse_view(self, noise):
72
+ return TF.rotate(noise, -180, interpolation=InterpolationMode.NEAREST)
73
+
74
+ def make_frame(self, im, t):
75
+ im_size = im.size[0]
76
+ frame_size = int(im_size * 1.5)
77
+ theta = t * 180
78
+
79
+ frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255))
80
+ centered_loc = (frame_size - im_size) // 2
81
+ frame.paste(im, (centered_loc, centered_loc))
82
+ frame = frame.rotate(theta,
83
+ resample=Image.Resampling.BILINEAR,
84
+ expand=False,
85
+ fillcolor=(255,255,255))
86
+
87
+ return frame
visual_anagrams/views/view_skew.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+
4
+ import torch
5
+
6
+ from .view_base import BaseView
7
+
8
+
9
+ class SkewView(BaseView):
10
+ def __init__(self, skew_factor=1.5):
11
+ self.skew_factor = skew_factor
12
+
13
+ def skew_image(self, im, skew_factor):
14
+ '''
15
+ Roll each column of the image by increasing displacements.
16
+ This is a permutation of pixels
17
+ '''
18
+
19
+ # Params
20
+ c,h,w = im.shape
21
+ h_center = h//2
22
+
23
+ # Roll columns
24
+ cols = []
25
+ for i in range(w):
26
+ d = int(skew_factor * (i - h_center)) # Displacement
27
+ col = im[:,:,i]
28
+ cols.append(col.roll(d, dims=1))
29
+
30
+ # Stack rolled columns
31
+ skewed = torch.stack(cols, dim=2)
32
+ return skewed
33
+
34
+ def view(self, im):
35
+ return self.skew_image(im, self.skew_factor)
36
+
37
+ def inverse_view(self, noise):
38
+ return self.skew_image(noise, -self.skew_factor)
39
+
40
+ def make_frame(self, im, t):
41
+ im_size = im.size[0]
42
+ frame_size = int(im_size * 1.5)
43
+ skew_factor = t * self.skew_factor
44
+
45
+ # Convert to tensor, skew, then convert back to PIL
46
+ im = torch.tensor(np.array(im) / 255.).permute(2,0,1)
47
+ im = self.skew_image(im, skew_factor)
48
+ im = Image.fromarray((np.array(im.permute(1,2,0)) * 255.).astype(np.uint8))
49
+
50
+ # Paste on to canvas
51
+ frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255))
52
+ frame.paste(im, ((frame_size - im_size) // 2, (frame_size - im_size) // 2))
53
+
54
+ return frame
55
+