Spaces:
Runtime error
Runtime error
solved merge
Browse files- .gitignore +2 -1
- DisentanglementBase.py +16 -5
- Home.py +13 -9
- README.md +2 -1
- backend/disentangle_concepts.py +14 -4
- config.toml +8 -0
- data/grammar_ornaments/1_colors_generally.png +3 -0
- data/grammar_ornaments/2_proportions_and_contrasts.png +3 -0
- data/grammar_ornaments/3_positions_simultanous.png +3 -0
- data/grammar_ornaments/4_juxtapositions.png +3 -0
- grammar_ornament_test.ipynb +256 -0
- interfacegan_colour_disentanglement.ipynb +145 -16
- pages/{1_Textiles_Disentanglement.py → 1_Textiles_Manipulation.py} +51 -41
- pages/{2_Colours_comparison.py → 2_Network_comparison.py} +45 -34
- pages/3_Vectors_algebra.py +163 -0
- pyproject.toml +0 -0
- structure_annotations.ipynb +0 -0
- test-docker.sh +743 -0
.gitignore
CHANGED
@@ -189,4 +189,5 @@ cython_debug/
|
|
189 |
data/images/
|
190 |
tmp/
|
191 |
figures/
|
192 |
-
archive/
|
|
|
|
189 |
data/images/
|
190 |
tmp/
|
191 |
figures/
|
192 |
+
archive/
|
193 |
+
segment-anything/
|
DisentanglementBase.py
CHANGED
@@ -181,12 +181,20 @@ class DisentanglementBase:
|
|
181 |
bins = [(x-1) * 360 / (len(self.colors_list) - 1) if x != 1
|
182 |
else 1 for x in range(len(self.colors_list) + 1)]
|
183 |
bins[0] = 0
|
|
|
184 |
y_cat = pd.cut(y,
|
185 |
bins=bins,
|
186 |
labels=self.colors_list,
|
187 |
include_lowest=True
|
188 |
)
|
189 |
print(y_cat.value_counts())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
x_train, x_val, y_train, y_val = train_test_split(X, y_cat, test_size=0.2)
|
191 |
else:
|
192 |
if extremes:
|
@@ -567,11 +575,14 @@ def main():
|
|
567 |
with dnnlib.util.open_url(model_file) as f:
|
568 |
model = legacy.load_network_pkl(f)['G_ema'] # type: ignore
|
569 |
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
|
|
|
|
|
|
575 |
|
576 |
scores = []
|
577 |
kwargs = {'CL method':['LR', 'SVM'], 'C':[0.1, 1], 'sign':[True, False],
|
|
|
181 |
bins = [(x-1) * 360 / (len(self.colors_list) - 1) if x != 1
|
182 |
else 1 for x in range(len(self.colors_list) + 1)]
|
183 |
bins[0] = 0
|
184 |
+
|
185 |
y_cat = pd.cut(y,
|
186 |
bins=bins,
|
187 |
labels=self.colors_list,
|
188 |
include_lowest=True
|
189 |
)
|
190 |
print(y_cat.value_counts())
|
191 |
+
|
192 |
+
y_h_cat[y_s == 0] = 'Gray'
|
193 |
+
y_h_cat[y_s == 100] = 'Gray'
|
194 |
+
y_h_cat[y_v == 0] = 'Gray'
|
195 |
+
y_h_cat[y_v == 100] = 'Gray'
|
196 |
+
|
197 |
+
print(y_cat.value_counts())
|
198 |
x_train, x_val, y_train, y_val = train_test_split(X, y_cat, test_size=0.2)
|
199 |
else:
|
200 |
if extremes:
|
|
|
575 |
with dnnlib.util.open_url(model_file) as f:
|
576 |
model = legacy.load_network_pkl(f)['G_ema'] # type: ignore
|
577 |
|
578 |
+
|
579 |
+
# colors_list = ['Red', 'Orange', 'Yellow', 'Yellow Green', 'Chartreuse Green',
|
580 |
+
# 'Kelly Green', 'Green Blue Seafoam', 'Cyan Blue',
|
581 |
+
# 'Warm Blue', 'Indigo', 'Purple Magenta', 'Magenta Pink']
|
582 |
+
# colors_list = ['Gray', 'Red Orange', 'Yellow', 'Green', 'Light Blue',
|
583 |
+
# 'Blue', 'Purple', 'Pink']
|
584 |
+
colors_list = ['Gray', 'Red', 'Yellow', 'Green', 'Cyan',
|
585 |
+
'Blue', 'Magenta']
|
586 |
|
587 |
scores = []
|
588 |
kwargs = {'CL method':['LR', 'SVM'], 'C':[0.1, 1], 'sign':[True, False],
|
Home.py
CHANGED
@@ -6,24 +6,28 @@ st.set_page_config(layout='wide')
|
|
6 |
|
7 |
st.title('About')
|
8 |
|
|
|
9 |
# INTRO
|
10 |
-
intro_text = """
|
|
|
11 |
The thesis concentrates mostly on the second part, exploring different avenues for understanding the space. Using a multitude of vision generative models, it discusses possibilities for the systematic exploration of space, including disentanglement properties and coverage of various guidance methods.
|
12 |
It also explores the possibility of comparison across latent spaces and investigates the differences and commonalities across different learning experiments. Furthermore, the thesis investigates the role of stochasticity in newer models.
|
13 |
As a case study, this thesis adopts art historical data, spanning classic art, photography, and modern and contemporary art.
|
14 |
-
|
15 |
-
The project aims to interpret the StyleGAN2 model by several techniques.
|
16 |
-
> “What concepts are disentangled in the latent space of StyleGAN2”\n
|
17 |
-
> “Can we quantify the complexity of such concepts?”.
|
18 |
-
|
19 |
"""
|
20 |
st.write(intro_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
# 4 PAGES
|
23 |
st.subheader('Pages')
|
24 |
-
sections_text = """Overall, there are
|
25 |
-
1)
|
26 |
-
2)
|
|
|
27 |
...
|
28 |
"""
|
29 |
st.write(sections_text)
|
|
|
6 |
|
7 |
st.title('About')
|
8 |
|
9 |
+
st.subheader('General aim of the Ph.D. (to be updated)')
|
10 |
# INTRO
|
11 |
+
intro_text = """
|
12 |
+
This project investigates the nature and nurture of latent spaces, with the aim of formulating a theory of this particular vectorial space. It draws together reflections on the inherent constraints of latent spaces in particular architectures and considers the learning-specific features that emerge.
|
13 |
The thesis concentrates mostly on the second part, exploring different avenues for understanding the space. Using a multitude of vision generative models, it discusses possibilities for the systematic exploration of space, including disentanglement properties and coverage of various guidance methods.
|
14 |
It also explores the possibility of comparison across latent spaces and investigates the differences and commonalities across different learning experiments. Furthermore, the thesis investigates the role of stochasticity in newer models.
|
15 |
As a case study, this thesis adopts art historical data, spanning classic art, photography, and modern and contemporary art.
|
|
|
|
|
|
|
|
|
|
|
16 |
"""
|
17 |
st.write(intro_text)
|
18 |
+
st.subheader('On this experiment')
|
19 |
+
st.write(
|
20 |
+
"""The project aims to interpret the StyleGAN3 model trained on Textiles using disentanglement methods.
|
21 |
+
> “What features are disentangled in the latent space of StyleGAN3”\n
|
22 |
+
> “Can we quantify the complexity, quality and relations of such features?”.
|
23 |
+
""")
|
24 |
|
25 |
# 4 PAGES
|
26 |
st.subheader('Pages')
|
27 |
+
sections_text = """Overall, there are 3 features in this web app:
|
28 |
+
1) Textiles manipulation
|
29 |
+
2) Features comparison
|
30 |
+
3) Vectors algebra manipulation
|
31 |
...
|
32 |
"""
|
33 |
st.write(sections_text)
|
README.md
CHANGED
@@ -13,4 +13,5 @@ pinned: false
|
|
13 |
|
14 |
To be change name: latent-space-theory
|
15 |
|
16 |
-
This app was built with Streamlit. To run the app, `streamlit run Home.py` in the terminal.
|
|
|
|
13 |
|
14 |
To be change name: latent-space-theory
|
15 |
|
16 |
+
This app was built with Streamlit. To run the app, `streamlit run Home.py` in the terminal.
|
17 |
+
python -m streamlit run Home.py
|
backend/disentangle_concepts.py
CHANGED
@@ -7,7 +7,7 @@ from PIL import Image
|
|
7 |
|
8 |
|
9 |
|
10 |
-
def generate_composite_images(model, z, decision_boundaries, lambdas, latent_space='W'):
|
11 |
"""
|
12 |
The regenerate_images function takes a model, z, and decision_boundary as input. It then
|
13 |
constructs an inverse rotation/translation matrix and passes it to the generator. The generator
|
@@ -33,9 +33,19 @@ def generate_composite_images(model, z, decision_boundaries, lambdas, latent_spa
|
|
33 |
repetitions = 16
|
34 |
z_0 = z
|
35 |
|
36 |
-
|
37 |
-
decision_boundary
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
|
41 |
if latent_space == 'Z':
|
|
|
7 |
|
8 |
|
9 |
|
10 |
+
def generate_composite_images(model, z, decision_boundaries, lambdas, latent_space='W', negative_colors=None):
|
11 |
"""
|
12 |
The regenerate_images function takes a model, z, and decision_boundary as input. It then
|
13 |
constructs an inverse rotation/translation matrix and passes it to the generator. The generator
|
|
|
33 |
repetitions = 16
|
34 |
z_0 = z
|
35 |
|
36 |
+
if negative_colors:
|
37 |
+
for decision_boundary, lmbd, neg_boundary in zip(decision_boundaries, lambdas, negative_colors):
|
38 |
+
decision_boundary = torch.from_numpy(decision_boundary.copy()).to(device)
|
39 |
+
if neg_boundary != 'None':
|
40 |
+
neg_boundary = torch.from_numpy(neg_boundary.copy()).to(device)
|
41 |
+
|
42 |
+
z_0 = z_0 + int(lmbd) * (decision_boundary - (neg_boundary.T * decision_boundary) * neg_boundary)
|
43 |
+
else:
|
44 |
+
z_0 = z_0 + int(lmbd) * decision_boundary
|
45 |
+
else:
|
46 |
+
for decision_boundary, lmbd in zip(decision_boundaries, lambdas):
|
47 |
+
decision_boundary = torch.from_numpy(decision_boundary.copy()).to(device)
|
48 |
+
z_0 = z_0 + int(lmbd) * decision_boundary
|
49 |
|
50 |
|
51 |
if latent_space == 'Z':
|
config.toml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[server]
|
2 |
+
enableCORS = false
|
3 |
+
headless = true
|
4 |
+
|
5 |
+
[browser]
|
6 |
+
serverAddress = 0.0.0.0
|
7 |
+
gatherUsageStats = false
|
8 |
+
serverPort = 8501
|
data/grammar_ornaments/1_colors_generally.png
ADDED
Git LFS Details
|
data/grammar_ornaments/2_proportions_and_contrasts.png
ADDED
Git LFS Details
|
data/grammar_ornaments/3_positions_simultanous.png
ADDED
Git LFS Details
|
data/grammar_ornaments/4_juxtapositions.png
ADDED
Git LFS Details
|
grammar_ornament_test.ipynb
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import os \n",
|
10 |
+
"from glob import glob \n",
|
11 |
+
"import pandas as pd\n",
|
12 |
+
"import numpy as np\n",
|
13 |
+
"\n",
|
14 |
+
"from PIL import Image, ImageColor\n",
|
15 |
+
"import extcolors\n",
|
16 |
+
"\n",
|
17 |
+
"import matplotlib.pyplot as plt\n",
|
18 |
+
"\n",
|
19 |
+
"import torch\n",
|
20 |
+
"\n",
|
21 |
+
"import dnnlib \n",
|
22 |
+
"import legacy\n",
|
23 |
+
"\n",
|
24 |
+
"\n",
|
25 |
+
"%load_ext autoreload\n",
|
26 |
+
"%autoreload 2"
|
27 |
+
]
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"cell_type": "code",
|
31 |
+
"execution_count": null,
|
32 |
+
"metadata": {},
|
33 |
+
"outputs": [],
|
34 |
+
"source": [
|
35 |
+
"images_textiles = glob('/Users/ludovicaschaerf/Desktop/TextAIles/TextileGAN/Original Textiles/*')"
|
36 |
+
]
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"cell_type": "markdown",
|
40 |
+
"metadata": {},
|
41 |
+
"source": [
|
42 |
+
"### LAWS\n",
|
43 |
+
"\n",
|
44 |
+
"1. primary colours on small surfaces and secondary or tertiary colors on large backgrounds\n",
|
45 |
+
"2. primary in upper portions and sec/third in lower portions of objects\n",
|
46 |
+
"3. primaries of equal intensities harmonize, secondaries harmonized by opposite primary in equal intensity, tertiary by remaining secondary\n",
|
47 |
+
"4. a full colors contrasted by a lower tone color should have the latter in larger proportion\n",
|
48 |
+
"5. when a primary has a hue (second coloration) of another primary, the secondary must have the hue of the third primary\n",
|
49 |
+
"6. blue in concave surfaces, yellow in convex, red in undersites\n",
|
50 |
+
"7. if too much of a color, the other colors should have the hue version without that color\n",
|
51 |
+
"8. all three primaries should be present\n",
|
52 |
+
"9. ..."
|
53 |
+
]
|
54 |
+
},
|
55 |
+
{
|
56 |
+
"cell_type": "markdown",
|
57 |
+
"metadata": {},
|
58 |
+
"source": [
|
59 |
+
"Test 1\n",
|
60 |
+
"\n",
|
61 |
+
"primary - secondary - tertiary "
|
62 |
+
]
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"cell_type": "code",
|
66 |
+
"execution_count": null,
|
67 |
+
"metadata": {},
|
68 |
+
"outputs": [],
|
69 |
+
"source": [
|
70 |
+
"def get_color_rank(hue, saturation, value):\n",
|
71 |
+
" if value < 5:\n",
|
72 |
+
" color = 'Black'\n",
|
73 |
+
" rank = 'None'\n",
|
74 |
+
" elif saturation < 3:\n",
|
75 |
+
" color = 'White'\n",
|
76 |
+
" rank = 'None'\n",
|
77 |
+
" elif saturation < 15:\n",
|
78 |
+
" color = 'Gray'\n",
|
79 |
+
" rank = 'None'\n",
|
80 |
+
" elif hue == 0:\n",
|
81 |
+
" color = 'Gray'\n",
|
82 |
+
" rank = 'None'\n",
|
83 |
+
" \n",
|
84 |
+
" elif hue >= 330 or hue <= 15:\n",
|
85 |
+
" color = 'Red'\n",
|
86 |
+
" rank = 'Primary'\n",
|
87 |
+
" elif hue > 15 and hue < 25:\n",
|
88 |
+
" color = 'Red Orange'\n",
|
89 |
+
" rank = 'Tertiary'\n",
|
90 |
+
" elif hue >= 25 and hue <= 40:\n",
|
91 |
+
" color = 'Orange'\n",
|
92 |
+
" rank = 'Secondary'\n",
|
93 |
+
" elif hue > 40 and hue < 50:\n",
|
94 |
+
" color = 'Orange Yellow'\n",
|
95 |
+
" rank = 'Tertiary'\n",
|
96 |
+
" elif hue >= 50 and hue <= 85:\n",
|
97 |
+
" color = 'Yellow'\n",
|
98 |
+
" rank = 'Primary'\n",
|
99 |
+
" elif hue > 85 and hue < 95:\n",
|
100 |
+
" color = 'Yellow Green'\n",
|
101 |
+
" rank = 'Tertiary'\n",
|
102 |
+
" elif hue >= 95 and hue <= 145:\n",
|
103 |
+
" color = 'Green'\n",
|
104 |
+
" rank = 'Secondary'\n",
|
105 |
+
" elif hue >= 145 and hue < 180:\n",
|
106 |
+
" color = 'Green Blue'\n",
|
107 |
+
" rank = 'Tertiary'\n",
|
108 |
+
" elif hue >= 180 and hue <= 245:\n",
|
109 |
+
" color = 'Blue'\n",
|
110 |
+
" rank = 'Primary'\n",
|
111 |
+
" elif hue > 245 and hue < 265:\n",
|
112 |
+
" color = 'Blue Violet'\n",
|
113 |
+
" rank = 'Tertiary'\n",
|
114 |
+
" elif hue >= 265 and hue <= 290:\n",
|
115 |
+
" color = 'Violet'\n",
|
116 |
+
" rank = 'Secondary'\n",
|
117 |
+
" elif hue > 290 and hue < 330:\n",
|
118 |
+
" color = 'Violet Red'\n",
|
119 |
+
" rank = 'Tertiary'\n",
|
120 |
+
" return color, rank"
|
121 |
+
]
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"cell_type": "code",
|
125 |
+
"execution_count": null,
|
126 |
+
"metadata": {},
|
127 |
+
"outputs": [],
|
128 |
+
"source": [
|
129 |
+
"def rgb2hsv(r, g, b):\n",
|
130 |
+
" # Normalize R, G, B values\n",
|
131 |
+
" r, g, b = r / 255.0, g / 255.0, b / 255.0\n",
|
132 |
+
" \n",
|
133 |
+
" # h, s, v = hue, saturation, value\n",
|
134 |
+
" max_rgb = max(r, g, b) \n",
|
135 |
+
" min_rgb = min(r, g, b) \n",
|
136 |
+
" difference = max_rgb-min_rgb \n",
|
137 |
+
" \n",
|
138 |
+
" # if max_rgb and max_rgb are equal then h = 0\n",
|
139 |
+
" if max_rgb == min_rgb:\n",
|
140 |
+
" h = 0\n",
|
141 |
+
" \n",
|
142 |
+
" # if max_rgb==r then h is computed as follows\n",
|
143 |
+
" elif max_rgb == r:\n",
|
144 |
+
" h = (60 * ((g - b) / difference) + 360) % 360\n",
|
145 |
+
" \n",
|
146 |
+
" # if max_rgb==g then compute h as follows\n",
|
147 |
+
" elif max_rgb == g:\n",
|
148 |
+
" h = (60 * ((b - r) / difference) + 120) % 360\n",
|
149 |
+
" \n",
|
150 |
+
" # if max_rgb=b then compute h\n",
|
151 |
+
" elif max_rgb == b:\n",
|
152 |
+
" h = (60 * ((r - g) / difference) + 240) % 360\n",
|
153 |
+
" \n",
|
154 |
+
" # if max_rgb==zero then s=0\n",
|
155 |
+
" if max_rgb == 0:\n",
|
156 |
+
" s = 0\n",
|
157 |
+
" else:\n",
|
158 |
+
" s = (difference / max_rgb) * 100\n",
|
159 |
+
" \n",
|
160 |
+
" # compute v\n",
|
161 |
+
" v = max_rgb * 100\n",
|
162 |
+
" # return rounded values of H, S and V\n",
|
163 |
+
" return tuple(map(round, (h, s, v)))\n",
|
164 |
+
" "
|
165 |
+
]
|
166 |
+
},
|
167 |
+
{
|
168 |
+
"cell_type": "code",
|
169 |
+
"execution_count": null,
|
170 |
+
"metadata": {},
|
171 |
+
"outputs": [],
|
172 |
+
"source": [
|
173 |
+
"def obtain_hsv_colors(img):\n",
|
174 |
+
" colors = extcolors.extract_from_path(img, tolerance=7, limit=7)\n",
|
175 |
+
" colors = [(rgb2hsv(h[0][0], h[0][1], h[0][2]), h[1]) for h in colors[0] if h[0] != (0,0,0)]\n",
|
176 |
+
" return colors"
|
177 |
+
]
|
178 |
+
},
|
179 |
+
{
|
180 |
+
"cell_type": "code",
|
181 |
+
"execution_count": null,
|
182 |
+
"metadata": {},
|
183 |
+
"outputs": [],
|
184 |
+
"source": [
|
185 |
+
"colors = obtain_hsv_colors(images_textiles[0])\n",
|
186 |
+
"print(colors)"
|
187 |
+
]
|
188 |
+
},
|
189 |
+
{
|
190 |
+
"cell_type": "code",
|
191 |
+
"execution_count": null,
|
192 |
+
"metadata": {},
|
193 |
+
"outputs": [],
|
194 |
+
"source": [
|
195 |
+
"for col in colors:\n",
|
196 |
+
" print(get_color_rank(*col[0]))"
|
197 |
+
]
|
198 |
+
},
|
199 |
+
{
|
200 |
+
"cell_type": "code",
|
201 |
+
"execution_count": null,
|
202 |
+
"metadata": {},
|
203 |
+
"outputs": [],
|
204 |
+
"source": [
|
205 |
+
"for img in images_textiles[:30]:\n",
|
206 |
+
" colors = obtain_hsv_colors(img)\n",
|
207 |
+
" plt.imshow(plt.imread(img))\n",
|
208 |
+
" plt.show()\n",
|
209 |
+
" for col in colors:\n",
|
210 |
+
" print(col[0])\n",
|
211 |
+
" print(get_color_rank(*col[0]))\n",
|
212 |
+
" \n",
|
213 |
+
" print()"
|
214 |
+
]
|
215 |
+
},
|
216 |
+
{
|
217 |
+
"cell_type": "markdown",
|
218 |
+
"metadata": {},
|
219 |
+
"source": [
|
220 |
+
"### use for training only images with medium saturation and value\n",
|
221 |
+
"\n",
|
222 |
+
"use codes and not only hue for color categorization\n",
|
223 |
+
"or remove colors that are creater with black and whites"
|
224 |
+
]
|
225 |
+
},
|
226 |
+
{
|
227 |
+
"cell_type": "code",
|
228 |
+
"execution_count": null,
|
229 |
+
"metadata": {},
|
230 |
+
"outputs": [],
|
231 |
+
"source": []
|
232 |
+
}
|
233 |
+
],
|
234 |
+
"metadata": {
|
235 |
+
"kernelspec": {
|
236 |
+
"display_name": "art-reco_x86",
|
237 |
+
"language": "python",
|
238 |
+
"name": "python3"
|
239 |
+
},
|
240 |
+
"language_info": {
|
241 |
+
"codemirror_mode": {
|
242 |
+
"name": "ipython",
|
243 |
+
"version": 3
|
244 |
+
},
|
245 |
+
"file_extension": ".py",
|
246 |
+
"mimetype": "text/x-python",
|
247 |
+
"name": "python",
|
248 |
+
"nbconvert_exporter": "python",
|
249 |
+
"pygments_lexer": "ipython3",
|
250 |
+
"version": "3.8.16"
|
251 |
+
},
|
252 |
+
"orig_nbformat": 4
|
253 |
+
},
|
254 |
+
"nbformat": 4,
|
255 |
+
"nbformat_minor": 2
|
256 |
+
}
|
interfacegan_colour_disentanglement.ipynb
CHANGED
@@ -15,6 +15,7 @@
|
|
15 |
"\n",
|
16 |
"from PIL import Image, ImageColor\n",
|
17 |
"import matplotlib.pyplot as plt\n",
|
|
|
18 |
"\n",
|
19 |
"import numpy as np\n",
|
20 |
"import torch\n",
|
@@ -33,6 +34,87 @@
|
|
33 |
"%autoreload 2"
|
34 |
]
|
35 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
{
|
37 |
"cell_type": "code",
|
38 |
"execution_count": null,
|
@@ -43,6 +125,21 @@
|
|
43 |
"num_colors = 7"
|
44 |
]
|
45 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
{
|
47 |
"cell_type": "code",
|
48 |
"execution_count": null,
|
@@ -50,8 +147,7 @@
|
|
50 |
"metadata": {},
|
51 |
"outputs": [],
|
52 |
"source": [
|
53 |
-
"
|
54 |
-
"centers = [int((values[i-1]+values[i])/2) for i in range(len(values)) if i > 0]"
|
55 |
]
|
56 |
},
|
57 |
{
|
@@ -61,7 +157,7 @@
|
|
61 |
"metadata": {},
|
62 |
"outputs": [],
|
63 |
"source": [
|
64 |
-
"print(
|
65 |
"print(centers)"
|
66 |
]
|
67 |
},
|
@@ -100,9 +196,9 @@
|
|
100 |
"metadata": {},
|
101 |
"outputs": [],
|
102 |
"source": [
|
103 |
-
"def to_256(val):\n",
|
104 |
-
"
|
105 |
-
"
|
106 |
]
|
107 |
},
|
108 |
{
|
@@ -112,9 +208,7 @@
|
|
112 |
"metadata": {},
|
113 |
"outputs": [],
|
114 |
"source": [
|
115 |
-
"names = ['
|
116 |
-
" 'Kelly Green', 'Green Blue Seafoam', 'Cyan Blue',\n",
|
117 |
-
" 'Warm Blue', 'Indigo', 'Purple Magenta', 'Magenta Pink']"
|
118 |
]
|
119 |
},
|
120 |
{
|
@@ -127,7 +221,7 @@
|
|
127 |
"saturation = 1 # Saturation value (0 to 1)\n",
|
128 |
"value = 1 # Value (brightness) value (0 to 1)\n",
|
129 |
"for hue, name in zip(centers, names[:num_colors]):\n",
|
130 |
-
" image = create_color_image(
|
131 |
" display_image(image, name) # Display the generated color image"
|
132 |
]
|
133 |
},
|
@@ -148,6 +242,37 @@
|
|
148 |
" model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore\n"
|
149 |
]
|
150 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
{
|
152 |
"cell_type": "code",
|
153 |
"execution_count": null,
|
@@ -155,7 +280,6 @@
|
|
155 |
"metadata": {},
|
156 |
"outputs": [],
|
157 |
"source": [
|
158 |
-
"ann_df = tohsv(ann_df)\n",
|
159 |
"ann_df.head()"
|
160 |
]
|
161 |
},
|
@@ -291,9 +415,7 @@
|
|
291 |
"metadata": {},
|
292 |
"outputs": [],
|
293 |
"source": [
|
294 |
-
"colors_list =
|
295 |
-
" 'Kelly Green', 'Green Blue Seafoam', 'Blue Green Cyan',\n",
|
296 |
-
" 'Warm Blue', 'Indigo Blue Purple', 'Purple Magenta', 'Magenta Pink']"
|
297 |
]
|
298 |
},
|
299 |
{
|
@@ -313,10 +435,17 @@
|
|
313 |
"source": [
|
314 |
"from sklearn import svm\n",
|
315 |
"\n",
|
316 |
-
"
|
317 |
-
"y_h_cat = pd.cut(y_h,bins=[x*256/12 if x<12 else 256 for x in range(13)],labels=colors_list).fillna('Warm Pink Red')\n",
|
318 |
"\n",
|
319 |
"print(y_h_cat.value_counts(dropna=False))\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
"x_trainhc, x_valhc, y_trainhc, y_valhc = train_test_split(X, y_h_cat, test_size=0.2)"
|
321 |
]
|
322 |
},
|
|
|
15 |
"\n",
|
16 |
"from PIL import Image, ImageColor\n",
|
17 |
"import matplotlib.pyplot as plt\n",
|
18 |
+
"from sklearn.model_selection import train_test_split\n",
|
19 |
"\n",
|
20 |
"import numpy as np\n",
|
21 |
"import torch\n",
|
|
|
34 |
"%autoreload 2"
|
35 |
]
|
36 |
},
|
37 |
+
{
|
38 |
+
"cell_type": "markdown",
|
39 |
+
"id": "03efb8c0",
|
40 |
+
"metadata": {},
|
41 |
+
"source": [
|
42 |
+
"0-60\n",
|
43 |
+
"\n",
|
44 |
+
"Red\n",
|
45 |
+
"\n",
|
46 |
+
"60-120\n",
|
47 |
+
"\n",
|
48 |
+
"Yellow\n",
|
49 |
+
"\n",
|
50 |
+
"120-180\n",
|
51 |
+
"\n",
|
52 |
+
"Green\n",
|
53 |
+
"\n",
|
54 |
+
"180-240\n",
|
55 |
+
"\n",
|
56 |
+
"Cyan\n",
|
57 |
+
"\n",
|
58 |
+
"240-300\n",
|
59 |
+
"\n",
|
60 |
+
"Blue\n",
|
61 |
+
"\n",
|
62 |
+
"300-360\n",
|
63 |
+
"\n",
|
64 |
+
"Magenta\n",
|
65 |
+
"\n",
|
66 |
+
"Standard classification"
|
67 |
+
]
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"cell_type": "code",
|
71 |
+
"execution_count": null,
|
72 |
+
"id": "00a35126",
|
73 |
+
"metadata": {},
|
74 |
+
"outputs": [],
|
75 |
+
"source": [
|
76 |
+
"def hex2rgb(hex_value):\n",
|
77 |
+
" h = hex_value.strip(\"#\") \n",
|
78 |
+
" rgb = tuple(int(h[i:i+2], 16) for i in (0, 2, 4))\n",
|
79 |
+
" return rgb\n",
|
80 |
+
"\n",
|
81 |
+
"def rgb2hsv(r, g, b):\n",
|
82 |
+
" # Normalize R, G, B values\n",
|
83 |
+
" r, g, b = r / 255.0, g / 255.0, b / 255.0\n",
|
84 |
+
" \n",
|
85 |
+
" # h, s, v = hue, saturation, value\n",
|
86 |
+
" max_rgb = max(r, g, b) \n",
|
87 |
+
" min_rgb = min(r, g, b) \n",
|
88 |
+
" difference = max_rgb-min_rgb \n",
|
89 |
+
" \n",
|
90 |
+
" # if max_rgb and max_rgb are equal then h = 0\n",
|
91 |
+
" if max_rgb == min_rgb:\n",
|
92 |
+
" h = 0\n",
|
93 |
+
" \n",
|
94 |
+
" # if max_rgb==r then h is computed as follows\n",
|
95 |
+
" elif max_rgb == r:\n",
|
96 |
+
" h = (60 * ((g - b) / difference) + 360) % 360\n",
|
97 |
+
" \n",
|
98 |
+
" # if max_rgb==g then compute h as follows\n",
|
99 |
+
" elif max_rgb == g:\n",
|
100 |
+
" h = (60 * ((b - r) / difference) + 120) % 360\n",
|
101 |
+
" \n",
|
102 |
+
" # if max_rgb=b then compute h\n",
|
103 |
+
" elif max_rgb == b:\n",
|
104 |
+
" h = (60 * ((r - g) / difference) + 240) % 360\n",
|
105 |
+
" \n",
|
106 |
+
" # if max_rgb==zero then s=0\n",
|
107 |
+
" if max_rgb == 0:\n",
|
108 |
+
" s = 0\n",
|
109 |
+
" else:\n",
|
110 |
+
" s = (difference / max_rgb) * 100\n",
|
111 |
+
" \n",
|
112 |
+
" # compute v\n",
|
113 |
+
" v = max_rgb * 100\n",
|
114 |
+
" # return rounded values of H, S and V\n",
|
115 |
+
" return tuple(map(round, (h, s, v)))"
|
116 |
+
]
|
117 |
+
},
|
118 |
{
|
119 |
"cell_type": "code",
|
120 |
"execution_count": null,
|
|
|
125 |
"num_colors = 7"
|
126 |
]
|
127 |
},
|
128 |
+
{
|
129 |
+
"cell_type": "code",
|
130 |
+
"execution_count": null,
|
131 |
+
"id": "c8428918",
|
132 |
+
"metadata": {},
|
133 |
+
"outputs": [],
|
134 |
+
"source": [
|
135 |
+
"bins = [(x-1) * 360 / (num_colors - 1) if x != 1 \n",
|
136 |
+
" else 1 for x in range(num_colors + 1)]\n",
|
137 |
+
"bins[0] = 0\n",
|
138 |
+
"\n",
|
139 |
+
"bins\n",
|
140 |
+
" "
|
141 |
+
]
|
142 |
+
},
|
143 |
{
|
144 |
"cell_type": "code",
|
145 |
"execution_count": null,
|
|
|
147 |
"metadata": {},
|
148 |
"outputs": [],
|
149 |
"source": [
|
150 |
+
"centers = [int((bins[i-1]+bins[i])/2) for i in range(len(bins)) if i > 0]"
|
|
|
151 |
]
|
152 |
},
|
153 |
{
|
|
|
157 |
"metadata": {},
|
158 |
"outputs": [],
|
159 |
"source": [
|
160 |
+
"print(bins)\n",
|
161 |
"print(centers)"
|
162 |
]
|
163 |
},
|
|
|
196 |
"metadata": {},
|
197 |
"outputs": [],
|
198 |
"source": [
|
199 |
+
"# def to_256(val):\n",
|
200 |
+
"# x = val*360/256\n",
|
201 |
+
"# return int(x)"
|
202 |
]
|
203 |
},
|
204 |
{
|
|
|
208 |
"metadata": {},
|
209 |
"outputs": [],
|
210 |
"source": [
|
211 |
+
"names = ['Gray', 'Red', 'Yellow', 'Green', 'Cyan', 'Blue','Magenta']"
|
|
|
|
|
212 |
]
|
213 |
},
|
214 |
{
|
|
|
221 |
"saturation = 1 # Saturation value (0 to 1)\n",
|
222 |
"value = 1 # Value (brightness) value (0 to 1)\n",
|
223 |
"for hue, name in zip(centers, names[:num_colors]):\n",
|
224 |
+
" image = create_color_image(hue, saturation, value)\n",
|
225 |
" display_image(image, name) # Display the generated color image"
|
226 |
]
|
227 |
},
|
|
|
242 |
" model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore\n"
|
243 |
]
|
244 |
},
|
245 |
+
{
|
246 |
+
"cell_type": "code",
|
247 |
+
"execution_count": null,
|
248 |
+
"id": "065cd656",
|
249 |
+
"metadata": {},
|
250 |
+
"outputs": [],
|
251 |
+
"source": [
|
252 |
+
"from DisentanglementBase import DisentanglementBase"
|
253 |
+
]
|
254 |
+
},
|
255 |
+
{
|
256 |
+
"cell_type": "code",
|
257 |
+
"execution_count": null,
|
258 |
+
"id": "afb8a611",
|
259 |
+
"metadata": {},
|
260 |
+
"outputs": [],
|
261 |
+
"source": [
|
262 |
+
"variable = 'H1'\n",
|
263 |
+
"disentanglemnet_exp = DisentanglementBase('.', model, annotations, ann_df, space='W', colors_list=names, compute_s=False, variable=variable)\n"
|
264 |
+
]
|
265 |
+
},
|
266 |
+
{
|
267 |
+
"cell_type": "code",
|
268 |
+
"execution_count": null,
|
269 |
+
"id": "a7398217",
|
270 |
+
"metadata": {},
|
271 |
+
"outputs": [],
|
272 |
+
"source": [
|
273 |
+
"ann_df = disentanglemnet_exp.df"
|
274 |
+
]
|
275 |
+
},
|
276 |
{
|
277 |
"cell_type": "code",
|
278 |
"execution_count": null,
|
|
|
280 |
"metadata": {},
|
281 |
"outputs": [],
|
282 |
"source": [
|
|
|
283 |
"ann_df.head()"
|
284 |
]
|
285 |
},
|
|
|
415 |
"metadata": {},
|
416 |
"outputs": [],
|
417 |
"source": [
|
418 |
+
"colors_list = names"
|
|
|
|
|
419 |
]
|
420 |
},
|
421 |
{
|
|
|
435 |
"source": [
|
436 |
"from sklearn import svm\n",
|
437 |
"\n",
|
438 |
+
"y_h_cat = pd.cut(y_h,bins=bins,labels=colors_list, include_lowest=True)\n",
|
|
|
439 |
"\n",
|
440 |
"print(y_h_cat.value_counts(dropna=False))\n",
|
441 |
+
"\n",
|
442 |
+
"y_h_cat[y_s == 0] = 'Gray'\n",
|
443 |
+
"y_h_cat[y_s == 100] = 'Gray'\n",
|
444 |
+
"y_h_cat[y_v == 0] = 'Gray'\n",
|
445 |
+
"y_h_cat[y_v == 100] = 'Gray'\n",
|
446 |
+
"\n",
|
447 |
+
"print(y_h_cat.value_counts(dropna=False))\n",
|
448 |
+
"\n",
|
449 |
"x_trainhc, x_valhc, y_trainhc, y_valhc = train_test_split(X, y_h_cat, test_size=0.2)"
|
450 |
]
|
451 |
},
|
pages/{1_Textiles_Disentanglement.py → 1_Textiles_Manipulation.py}
RENAMED
@@ -20,10 +20,14 @@ BACKGROUND_COLOR = '#bcd0e7'
|
|
20 |
SECONDARY_COLOR = '#bce7db'
|
21 |
|
22 |
|
23 |
-
st.title('Disentanglement
|
24 |
st.markdown(
|
25 |
"""
|
26 |
-
This is a demo of the Disentanglement
|
|
|
|
|
|
|
|
|
27 |
""",
|
28 |
unsafe_allow_html=False,)
|
29 |
|
@@ -49,7 +53,7 @@ with dnnlib.util.open_url('./data/textile_model_files/network-snapshot-005000.pk
|
|
49 |
COLORS_LIST = ['Gray', 'Red Orange', 'Yellow', 'Green', 'Light Blue', 'Blue', 'Purple', 'Pink']
|
50 |
|
51 |
if 'image_id' not in st.session_state:
|
52 |
-
st.session_state.image_id =
|
53 |
if 'color_ids' not in st.session_state:
|
54 |
st.session_state.concept_ids = COLORS_LIST[-1]
|
55 |
if 'space_id' not in st.session_state:
|
@@ -73,9 +77,6 @@ if 'num_factors' not in st.session_state:
|
|
73 |
if 'best' not in st.session_state:
|
74 |
st.session_state.best = True
|
75 |
|
76 |
-
# def on_change_random_input():
|
77 |
-
# st.session_state.image_id = st.session_state.image_id
|
78 |
-
|
79 |
# ----------------------------- INPUT ----------------------------------
|
80 |
st.header('Input')
|
81 |
input_col_1, input_col_2, input_col_3, input_col_4 = st.columns(4)
|
@@ -83,8 +84,7 @@ input_col_1, input_col_2, input_col_3, input_col_4 = st.columns(4)
|
|
83 |
with input_col_1:
|
84 |
with st.form('image_form'):
|
85 |
|
86 |
-
|
87 |
-
st.write('**Choose or generate a random image to test the disentanglement**')
|
88 |
chosen_image_id_input = st.empty()
|
89 |
image_id = chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
|
90 |
|
@@ -103,16 +103,15 @@ with input_col_1:
|
|
103 |
with input_col_2:
|
104 |
with st.form('text_form_1'):
|
105 |
|
106 |
-
st.write('**Choose
|
107 |
-
type_col = st.selectbox('
|
108 |
-
colors_button = st.form_submit_button('Choose the defined color')
|
109 |
|
110 |
st.write('**Set range of change**')
|
111 |
chosen_color_lambda_input = st.empty()
|
112 |
color_lambda = chosen_color_lambda_input.number_input('Lambda:', min_value=-100, step=1, value=7)
|
113 |
-
color_lambda_button = st.form_submit_button('Choose the defined
|
114 |
|
115 |
-
if
|
116 |
st.session_state.image_id = image_id
|
117 |
st.session_state.concept_ids = type_col
|
118 |
st.session_state.color_lambda = color_lambda
|
@@ -121,45 +120,48 @@ with input_col_2:
|
|
121 |
with input_col_3:
|
122 |
with st.form('text_form'):
|
123 |
|
124 |
-
st.write('**
|
125 |
chosen_saturation_lambda_input = st.empty()
|
126 |
saturation_lambda = chosen_saturation_lambda_input.number_input('Lambda:', min_value=-100, step=1, key=0, value=0)
|
127 |
-
saturation_lambda_button = st.form_submit_button('Choose the defined lambda for saturation')
|
128 |
|
129 |
-
st.write('**
|
130 |
chosen_value_lambda_input = st.empty()
|
131 |
value_lambda = chosen_value_lambda_input.number_input('Lambda:', min_value=-100, step=1, key=1, value=0)
|
132 |
-
value_lambda_button = st.form_submit_button('Choose the defined lambda for
|
133 |
|
134 |
-
if
|
135 |
st.session_state.saturation_lambda = int(saturation_lambda)
|
136 |
st.session_state.value_lambda = int(value_lambda)
|
137 |
|
138 |
with input_col_4:
|
139 |
with st.form('text_form_2'):
|
140 |
-
st.write('Use best
|
141 |
best = st.selectbox('Option:', tuple([True, False]), index=0)
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
choose_options_button = st.form_submit_button('Choose the defined options')
|
152 |
-
# st.write('**Choose a latent space to disentangle**')
|
153 |
-
# # chosen_text_id_input = st.empty()
|
154 |
-
# # concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
|
155 |
-
# space_id = st.selectbox('Space:', tuple(['Z', 'W']))
|
156 |
if choose_options_button:
|
157 |
-
st.session_state.sign = sign
|
158 |
-
st.session_state.num_factors = num_factors
|
159 |
-
st.session_state.cl_method = cl_method
|
160 |
-
st.session_state.regularization = regularization
|
161 |
-
st.session_state.extremes = extremes
|
162 |
st.session_state.best = best
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
# with input_col_4:
|
165 |
# with st.form('Network specifics:'):
|
@@ -178,7 +180,7 @@ with input_col_4:
|
|
178 |
# ---------------------------- SET UP OUTPUT ------------------------------
|
179 |
epsilon_container = st.empty()
|
180 |
st.header('Image Manipulation')
|
181 |
-
st.
|
182 |
|
183 |
header_col_1, header_col_2 = st.columns([1,1])
|
184 |
output_col_1, output_col_2 = st.columns([1,1])
|
@@ -193,7 +195,7 @@ output_col_1, output_col_2 = st.columns([1,1])
|
|
193 |
|
194 |
# ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
|
195 |
with header_col_1:
|
196 |
-
st.write(f'Original image')
|
197 |
|
198 |
with header_col_2:
|
199 |
if st.session_state.best:
|
@@ -209,8 +211,15 @@ with header_col_2:
|
|
209 |
tmp_sat = concept_vectors[concept_vectors['color'] == 'Saturation'][concept_vectors['extremes'] == st.session_state.extremes]
|
210 |
saturation_separation_vector, performance_saturation = tmp_sat.reset_index().loc[0, ['vector', 'score']]
|
211 |
|
212 |
-
st.write(
|
213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
# ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
|
215 |
|
216 |
if st.session_state.space_id == 'Z':
|
@@ -226,3 +235,4 @@ with output_col_1:
|
|
226 |
with output_col_2:
|
227 |
image_updated = generate_composite_images(model, original_image_vec, [color_separation_vector, saturation_separation_vector, value_separation_vector], lambdas=[st.session_state.color_lambda, st.session_state.saturation_lambda, st.session_state.value_lambda])
|
228 |
st.image(image_updated)
|
|
|
|
20 |
SECONDARY_COLOR = '#bce7db'
|
21 |
|
22 |
|
23 |
+
st.title('Disentanglement on Textile Datasets')
|
24 |
st.markdown(
|
25 |
"""
|
26 |
+
This is a demo of the Disentanglement experiment on the [iMET Textiles Dataset](https://www.metmuseum.org/art/collection/search/85531).
|
27 |
+
|
28 |
+
In this page, the user can adjust the colors of textile images generated by an AI by simply traversing the latent space of the AI.
|
29 |
+
The colors can be adjusted following the human-intuitive encoding of HSV, adjusting the main Hue of the image with an option of 7 colors + Gray,
|
30 |
+
the saturation (the amount of Gray) and the value of the image (the amount of Black).
|
31 |
""",
|
32 |
unsafe_allow_html=False,)
|
33 |
|
|
|
53 |
COLORS_LIST = ['Gray', 'Red Orange', 'Yellow', 'Green', 'Light Blue', 'Blue', 'Purple', 'Pink']
|
54 |
|
55 |
if 'image_id' not in st.session_state:
|
56 |
+
st.session_state.image_id = 52921
|
57 |
if 'color_ids' not in st.session_state:
|
58 |
st.session_state.concept_ids = COLORS_LIST[-1]
|
59 |
if 'space_id' not in st.session_state:
|
|
|
77 |
if 'best' not in st.session_state:
|
78 |
st.session_state.best = True
|
79 |
|
|
|
|
|
|
|
80 |
# ----------------------------- INPUT ----------------------------------
|
81 |
st.header('Input')
|
82 |
input_col_1, input_col_2, input_col_3, input_col_4 = st.columns(4)
|
|
|
84 |
with input_col_1:
|
85 |
with st.form('image_form'):
|
86 |
|
87 |
+
st.write('**Choose or generate a random base image**')
|
|
|
88 |
chosen_image_id_input = st.empty()
|
89 |
image_id = chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
|
90 |
|
|
|
103 |
with input_col_2:
|
104 |
with st.form('text_form_1'):
|
105 |
|
106 |
+
st.write('**Choose hue to vary**')
|
107 |
+
type_col = st.selectbox('Hue:', tuple(COLORS_LIST), index=7)
|
|
|
108 |
|
109 |
st.write('**Set range of change**')
|
110 |
chosen_color_lambda_input = st.empty()
|
111 |
color_lambda = chosen_color_lambda_input.number_input('Lambda:', min_value=-100, step=1, value=7)
|
112 |
+
color_lambda_button = st.form_submit_button('Choose the defined hue and lambda')
|
113 |
|
114 |
+
if color_lambda_button:
|
115 |
st.session_state.image_id = image_id
|
116 |
st.session_state.concept_ids = type_col
|
117 |
st.session_state.color_lambda = color_lambda
|
|
|
120 |
with input_col_3:
|
121 |
with st.form('text_form'):
|
122 |
|
123 |
+
st.write('**Choose saturation variation**')
|
124 |
chosen_saturation_lambda_input = st.empty()
|
125 |
saturation_lambda = chosen_saturation_lambda_input.number_input('Lambda:', min_value=-100, step=1, key=0, value=0)
|
|
|
126 |
|
127 |
+
st.write('**Choose value variation**')
|
128 |
chosen_value_lambda_input = st.empty()
|
129 |
value_lambda = chosen_value_lambda_input.number_input('Lambda:', min_value=-100, step=1, key=1, value=0)
|
130 |
+
value_lambda_button = st.form_submit_button('Choose the defined lambda for value and saturation')
|
131 |
|
132 |
+
if value_lambda_button:
|
133 |
st.session_state.saturation_lambda = int(saturation_lambda)
|
134 |
st.session_state.value_lambda = int(value_lambda)
|
135 |
|
136 |
with input_col_4:
|
137 |
with st.form('text_form_2'):
|
138 |
+
st.write('Use the best vectors (after hyperparameter tuning)')
|
139 |
best = st.selectbox('Option:', tuple([True, False]), index=0)
|
140 |
+
sign = True
|
141 |
+
num_factors=10
|
142 |
+
cl_method='LR'
|
143 |
+
regularization=0.1
|
144 |
+
extremes=True
|
145 |
+
if st.session_state.best is False:
|
146 |
+
st.write('Options for StyleSpace (not available for Saturation and Value)')
|
147 |
+
sign = st.selectbox('Sign option:', tuple([True, False]), index=1)
|
148 |
+
num_factors = st.selectbox('Number of factors option:', tuple([1, 5, 10, 20, False]), index=4)
|
149 |
+
st.write('Options for InterFaceGAN (not available for Saturation and Value)')
|
150 |
+
cl_method = st.selectbox('Classification method option:', tuple(['LR', 'SVM', False]), index=2)
|
151 |
+
regularization = st.selectbox('Regularization option:', tuple([0.1, 1.0, False]), index=2)
|
152 |
+
st.write('Options for InterFaceGAN (only for Saturation and Value)')
|
153 |
+
extremes = st.selectbox('Extremes option:', tuple([True, False]), index=1)
|
154 |
+
|
155 |
choose_options_button = st.form_submit_button('Choose the defined options')
|
|
|
|
|
|
|
|
|
156 |
if choose_options_button:
|
|
|
|
|
|
|
|
|
|
|
157 |
st.session_state.best = best
|
158 |
+
if st.session_state.best is False:
|
159 |
+
st.session_state.sign = sign
|
160 |
+
st.session_state.num_factors = num_factors
|
161 |
+
st.session_state.cl_method = cl_method
|
162 |
+
st.session_state.regularization = regularization
|
163 |
+
st.session_state.extremes = extremes
|
164 |
+
|
165 |
|
166 |
# with input_col_4:
|
167 |
# with st.form('Network specifics:'):
|
|
|
180 |
# ---------------------------- SET UP OUTPUT ------------------------------
|
181 |
epsilon_container = st.empty()
|
182 |
st.header('Image Manipulation')
|
183 |
+
st.write('Using selected vectors to modify the original image...')
|
184 |
|
185 |
header_col_1, header_col_2 = st.columns([1,1])
|
186 |
output_col_1, output_col_2 = st.columns([1,1])
|
|
|
195 |
|
196 |
# ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
|
197 |
with header_col_1:
|
198 |
+
st.write(f'### Original image')
|
199 |
|
200 |
with header_col_2:
|
201 |
if st.session_state.best:
|
|
|
211 |
tmp_sat = concept_vectors[concept_vectors['color'] == 'Saturation'][concept_vectors['extremes'] == st.session_state.extremes]
|
212 |
saturation_separation_vector, performance_saturation = tmp_sat.reset_index().loc[0, ['vector', 'score']]
|
213 |
|
214 |
+
st.write('### Modified image')
|
215 |
+
st.write(f"""
|
216 |
+
Change in hue: {st.session_state.concept_ids} of amount: {np.round(st.session_state.color_lambda, 2)},
|
217 |
+
in: saturation of amount: {np.round(st.session_state.saturation_lambda, 2)},
|
218 |
+
in: value of amount: {np.round(st.session_state.value_lambda, 2)}.\
|
219 |
+
Verification performance of hue vector: {performance_color},
|
220 |
+
saturation vector: {performance_saturation/100},
|
221 |
+
value vector: {performance_value/100}""")
|
222 |
+
|
223 |
# ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
|
224 |
|
225 |
if st.session_state.space_id == 'Z':
|
|
|
235 |
with output_col_2:
|
236 |
image_updated = generate_composite_images(model, original_image_vec, [color_separation_vector, saturation_separation_vector, value_separation_vector], lambdas=[st.session_state.color_lambda, st.session_state.saturation_lambda, st.session_state.value_lambda])
|
237 |
st.image(image_updated)
|
238 |
+
|
pages/{2_Colours_comparison.py → 2_Network_comparison.py}
RENAMED
@@ -24,7 +24,11 @@ st.set_page_config(layout='wide')
|
|
24 |
|
25 |
st.title('Comparison among color directions')
|
26 |
st.write('> **How do the color directions relate to each other?**')
|
27 |
-
st.write(
|
|
|
|
|
|
|
|
|
28 |
|
29 |
|
30 |
annotations_file = './data/textile_annotated_files/seeds0000-100000_S.pkl'
|
@@ -46,10 +50,8 @@ with dnnlib.util.open_url('./data/textile_model_files/network-snapshot-005000.pk
|
|
46 |
|
47 |
COLORS_LIST = ['Gray', 'Red Orange', 'Yellow', 'Green', 'Light Blue', 'Blue', 'Purple', 'Pink', 'Saturation', 'Value']
|
48 |
|
49 |
-
if 'image_id' not in st.session_state:
|
50 |
-
st.session_state.image_id = 0
|
51 |
if 'concept_ids' not in st.session_state:
|
52 |
-
st.session_state.concept_ids =
|
53 |
if 'sign' not in st.session_state:
|
54 |
st.session_state.sign = False
|
55 |
if 'extremes' not in st.session_state:
|
@@ -60,10 +62,8 @@ if 'cl_method' not in st.session_state:
|
|
60 |
st.session_state.cl_method = False
|
61 |
if 'num_factors' not in st.session_state:
|
62 |
st.session_state.num_factors = False
|
63 |
-
|
64 |
-
|
65 |
-
if 'space_id' not in st.session_state:
|
66 |
-
st.session_state.space_id = 'W'
|
67 |
|
68 |
# ----------------------------- INPUT ----------------------------------
|
69 |
st.header('Input')
|
@@ -76,7 +76,7 @@ with input_col_1:
|
|
76 |
st.write('**Choose a series of colors to compare**')
|
77 |
# chosen_text_id_input = st.empty()
|
78 |
# concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
|
79 |
-
concept_ids = st.multiselect('Color (including Saturation and Value):', tuple(COLORS_LIST), default=
|
80 |
choose_text_button = st.form_submit_button('Choose the defined colors')
|
81 |
|
82 |
if choose_text_button:
|
@@ -85,27 +85,33 @@ with input_col_1:
|
|
85 |
|
86 |
with input_col_2:
|
87 |
with st.form('text_form_1'):
|
88 |
-
st.write('
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
cl_method
|
93 |
-
regularization
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
choose_options_button = st.form_submit_button('Choose the defined options')
|
98 |
-
# st.write('**Choose a latent space to disentangle**')
|
99 |
-
# # chosen_text_id_input = st.empty()
|
100 |
-
# # concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
|
101 |
-
# space_id = st.selectbox('Space:', tuple(['Z', 'W']))
|
102 |
if choose_options_button:
|
103 |
-
st.session_state.
|
104 |
-
st.session_state.
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
|
|
109 |
# ---------------------------- SET UP OUTPUT ------------------------------
|
110 |
epsilon_container = st.empty()
|
111 |
st.header('Comparison')
|
@@ -115,23 +121,28 @@ header_col_1, header_col_2 = st.columns([3,1])
|
|
115 |
output_col_1, output_col_2 = st.columns([3,1])
|
116 |
|
117 |
# ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
|
118 |
-
|
119 |
-
tmp =
|
|
|
|
|
|
|
|
|
120 |
info = tmp.loc[:, ['vector', 'score', 'color', 'kwargs']].values
|
121 |
concept_ids = [i[2] for i in info] #+ ' ' + i[3]
|
122 |
|
123 |
with header_col_1:
|
124 |
-
st.write('Similarity graph')
|
125 |
|
126 |
with header_col_2:
|
127 |
-
st.write('Information')
|
128 |
|
129 |
with output_col_2:
|
130 |
for i,concept_id in enumerate(concept_ids):
|
131 |
-
st.write(f'Color {info[i][2]}
|
|
|
|
|
132 |
|
133 |
with output_col_1:
|
134 |
-
|
135 |
edges = []
|
136 |
for i in range(len(concept_ids)):
|
137 |
for j in range(len(concept_ids)):
|
|
|
24 |
|
25 |
st.title('Comparison among color directions')
|
26 |
st.write('> **How do the color directions relate to each other?**')
|
27 |
+
st.write("""
|
28 |
+
This page provides a simple network-based framework to inspect the vector similarity (cosine similarity) among the found color vectors.
|
29 |
+
The nodes are the colors chosen for comparison and the strength of the edge represents the similarity.
|
30 |
+
|
31 |
+
""")
|
32 |
|
33 |
|
34 |
annotations_file = './data/textile_annotated_files/seeds0000-100000_S.pkl'
|
|
|
50 |
|
51 |
COLORS_LIST = ['Gray', 'Red Orange', 'Yellow', 'Green', 'Light Blue', 'Blue', 'Purple', 'Pink', 'Saturation', 'Value']
|
52 |
|
|
|
|
|
53 |
if 'concept_ids' not in st.session_state:
|
54 |
+
st.session_state.concept_ids = COLORS_LIST
|
55 |
if 'sign' not in st.session_state:
|
56 |
st.session_state.sign = False
|
57 |
if 'extremes' not in st.session_state:
|
|
|
62 |
st.session_state.cl_method = False
|
63 |
if 'num_factors' not in st.session_state:
|
64 |
st.session_state.num_factors = False
|
65 |
+
if 'best' not in st.session_state:
|
66 |
+
st.session_state.best = True
|
|
|
|
|
67 |
|
68 |
# ----------------------------- INPUT ----------------------------------
|
69 |
st.header('Input')
|
|
|
76 |
st.write('**Choose a series of colors to compare**')
|
77 |
# chosen_text_id_input = st.empty()
|
78 |
# concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
|
79 |
+
concept_ids = st.multiselect('Color (including Saturation and Value):', tuple(COLORS_LIST), default=COLORS_LIST)
|
80 |
choose_text_button = st.form_submit_button('Choose the defined colors')
|
81 |
|
82 |
if choose_text_button:
|
|
|
85 |
|
86 |
with input_col_2:
|
87 |
with st.form('text_form_1'):
|
88 |
+
st.write('Use the best vectors (after hyperparameter tuning)')
|
89 |
+
best = st.selectbox('Option:', tuple([True, False]), index=0)
|
90 |
+
sign = True
|
91 |
+
num_factors=10
|
92 |
+
cl_method='LR'
|
93 |
+
regularization=0.1
|
94 |
+
extremes=True
|
95 |
+
if st.session_state.best is False:
|
96 |
+
st.write('Options for StyleSpace (not available for Saturation and Value)')
|
97 |
+
sign = st.selectbox('Sign option:', tuple([True, False]), index=1)
|
98 |
+
num_factors = st.selectbox('Number of factors option:', tuple([1, 5, 10, 20, False]), index=4)
|
99 |
+
st.write('Options for InterFaceGAN (not available for Saturation and Value)')
|
100 |
+
cl_method = st.selectbox('Classification method option:', tuple(['LR', 'SVM', False]), index=2)
|
101 |
+
regularization = st.selectbox('Regularization option:', tuple([0.1, 1.0, False]), index=2)
|
102 |
+
st.write('Options for InterFaceGAN (only for Saturation and Value)')
|
103 |
+
extremes = st.selectbox('Extremes option:', tuple([True, False]), index=1)
|
104 |
+
|
105 |
choose_options_button = st.form_submit_button('Choose the defined options')
|
|
|
|
|
|
|
|
|
106 |
if choose_options_button:
|
107 |
+
st.session_state.best = best
|
108 |
+
if st.session_state.best is False:
|
109 |
+
st.session_state.sign = sign
|
110 |
+
st.session_state.num_factors = num_factors
|
111 |
+
st.session_state.cl_method = cl_method
|
112 |
+
st.session_state.regularization = regularization
|
113 |
+
st.session_state.extremes = extremes
|
114 |
+
|
115 |
# ---------------------------- SET UP OUTPUT ------------------------------
|
116 |
epsilon_container = st.empty()
|
117 |
st.header('Comparison')
|
|
|
121 |
output_col_1, output_col_2 = st.columns([3,1])
|
122 |
|
123 |
# ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
|
124 |
+
if st.session_state.best:
|
125 |
+
tmp = concept_vectors[concept_vectors['color'].isin(st.session_state.concept_ids)].groupby('color').first().reset_index()
|
126 |
+
else:
|
127 |
+
tmp = concept_vectors[concept_vectors['color'].isin(st.session_state.concept_ids)]
|
128 |
+
tmp = tmp[tmp['sign'] == st.session_state.sign][tmp['extremes'] == st.session_state.extremes][tmp['num_factors'] == st.session_state.num_factors][tmp['cl_method'] == st.session_state.cl_method][tmp['regularization'] == st.session_state.regularization]
|
129 |
+
|
130 |
info = tmp.loc[:, ['vector', 'score', 'color', 'kwargs']].values
|
131 |
concept_ids = [i[2] for i in info] #+ ' ' + i[3]
|
132 |
|
133 |
with header_col_1:
|
134 |
+
st.write('### Similarity graph')
|
135 |
|
136 |
with header_col_2:
|
137 |
+
st.write('### Information')
|
138 |
|
139 |
with output_col_2:
|
140 |
for i,concept_id in enumerate(concept_ids):
|
141 |
+
st.write(f'''Color: {info[i][2]}.\
|
142 |
+
Settings: {info[i][3]}\
|
143 |
+
''')
|
144 |
|
145 |
with output_col_1:
|
|
|
146 |
edges = []
|
147 |
for i in range(len(concept_ids)):
|
148 |
for j in range(len(concept_ids)):
|
pages/3_Vectors_algebra.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pickle
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
import random
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from matplotlib.backends.backend_agg import RendererAgg
|
9 |
+
|
10 |
+
from backend.disentangle_concepts import *
|
11 |
+
import torch_utils
|
12 |
+
import dnnlib
|
13 |
+
import legacy
|
14 |
+
|
15 |
+
_lock = RendererAgg.lock
|
16 |
+
|
17 |
+
|
18 |
+
st.set_page_config(layout='wide')
|
19 |
+
BACKGROUND_COLOR = '#bcd0e7'
|
20 |
+
SECONDARY_COLOR = '#bce7db'
|
21 |
+
|
22 |
+
|
23 |
+
st.title('Vector algebra using disentangled vectors')
|
24 |
+
st.markdown(
|
25 |
+
"""
|
26 |
+
This page offers the possibility to edit the colors of a given textile image using vector algebra and projections.
|
27 |
+
It allows to select several colors to move towards and against (selecting a positive or negative lambda).
|
28 |
+
Furthermore, it offers the possibility of conditional manipulation, by moving in the direction of a color n1 without affecting the color n2.
|
29 |
+
This is done using a projected direction n1 - (n1.T n2) n2.
|
30 |
+
""",
|
31 |
+
unsafe_allow_html=False,)
|
32 |
+
|
33 |
+
annotations_file = './data/textile_annotated_files/seeds0000-100000_S.pkl'
|
34 |
+
with open(annotations_file, 'rb') as f:
|
35 |
+
annotations = pickle.load(f)
|
36 |
+
|
37 |
+
concept_vectors = pd.read_csv('./data/stored_vectors/scores_colors_hsv.csv')
|
38 |
+
concept_vectors['vector'] = [np.array([float(xx) for xx in x]) for x in concept_vectors['vector'].str.split(', ')]
|
39 |
+
concept_vectors['score'] = concept_vectors['score'].astype(float)
|
40 |
+
|
41 |
+
concept_vectors = concept_vectors.sort_values('score', ascending=False).reset_index()
|
42 |
+
|
43 |
+
with dnnlib.util.open_url('./data/textile_model_files/network-snapshot-005000.pkl') as f:
|
44 |
+
model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore
|
45 |
+
|
46 |
+
COLORS_LIST = ['Gray', 'Red Orange', 'Yellow', 'Green', 'Light Blue', 'Blue', 'Purple', 'Pink', 'Saturation', 'Value']
|
47 |
+
COLORS_NEGATIVE = COLORS_LIST + ['None']
|
48 |
+
|
49 |
+
if 'image_id' not in st.session_state:
|
50 |
+
st.session_state.image_id = 52921
|
51 |
+
if 'colors' not in st.session_state:
|
52 |
+
st.session_state.colors = [COLORS_LIST[5], COLORS_LIST[7]]
|
53 |
+
if 'non_colors' not in st.session_state:
|
54 |
+
st.session_state.non_colors = ['None']
|
55 |
+
if 'color_lambda' not in st.session_state:
|
56 |
+
st.session_state.color_lambda = [5]
|
57 |
+
|
58 |
+
# ----------------------------- INPUT ----------------------------------
|
59 |
+
epsilon_container = st.empty()
|
60 |
+
st.header('Image Manipulation with Vector Algebra')
|
61 |
+
|
62 |
+
header_col_1, header_col_2, header_col_3, header_col_4 = st.columns([1,1,1,1])
|
63 |
+
input_col_1, output_col_2, output_col_3, input_col_4 = st.columns([1,1,1,1])
|
64 |
+
|
65 |
+
# --------------------------- INPUT column 1 ---------------------------
|
66 |
+
with input_col_1:
|
67 |
+
with st.form('image_form'):
|
68 |
+
|
69 |
+
# image_id = st.number_input('Image ID: ', format='%d', step=1)
|
70 |
+
st.write('**Choose or generate a random image**')
|
71 |
+
chosen_image_id_input = st.empty()
|
72 |
+
image_id = chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
|
73 |
+
|
74 |
+
choose_image_button = st.form_submit_button('Choose the defined image')
|
75 |
+
random_id = st.form_submit_button('Generate a random image')
|
76 |
+
|
77 |
+
if random_id:
|
78 |
+
image_id = random.randint(0, 100000)
|
79 |
+
st.session_state.image_id = image_id
|
80 |
+
chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
|
81 |
+
|
82 |
+
if choose_image_button:
|
83 |
+
image_id = int(image_id)
|
84 |
+
st.session_state.image_id = image_id
|
85 |
+
|
86 |
+
with header_col_1:
|
87 |
+
st.write('### Input image selection')
|
88 |
+
|
89 |
+
original_image_vec = annotations['w_vectors'][st.session_state.image_id]
|
90 |
+
img = generate_original_image(original_image_vec, model)
|
91 |
+
|
92 |
+
with output_col_2:
|
93 |
+
st.image(img)
|
94 |
+
|
95 |
+
with header_col_2:
|
96 |
+
st.write('### Original image')
|
97 |
+
|
98 |
+
with input_col_4:
|
99 |
+
with st.form('text_form_1'):
|
100 |
+
|
101 |
+
st.write('**Colors to vary (including Saturation and Value)**')
|
102 |
+
colors = st.multiselect('Color:', tuple(COLORS_LIST), default=[COLORS_LIST[5], COLORS_LIST[7]])
|
103 |
+
colors_button = st.form_submit_button('Choose the defined colors')
|
104 |
+
|
105 |
+
st.session_state.image_id = image_id
|
106 |
+
st.session_state.colors = colors
|
107 |
+
st.session_state.color_lambda = [5]*len(colors)
|
108 |
+
st.session_state.non_colors = ['None']*len(colors)
|
109 |
+
|
110 |
+
lambdas = []
|
111 |
+
negative_cols = []
|
112 |
+
for color in colors:
|
113 |
+
st.write('### '+color )
|
114 |
+
st.write('**Set range of change (can be negative)**')
|
115 |
+
chosen_color_lambda_input = st.empty()
|
116 |
+
color_lambda = chosen_color_lambda_input.number_input('Lambda:', min_value=-100, step=1, value=5, key=color+'_number')
|
117 |
+
lambdas.append(color_lambda)
|
118 |
+
|
119 |
+
st.write('**Set dimensions of change to not consider**')
|
120 |
+
chosen_color_negative_input = st.empty()
|
121 |
+
color_negative = chosen_color_negative_input.selectbox('Color:', tuple(COLORS_NEGATIVE), index=len(COLORS_NEGATIVE)-1, key=color+'_noncolor')
|
122 |
+
negative_cols.append(color_negative)
|
123 |
+
|
124 |
+
lambdas_button = st.form_submit_button('Submit options')
|
125 |
+
if lambdas_button:
|
126 |
+
st.session_state.color_lambda = lambdas
|
127 |
+
st.session_state.non_colors = negative_cols
|
128 |
+
|
129 |
+
|
130 |
+
with header_col_4:
|
131 |
+
st.write('### Color settings')
|
132 |
+
# print(st.session_state.colors)
|
133 |
+
# print(st.session_state.color_lambda)
|
134 |
+
# print(st.session_state.non_colors)
|
135 |
+
|
136 |
+
# ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
|
137 |
+
|
138 |
+
with header_col_3:
|
139 |
+
separation_vectors = []
|
140 |
+
for col in st.session_state.colors:
|
141 |
+
separation_vector, score_1 = concept_vectors[concept_vectors['color'] == col].reset_index().loc[0, ['vector', 'score']]
|
142 |
+
separation_vectors.append(separation_vector)
|
143 |
+
|
144 |
+
negative_separation_vectors = []
|
145 |
+
for non_col in st.session_state.non_colors:
|
146 |
+
if non_col != 'None':
|
147 |
+
negative_separation_vector, score_2 = concept_vectors[concept_vectors['color'] == non_col].reset_index().loc[0, ['vector', 'score']]
|
148 |
+
negative_separation_vectors.append(negative_separation_vector)
|
149 |
+
else:
|
150 |
+
negative_separation_vectors.append('None')
|
151 |
+
## n1 − (n1T n2)n2
|
152 |
+
# print(negative_separation_vectors, separation_vectors)
|
153 |
+
st.write('### Output Image')
|
154 |
+
st.write(f'''Change in colors: {str(st.session_state.colors)},\
|
155 |
+
without affecting colors {str(st.session_state.non_colors)}''')
|
156 |
+
|
157 |
+
# ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
|
158 |
+
|
159 |
+
with output_col_3:
|
160 |
+
image_updated = generate_composite_images(model, original_image_vec, separation_vectors,
|
161 |
+
lambdas=st.session_state.color_lambda,
|
162 |
+
negative_colors=negative_separation_vectors)
|
163 |
+
st.image(image_updated)
|
pyproject.toml
ADDED
File without changes
|
structure_annotations.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
test-docker.sh
ADDED
@@ -0,0 +1,743 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/sh
|
2 |
+
set -e
|
3 |
+
# Docker Engine for Linux installation script.
|
4 |
+
#
|
5 |
+
# This script is intended as a convenient way to configure docker's package
|
6 |
+
# repositories and to install Docker Engine, This script is not recommended
|
7 |
+
# for production environments. Before running this script, make yourself familiar
|
8 |
+
# with potential risks and limitations, and refer to the installation manual
|
9 |
+
# at https://docs.docker.com/engine/install/ for alternative installation methods.
|
10 |
+
#
|
11 |
+
# The script:
|
12 |
+
#
|
13 |
+
# - Requires `root` or `sudo` privileges to run.
|
14 |
+
# - Attempts to detect your Linux distribution and version and configure your
|
15 |
+
# package management system for you.
|
16 |
+
# - Doesn't allow you to customize most installation parameters.
|
17 |
+
# - Installs dependencies and recommendations without asking for confirmation.
|
18 |
+
# - Installs the latest stable release (by default) of Docker CLI, Docker Engine,
|
19 |
+
# Docker Buildx, Docker Compose, containerd, and runc. When using this script
|
20 |
+
# to provision a machine, this may result in unexpected major version upgrades
|
21 |
+
# of these packages. Always test upgrades in a test environment before
|
22 |
+
# deploying to your production systems.
|
23 |
+
# - Isn't designed to upgrade an existing Docker installation. When using the
|
24 |
+
# script to update an existing installation, dependencies may not be updated
|
25 |
+
# to the expected version, resulting in outdated versions.
|
26 |
+
#
|
27 |
+
# Source code is available at https://github.com/docker/docker-install/
|
28 |
+
#
|
29 |
+
# Usage
|
30 |
+
# ==============================================================================
|
31 |
+
#
|
32 |
+
# To install the latest stable versions of Docker CLI, Docker Engine, and their
|
33 |
+
# dependencies:
|
34 |
+
#
|
35 |
+
# 1. download the script
|
36 |
+
#
|
37 |
+
# $ curl -fsSL https://get.docker.com -o install-docker.sh
|
38 |
+
#
|
39 |
+
# 2. verify the script's content
|
40 |
+
#
|
41 |
+
# $ cat install-docker.sh
|
42 |
+
#
|
43 |
+
# 3. run the script with --dry-run to verify the steps it executes
|
44 |
+
#
|
45 |
+
# $ sh install-docker.sh --dry-run
|
46 |
+
#
|
47 |
+
# 4. run the script either as root, or using sudo to perform the installation.
|
48 |
+
#
|
49 |
+
# $ sudo sh install-docker.sh
|
50 |
+
#
|
51 |
+
# Command-line options
|
52 |
+
# ==============================================================================
|
53 |
+
#
|
54 |
+
# --version <VERSION>
|
55 |
+
# Use the --version option to install a specific version, for example:
|
56 |
+
#
|
57 |
+
# $ sudo sh install-docker.sh --version 23.0
|
58 |
+
#
|
59 |
+
# --channel <stable|test>
|
60 |
+
#
|
61 |
+
# Use the --channel option to install from an alternative installation channel.
|
62 |
+
# The following example installs the latest versions from the "test" channel,
|
63 |
+
# which includes pre-releases (alpha, beta, rc):
|
64 |
+
#
|
65 |
+
# $ sudo sh install-docker.sh --channel test
|
66 |
+
#
|
67 |
+
# Alternatively, use the script at https://test.docker.com, which uses the test
|
68 |
+
# channel as default.
|
69 |
+
#
|
70 |
+
# --mirror <Aliyun|AzureChinaCloud>
|
71 |
+
#
|
72 |
+
# Use the --mirror option to install from a mirror supported by this script.
|
73 |
+
# Available mirrors are "Aliyun" (https://mirrors.aliyun.com/docker-ce), and
|
74 |
+
# "AzureChinaCloud" (https://mirror.azure.cn/docker-ce), for example:
|
75 |
+
#
|
76 |
+
# $ sudo sh install-docker.sh --mirror AzureChinaCloud
|
77 |
+
#
|
78 |
+
# ==============================================================================
|
79 |
+
|
80 |
+
|
81 |
+
# Git commit from https://github.com/docker/docker-install when
|
82 |
+
# the script was uploaded (Should only be modified by upload job):
|
83 |
+
SCRIPT_COMMIT_SHA="e5543d473431b782227f8908005543bb4389b8de"
|
84 |
+
|
85 |
+
# strip "v" prefix if present
|
86 |
+
VERSION="${VERSION#v}"
|
87 |
+
|
88 |
+
# The channel to install from:
|
89 |
+
# * stable
|
90 |
+
# * test
|
91 |
+
# * edge (deprecated)
|
92 |
+
# * nightly (deprecated)
|
93 |
+
DEFAULT_CHANNEL_VALUE="test"
|
94 |
+
if [ -z "$CHANNEL" ]; then
|
95 |
+
CHANNEL=$DEFAULT_CHANNEL_VALUE
|
96 |
+
fi
|
97 |
+
|
98 |
+
DEFAULT_DOWNLOAD_URL="https://download.docker.com"
|
99 |
+
if [ -z "$DOWNLOAD_URL" ]; then
|
100 |
+
DOWNLOAD_URL=$DEFAULT_DOWNLOAD_URL
|
101 |
+
fi
|
102 |
+
|
103 |
+
DEFAULT_REPO_FILE="docker-ce.repo"
|
104 |
+
if [ -z "$REPO_FILE" ]; then
|
105 |
+
REPO_FILE="$DEFAULT_REPO_FILE"
|
106 |
+
fi
|
107 |
+
|
108 |
+
mirror=''
|
109 |
+
DRY_RUN=${DRY_RUN:-}
|
110 |
+
while [ $# -gt 0 ]; do
|
111 |
+
case "$1" in
|
112 |
+
--channel)
|
113 |
+
CHANNEL="$2"
|
114 |
+
shift
|
115 |
+
;;
|
116 |
+
--dry-run)
|
117 |
+
DRY_RUN=1
|
118 |
+
;;
|
119 |
+
--mirror)
|
120 |
+
mirror="$2"
|
121 |
+
shift
|
122 |
+
;;
|
123 |
+
--version)
|
124 |
+
VERSION="${2#v}"
|
125 |
+
shift
|
126 |
+
;;
|
127 |
+
--*)
|
128 |
+
echo "Illegal option $1"
|
129 |
+
;;
|
130 |
+
esac
|
131 |
+
shift $(( $# > 0 ? 1 : 0 ))
|
132 |
+
done
|
133 |
+
|
134 |
+
case "$mirror" in
|
135 |
+
Aliyun)
|
136 |
+
DOWNLOAD_URL="https://mirrors.aliyun.com/docker-ce"
|
137 |
+
;;
|
138 |
+
AzureChinaCloud)
|
139 |
+
DOWNLOAD_URL="https://mirror.azure.cn/docker-ce"
|
140 |
+
;;
|
141 |
+
"")
|
142 |
+
;;
|
143 |
+
*)
|
144 |
+
>&2 echo "unknown mirror '$mirror': use either 'Aliyun', or 'AzureChinaCloud'."
|
145 |
+
exit 1
|
146 |
+
;;
|
147 |
+
esac
|
148 |
+
|
149 |
+
case "$CHANNEL" in
|
150 |
+
stable|test)
|
151 |
+
;;
|
152 |
+
edge|nightly)
|
153 |
+
>&2 echo "DEPRECATED: the $CHANNEL channel has been deprecated and is no longer supported by this script."
|
154 |
+
exit 1
|
155 |
+
;;
|
156 |
+
*)
|
157 |
+
>&2 echo "unknown CHANNEL '$CHANNEL': use either stable or test."
|
158 |
+
exit 1
|
159 |
+
;;
|
160 |
+
esac
|
161 |
+
|
162 |
+
command_exists() {
|
163 |
+
command -v "$@" > /dev/null 2>&1
|
164 |
+
}
|
165 |
+
|
166 |
+
# version_gte checks if the version specified in $VERSION is at least the given
|
167 |
+
# SemVer (Maj.Minor[.Patch]), or CalVer (YY.MM) version.It returns 0 (success)
|
168 |
+
# if $VERSION is either unset (=latest) or newer or equal than the specified
|
169 |
+
# version, or returns 1 (fail) otherwise.
|
170 |
+
#
|
171 |
+
# examples:
|
172 |
+
#
|
173 |
+
# VERSION=23.0
|
174 |
+
# version_gte 23.0 // 0 (success)
|
175 |
+
# version_gte 20.10 // 0 (success)
|
176 |
+
# version_gte 19.03 // 0 (success)
|
177 |
+
# version_gte 21.10 // 1 (fail)
|
178 |
+
version_gte() {
|
179 |
+
if [ -z "$VERSION" ]; then
|
180 |
+
return 0
|
181 |
+
fi
|
182 |
+
eval version_compare "$VERSION" "$1"
|
183 |
+
}
|
184 |
+
|
185 |
+
# version_compare compares two version strings (either SemVer (Major.Minor.Path),
|
186 |
+
# or CalVer (YY.MM) version strings. It returns 0 (success) if version A is newer
|
187 |
+
# or equal than version B, or 1 (fail) otherwise. Patch releases and pre-release
|
188 |
+
# (-alpha/-beta) are not taken into account
|
189 |
+
#
|
190 |
+
# examples:
|
191 |
+
#
|
192 |
+
# version_compare 23.0.0 20.10 // 0 (success)
|
193 |
+
# version_compare 23.0 20.10 // 0 (success)
|
194 |
+
# version_compare 20.10 19.03 // 0 (success)
|
195 |
+
# version_compare 20.10 20.10 // 0 (success)
|
196 |
+
# version_compare 19.03 20.10 // 1 (fail)
|
197 |
+
version_compare() (
|
198 |
+
set +x
|
199 |
+
|
200 |
+
yy_a="$(echo "$1" | cut -d'.' -f1)"
|
201 |
+
yy_b="$(echo "$2" | cut -d'.' -f1)"
|
202 |
+
if [ "$yy_a" -lt "$yy_b" ]; then
|
203 |
+
return 1
|
204 |
+
fi
|
205 |
+
if [ "$yy_a" -gt "$yy_b" ]; then
|
206 |
+
return 0
|
207 |
+
fi
|
208 |
+
mm_a="$(echo "$1" | cut -d'.' -f2)"
|
209 |
+
mm_b="$(echo "$2" | cut -d'.' -f2)"
|
210 |
+
|
211 |
+
# trim leading zeros to accommodate CalVer
|
212 |
+
mm_a="${mm_a#0}"
|
213 |
+
mm_b="${mm_b#0}"
|
214 |
+
|
215 |
+
if [ "${mm_a:-0}" -lt "${mm_b:-0}" ]; then
|
216 |
+
return 1
|
217 |
+
fi
|
218 |
+
|
219 |
+
return 0
|
220 |
+
)
|
221 |
+
|
222 |
+
is_dry_run() {
|
223 |
+
if [ -z "$DRY_RUN" ]; then
|
224 |
+
return 1
|
225 |
+
else
|
226 |
+
return 0
|
227 |
+
fi
|
228 |
+
}
|
229 |
+
|
230 |
+
is_wsl() {
|
231 |
+
case "$(uname -r)" in
|
232 |
+
*microsoft* ) true ;; # WSL 2
|
233 |
+
*Microsoft* ) true ;; # WSL 1
|
234 |
+
* ) false;;
|
235 |
+
esac
|
236 |
+
}
|
237 |
+
|
238 |
+
is_darwin() {
|
239 |
+
case "$(uname -s)" in
|
240 |
+
*darwin* ) true ;;
|
241 |
+
*Darwin* ) true ;;
|
242 |
+
* ) false;;
|
243 |
+
esac
|
244 |
+
}
|
245 |
+
|
246 |
+
deprecation_notice() {
|
247 |
+
distro=$1
|
248 |
+
distro_version=$2
|
249 |
+
echo
|
250 |
+
printf "\033[91;1mDEPRECATION WARNING\033[0m\n"
|
251 |
+
printf " This Linux distribution (\033[1m%s %s\033[0m) reached end-of-life and is no longer supported by this script.\n" "$distro" "$distro_version"
|
252 |
+
echo " No updates or security fixes will be released for this distribution, and users are recommended"
|
253 |
+
echo " to upgrade to a currently maintained version of $distro."
|
254 |
+
echo
|
255 |
+
printf "Press \033[1mCtrl+C\033[0m now to abort this script, or wait for the installation to continue."
|
256 |
+
echo
|
257 |
+
sleep 10
|
258 |
+
}
|
259 |
+
|
260 |
+
get_distribution() {
|
261 |
+
lsb_dist=""
|
262 |
+
# Every system that we officially support has /etc/os-release
|
263 |
+
if [ -r /etc/os-release ]; then
|
264 |
+
lsb_dist="$(. /etc/os-release && echo "$ID")"
|
265 |
+
fi
|
266 |
+
# Returning an empty string here should be alright since the
|
267 |
+
# case statements don't act unless you provide an actual value
|
268 |
+
echo "$lsb_dist"
|
269 |
+
}
|
270 |
+
|
271 |
+
echo_docker_as_nonroot() {
|
272 |
+
if is_dry_run; then
|
273 |
+
return
|
274 |
+
fi
|
275 |
+
if command_exists docker && [ -e /var/run/docker.sock ]; then
|
276 |
+
(
|
277 |
+
set -x
|
278 |
+
$sh_c 'docker version'
|
279 |
+
) || true
|
280 |
+
fi
|
281 |
+
|
282 |
+
# intentionally mixed spaces and tabs here -- tabs are stripped by "<<-EOF", spaces are kept in the output
|
283 |
+
echo
|
284 |
+
echo "================================================================================"
|
285 |
+
echo
|
286 |
+
if version_gte "20.10"; then
|
287 |
+
echo "To run Docker as a non-privileged user, consider setting up the"
|
288 |
+
echo "Docker daemon in rootless mode for your user:"
|
289 |
+
echo
|
290 |
+
echo " dockerd-rootless-setuptool.sh install"
|
291 |
+
echo
|
292 |
+
echo "Visit https://docs.docker.com/go/rootless/ to learn about rootless mode."
|
293 |
+
echo
|
294 |
+
fi
|
295 |
+
echo
|
296 |
+
echo "To run the Docker daemon as a fully privileged service, but granting non-root"
|
297 |
+
echo "users access, refer to https://docs.docker.com/go/daemon-access/"
|
298 |
+
echo
|
299 |
+
echo "WARNING: Access to the remote API on a privileged Docker daemon is equivalent"
|
300 |
+
echo " to root access on the host. Refer to the 'Docker daemon attack surface'"
|
301 |
+
echo " documentation for details: https://docs.docker.com/go/attack-surface/"
|
302 |
+
echo
|
303 |
+
echo "================================================================================"
|
304 |
+
echo
|
305 |
+
}
|
306 |
+
|
307 |
+
# Check if this is a forked Linux distro
|
308 |
+
check_forked() {
|
309 |
+
|
310 |
+
# Check for lsb_release command existence, it usually exists in forked distros
|
311 |
+
if command_exists lsb_release; then
|
312 |
+
# Check if the `-u` option is supported
|
313 |
+
set +e
|
314 |
+
lsb_release -a -u > /dev/null 2>&1
|
315 |
+
lsb_release_exit_code=$?
|
316 |
+
set -e
|
317 |
+
|
318 |
+
# Check if the command has exited successfully, it means we're in a forked distro
|
319 |
+
if [ "$lsb_release_exit_code" = "0" ]; then
|
320 |
+
# Print info about current distro
|
321 |
+
cat <<-EOF
|
322 |
+
You're using '$lsb_dist' version '$dist_version'.
|
323 |
+
EOF
|
324 |
+
|
325 |
+
# Get the upstream release info
|
326 |
+
lsb_dist=$(lsb_release -a -u 2>&1 | tr '[:upper:]' '[:lower:]' | grep -E 'id' | cut -d ':' -f 2 | tr -d '[:space:]')
|
327 |
+
dist_version=$(lsb_release -a -u 2>&1 | tr '[:upper:]' '[:lower:]' | grep -E 'codename' | cut -d ':' -f 2 | tr -d '[:space:]')
|
328 |
+
|
329 |
+
# Print info about upstream distro
|
330 |
+
cat <<-EOF
|
331 |
+
Upstream release is '$lsb_dist' version '$dist_version'.
|
332 |
+
EOF
|
333 |
+
else
|
334 |
+
if [ -r /etc/debian_version ] && [ "$lsb_dist" != "ubuntu" ] && [ "$lsb_dist" != "raspbian" ]; then
|
335 |
+
if [ "$lsb_dist" = "osmc" ]; then
|
336 |
+
# OSMC runs Raspbian
|
337 |
+
lsb_dist=raspbian
|
338 |
+
else
|
339 |
+
# We're Debian and don't even know it!
|
340 |
+
lsb_dist=debian
|
341 |
+
fi
|
342 |
+
dist_version="$(sed 's/\/.*//' /etc/debian_version | sed 's/\..*//')"
|
343 |
+
case "$dist_version" in
|
344 |
+
12)
|
345 |
+
dist_version="bookworm"
|
346 |
+
;;
|
347 |
+
11)
|
348 |
+
dist_version="bullseye"
|
349 |
+
;;
|
350 |
+
10)
|
351 |
+
dist_version="buster"
|
352 |
+
;;
|
353 |
+
9)
|
354 |
+
dist_version="stretch"
|
355 |
+
;;
|
356 |
+
8)
|
357 |
+
dist_version="jessie"
|
358 |
+
;;
|
359 |
+
esac
|
360 |
+
fi
|
361 |
+
fi
|
362 |
+
fi
|
363 |
+
}
|
364 |
+
|
365 |
+
do_install() {
|
366 |
+
echo "# Executing docker install script, commit: $SCRIPT_COMMIT_SHA"
|
367 |
+
|
368 |
+
if command_exists docker; then
|
369 |
+
cat >&2 <<-'EOF'
|
370 |
+
Warning: the "docker" command appears to already exist on this system.
|
371 |
+
|
372 |
+
If you already have Docker installed, this script can cause trouble, which is
|
373 |
+
why we're displaying this warning and provide the opportunity to cancel the
|
374 |
+
installation.
|
375 |
+
|
376 |
+
If you installed the current Docker package using this script and are using it
|
377 |
+
again to update Docker, you can safely ignore this message.
|
378 |
+
|
379 |
+
You may press Ctrl+C now to abort this script.
|
380 |
+
EOF
|
381 |
+
( set -x; sleep 20 )
|
382 |
+
fi
|
383 |
+
|
384 |
+
user="$(id -un 2>/dev/null || true)"
|
385 |
+
|
386 |
+
sh_c='sh -c'
|
387 |
+
if [ "$user" != 'root' ]; then
|
388 |
+
if command_exists sudo; then
|
389 |
+
sh_c='sudo -E sh -c'
|
390 |
+
elif command_exists su; then
|
391 |
+
sh_c='su -c'
|
392 |
+
else
|
393 |
+
cat >&2 <<-'EOF'
|
394 |
+
Error: this installer needs the ability to run commands as root.
|
395 |
+
We are unable to find either "sudo" or "su" available to make this happen.
|
396 |
+
EOF
|
397 |
+
exit 1
|
398 |
+
fi
|
399 |
+
fi
|
400 |
+
|
401 |
+
if is_dry_run; then
|
402 |
+
sh_c="echo"
|
403 |
+
fi
|
404 |
+
|
405 |
+
# perform some very rudimentary platform detection
|
406 |
+
lsb_dist=$( get_distribution )
|
407 |
+
lsb_dist="$(echo "$lsb_dist" | tr '[:upper:]' '[:lower:]')"
|
408 |
+
|
409 |
+
if is_wsl; then
|
410 |
+
echo
|
411 |
+
echo "WSL DETECTED: We recommend using Docker Desktop for Windows."
|
412 |
+
echo "Please get Docker Desktop from https://www.docker.com/products/docker-desktop/"
|
413 |
+
echo
|
414 |
+
cat >&2 <<-'EOF'
|
415 |
+
|
416 |
+
You may press Ctrl+C now to abort this script.
|
417 |
+
EOF
|
418 |
+
( set -x; sleep 20 )
|
419 |
+
fi
|
420 |
+
|
421 |
+
case "$lsb_dist" in
|
422 |
+
|
423 |
+
ubuntu)
|
424 |
+
if command_exists lsb_release; then
|
425 |
+
dist_version="$(lsb_release --codename | cut -f2)"
|
426 |
+
fi
|
427 |
+
if [ -z "$dist_version" ] && [ -r /etc/lsb-release ]; then
|
428 |
+
dist_version="$(. /etc/lsb-release && echo "$DISTRIB_CODENAME")"
|
429 |
+
fi
|
430 |
+
;;
|
431 |
+
|
432 |
+
debian|raspbian)
|
433 |
+
dist_version="$(sed 's/\/.*//' /etc/debian_version | sed 's/\..*//')"
|
434 |
+
case "$dist_version" in
|
435 |
+
12)
|
436 |
+
dist_version="bookworm"
|
437 |
+
;;
|
438 |
+
11)
|
439 |
+
dist_version="bullseye"
|
440 |
+
;;
|
441 |
+
10)
|
442 |
+
dist_version="buster"
|
443 |
+
;;
|
444 |
+
9)
|
445 |
+
dist_version="stretch"
|
446 |
+
;;
|
447 |
+
8)
|
448 |
+
dist_version="jessie"
|
449 |
+
;;
|
450 |
+
esac
|
451 |
+
;;
|
452 |
+
|
453 |
+
centos|rhel|sles)
|
454 |
+
if [ -z "$dist_version" ] && [ -r /etc/os-release ]; then
|
455 |
+
dist_version="$(. /etc/os-release && echo "$VERSION_ID")"
|
456 |
+
fi
|
457 |
+
;;
|
458 |
+
|
459 |
+
*)
|
460 |
+
if command_exists lsb_release; then
|
461 |
+
dist_version="$(lsb_release --release | cut -f2)"
|
462 |
+
fi
|
463 |
+
if [ -z "$dist_version" ] && [ -r /etc/os-release ]; then
|
464 |
+
dist_version="$(. /etc/os-release && echo "$VERSION_ID")"
|
465 |
+
fi
|
466 |
+
;;
|
467 |
+
|
468 |
+
esac
|
469 |
+
|
470 |
+
# Check if this is a forked Linux distro
|
471 |
+
check_forked
|
472 |
+
|
473 |
+
# Print deprecation warnings for distro versions that recently reached EOL,
|
474 |
+
# but may still be commonly used (especially LTS versions).
|
475 |
+
case "$lsb_dist.$dist_version" in
|
476 |
+
debian.stretch|debian.jessie)
|
477 |
+
deprecation_notice "$lsb_dist" "$dist_version"
|
478 |
+
;;
|
479 |
+
raspbian.stretch|raspbian.jessie)
|
480 |
+
deprecation_notice "$lsb_dist" "$dist_version"
|
481 |
+
;;
|
482 |
+
ubuntu.xenial|ubuntu.trusty)
|
483 |
+
deprecation_notice "$lsb_dist" "$dist_version"
|
484 |
+
;;
|
485 |
+
ubuntu.impish|ubuntu.hirsute|ubuntu.groovy|ubuntu.eoan|ubuntu.disco|ubuntu.cosmic)
|
486 |
+
deprecation_notice "$lsb_dist" "$dist_version"
|
487 |
+
;;
|
488 |
+
fedora.*)
|
489 |
+
if [ "$dist_version" -lt 36 ]; then
|
490 |
+
deprecation_notice "$lsb_dist" "$dist_version"
|
491 |
+
fi
|
492 |
+
;;
|
493 |
+
esac
|
494 |
+
|
495 |
+
# Run setup for each distro accordingly
|
496 |
+
case "$lsb_dist" in
|
497 |
+
ubuntu|debian|raspbian)
|
498 |
+
pre_reqs="apt-transport-https ca-certificates curl"
|
499 |
+
if ! command -v gpg > /dev/null; then
|
500 |
+
pre_reqs="$pre_reqs gnupg"
|
501 |
+
fi
|
502 |
+
apt_repo="deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] $DOWNLOAD_URL/linux/$lsb_dist $dist_version $CHANNEL"
|
503 |
+
(
|
504 |
+
if ! is_dry_run; then
|
505 |
+
set -x
|
506 |
+
fi
|
507 |
+
$sh_c 'apt-get update -qq >/dev/null'
|
508 |
+
$sh_c "DEBIAN_FRONTEND=noninteractive apt-get install -y -qq $pre_reqs >/dev/null"
|
509 |
+
$sh_c 'install -m 0755 -d /etc/apt/keyrings'
|
510 |
+
$sh_c "curl -fsSL \"$DOWNLOAD_URL/linux/$lsb_dist/gpg\" | gpg --dearmor --yes -o /etc/apt/keyrings/docker.gpg"
|
511 |
+
$sh_c "chmod a+r /etc/apt/keyrings/docker.gpg"
|
512 |
+
$sh_c "echo \"$apt_repo\" > /etc/apt/sources.list.d/docker.list"
|
513 |
+
$sh_c 'apt-get update -qq >/dev/null'
|
514 |
+
)
|
515 |
+
pkg_version=""
|
516 |
+
if [ -n "$VERSION" ]; then
|
517 |
+
if is_dry_run; then
|
518 |
+
echo "# WARNING: VERSION pinning is not supported in DRY_RUN"
|
519 |
+
else
|
520 |
+
# Will work for incomplete versions IE (17.12), but may not actually grab the "latest" if in the test channel
|
521 |
+
pkg_pattern="$(echo "$VERSION" | sed 's/-ce-/~ce~.*/g' | sed 's/-/.*/g')"
|
522 |
+
search_command="apt-cache madison docker-ce | grep '$pkg_pattern' | head -1 | awk '{\$1=\$1};1' | cut -d' ' -f 3"
|
523 |
+
pkg_version="$($sh_c "$search_command")"
|
524 |
+
echo "INFO: Searching repository for VERSION '$VERSION'"
|
525 |
+
echo "INFO: $search_command"
|
526 |
+
if [ -z "$pkg_version" ]; then
|
527 |
+
echo
|
528 |
+
echo "ERROR: '$VERSION' not found amongst apt-cache madison results"
|
529 |
+
echo
|
530 |
+
exit 1
|
531 |
+
fi
|
532 |
+
if version_gte "18.09"; then
|
533 |
+
search_command="apt-cache madison docker-ce-cli | grep '$pkg_pattern' | head -1 | awk '{\$1=\$1};1' | cut -d' ' -f 3"
|
534 |
+
echo "INFO: $search_command"
|
535 |
+
cli_pkg_version="=$($sh_c "$search_command")"
|
536 |
+
fi
|
537 |
+
pkg_version="=$pkg_version"
|
538 |
+
fi
|
539 |
+
fi
|
540 |
+
(
|
541 |
+
pkgs="docker-ce${pkg_version%=}"
|
542 |
+
if version_gte "18.09"; then
|
543 |
+
# older versions didn't ship the cli and containerd as separate packages
|
544 |
+
pkgs="$pkgs docker-ce-cli${cli_pkg_version%=} containerd.io"
|
545 |
+
fi
|
546 |
+
if version_gte "20.10"; then
|
547 |
+
pkgs="$pkgs docker-compose-plugin docker-ce-rootless-extras$pkg_version"
|
548 |
+
fi
|
549 |
+
if version_gte "23.0"; then
|
550 |
+
pkgs="$pkgs docker-buildx-plugin"
|
551 |
+
fi
|
552 |
+
if ! is_dry_run; then
|
553 |
+
set -x
|
554 |
+
fi
|
555 |
+
$sh_c "DEBIAN_FRONTEND=noninteractive apt-get install -y -qq $pkgs >/dev/null"
|
556 |
+
)
|
557 |
+
echo_docker_as_nonroot
|
558 |
+
exit 0
|
559 |
+
;;
|
560 |
+
centos|fedora|rhel)
|
561 |
+
if [ "$(uname -m)" != "s390x" ] && [ "$lsb_dist" = "rhel" ]; then
|
562 |
+
echo "Packages for RHEL are currently only available for s390x."
|
563 |
+
exit 1
|
564 |
+
fi
|
565 |
+
if [ "$lsb_dist" = "fedora" ]; then
|
566 |
+
pkg_manager="dnf"
|
567 |
+
config_manager="dnf config-manager"
|
568 |
+
enable_channel_flag="--set-enabled"
|
569 |
+
disable_channel_flag="--set-disabled"
|
570 |
+
pre_reqs="dnf-plugins-core"
|
571 |
+
pkg_suffix="fc$dist_version"
|
572 |
+
else
|
573 |
+
pkg_manager="yum"
|
574 |
+
config_manager="yum-config-manager"
|
575 |
+
enable_channel_flag="--enable"
|
576 |
+
disable_channel_flag="--disable"
|
577 |
+
pre_reqs="yum-utils"
|
578 |
+
pkg_suffix="el"
|
579 |
+
fi
|
580 |
+
repo_file_url="$DOWNLOAD_URL/linux/$lsb_dist/$REPO_FILE"
|
581 |
+
(
|
582 |
+
if ! is_dry_run; then
|
583 |
+
set -x
|
584 |
+
fi
|
585 |
+
$sh_c "$pkg_manager install -y -q $pre_reqs"
|
586 |
+
$sh_c "$config_manager --add-repo $repo_file_url"
|
587 |
+
|
588 |
+
if [ "$CHANNEL" != "stable" ]; then
|
589 |
+
$sh_c "$config_manager $disable_channel_flag 'docker-ce-*'"
|
590 |
+
$sh_c "$config_manager $enable_channel_flag 'docker-ce-$CHANNEL'"
|
591 |
+
fi
|
592 |
+
$sh_c "$pkg_manager makecache"
|
593 |
+
)
|
594 |
+
pkg_version=""
|
595 |
+
if [ -n "$VERSION" ]; then
|
596 |
+
if is_dry_run; then
|
597 |
+
echo "# WARNING: VERSION pinning is not supported in DRY_RUN"
|
598 |
+
else
|
599 |
+
pkg_pattern="$(echo "$VERSION" | sed 's/-ce-/\\\\.ce.*/g' | sed 's/-/.*/g').*$pkg_suffix"
|
600 |
+
search_command="$pkg_manager list --showduplicates docker-ce | grep '$pkg_pattern' | tail -1 | awk '{print \$2}'"
|
601 |
+
pkg_version="$($sh_c "$search_command")"
|
602 |
+
echo "INFO: Searching repository for VERSION '$VERSION'"
|
603 |
+
echo "INFO: $search_command"
|
604 |
+
if [ -z "$pkg_version" ]; then
|
605 |
+
echo
|
606 |
+
echo "ERROR: '$VERSION' not found amongst $pkg_manager list results"
|
607 |
+
echo
|
608 |
+
exit 1
|
609 |
+
fi
|
610 |
+
if version_gte "18.09"; then
|
611 |
+
# older versions don't support a cli package
|
612 |
+
search_command="$pkg_manager list --showduplicates docker-ce-cli | grep '$pkg_pattern' | tail -1 | awk '{print \$2}'"
|
613 |
+
cli_pkg_version="$($sh_c "$search_command" | cut -d':' -f 2)"
|
614 |
+
fi
|
615 |
+
# Cut out the epoch and prefix with a '-'
|
616 |
+
pkg_version="-$(echo "$pkg_version" | cut -d':' -f 2)"
|
617 |
+
fi
|
618 |
+
fi
|
619 |
+
(
|
620 |
+
pkgs="docker-ce$pkg_version"
|
621 |
+
if version_gte "18.09"; then
|
622 |
+
# older versions didn't ship the cli and containerd as separate packages
|
623 |
+
if [ -n "$cli_pkg_version" ]; then
|
624 |
+
pkgs="$pkgs docker-ce-cli-$cli_pkg_version containerd.io"
|
625 |
+
else
|
626 |
+
pkgs="$pkgs docker-ce-cli containerd.io"
|
627 |
+
fi
|
628 |
+
fi
|
629 |
+
if version_gte "20.10"; then
|
630 |
+
pkgs="$pkgs docker-compose-plugin docker-ce-rootless-extras$pkg_version"
|
631 |
+
fi
|
632 |
+
if version_gte "23.0"; then
|
633 |
+
pkgs="$pkgs docker-buildx-plugin"
|
634 |
+
fi
|
635 |
+
if ! is_dry_run; then
|
636 |
+
set -x
|
637 |
+
fi
|
638 |
+
$sh_c "$pkg_manager install -y -q $pkgs"
|
639 |
+
)
|
640 |
+
echo_docker_as_nonroot
|
641 |
+
exit 0
|
642 |
+
;;
|
643 |
+
sles)
|
644 |
+
if [ "$(uname -m)" != "s390x" ]; then
|
645 |
+
echo "Packages for SLES are currently only available for s390x"
|
646 |
+
exit 1
|
647 |
+
fi
|
648 |
+
if [ "$dist_version" = "15.3" ]; then
|
649 |
+
sles_version="SLE_15_SP3"
|
650 |
+
else
|
651 |
+
sles_minor_version="${dist_version##*.}"
|
652 |
+
sles_version="15.$sles_minor_version"
|
653 |
+
fi
|
654 |
+
repo_file_url="$DOWNLOAD_URL/linux/$lsb_dist/$REPO_FILE"
|
655 |
+
pre_reqs="ca-certificates curl libseccomp2 awk"
|
656 |
+
(
|
657 |
+
if ! is_dry_run; then
|
658 |
+
set -x
|
659 |
+
fi
|
660 |
+
$sh_c "zypper install -y $pre_reqs"
|
661 |
+
$sh_c "zypper addrepo $repo_file_url"
|
662 |
+
if ! is_dry_run; then
|
663 |
+
cat >&2 <<-'EOF'
|
664 |
+
WARNING!!
|
665 |
+
openSUSE repository (https://download.opensuse.org/repositories/security:SELinux) will be enabled now.
|
666 |
+
Do you wish to continue?
|
667 |
+
You may press Ctrl+C now to abort this script.
|
668 |
+
EOF
|
669 |
+
( set -x; sleep 30 )
|
670 |
+
fi
|
671 |
+
opensuse_repo="https://download.opensuse.org/repositories/security:SELinux/$sles_version/security:SELinux.repo"
|
672 |
+
$sh_c "zypper addrepo $opensuse_repo"
|
673 |
+
$sh_c "zypper --gpg-auto-import-keys refresh"
|
674 |
+
$sh_c "zypper lr -d"
|
675 |
+
)
|
676 |
+
pkg_version=""
|
677 |
+
if [ -n "$VERSION" ]; then
|
678 |
+
if is_dry_run; then
|
679 |
+
echo "# WARNING: VERSION pinning is not supported in DRY_RUN"
|
680 |
+
else
|
681 |
+
pkg_pattern="$(echo "$VERSION" | sed 's/-ce-/\\\\.ce.*/g' | sed 's/-/.*/g')"
|
682 |
+
search_command="zypper search -s --match-exact 'docker-ce' | grep '$pkg_pattern' | tail -1 | awk '{print \$6}'"
|
683 |
+
pkg_version="$($sh_c "$search_command")"
|
684 |
+
echo "INFO: Searching repository for VERSION '$VERSION'"
|
685 |
+
echo "INFO: $search_command"
|
686 |
+
if [ -z "$pkg_version" ]; then
|
687 |
+
echo
|
688 |
+
echo "ERROR: '$VERSION' not found amongst zypper list results"
|
689 |
+
echo
|
690 |
+
exit 1
|
691 |
+
fi
|
692 |
+
search_command="zypper search -s --match-exact 'docker-ce-cli' | grep '$pkg_pattern' | tail -1 | awk '{print \$6}'"
|
693 |
+
# It's okay for cli_pkg_version to be blank, since older versions don't support a cli package
|
694 |
+
cli_pkg_version="$($sh_c "$search_command")"
|
695 |
+
pkg_version="-$pkg_version"
|
696 |
+
fi
|
697 |
+
fi
|
698 |
+
(
|
699 |
+
pkgs="docker-ce$pkg_version"
|
700 |
+
if version_gte "18.09"; then
|
701 |
+
if [ -n "$cli_pkg_version" ]; then
|
702 |
+
# older versions didn't ship the cli and containerd as separate packages
|
703 |
+
pkgs="$pkgs docker-ce-cli-$cli_pkg_version containerd.io"
|
704 |
+
else
|
705 |
+
pkgs="$pkgs docker-ce-cli containerd.io"
|
706 |
+
fi
|
707 |
+
fi
|
708 |
+
if version_gte "20.10"; then
|
709 |
+
pkgs="$pkgs docker-compose-plugin docker-ce-rootless-extras$pkg_version"
|
710 |
+
fi
|
711 |
+
if version_gte "23.0"; then
|
712 |
+
pkgs="$pkgs docker-buildx-plugin"
|
713 |
+
fi
|
714 |
+
if ! is_dry_run; then
|
715 |
+
set -x
|
716 |
+
fi
|
717 |
+
$sh_c "zypper -q install -y $pkgs"
|
718 |
+
)
|
719 |
+
echo_docker_as_nonroot
|
720 |
+
exit 0
|
721 |
+
;;
|
722 |
+
*)
|
723 |
+
if [ -z "$lsb_dist" ]; then
|
724 |
+
if is_darwin; then
|
725 |
+
echo
|
726 |
+
echo "ERROR: Unsupported operating system 'macOS'"
|
727 |
+
echo "Please get Docker Desktop from https://www.docker.com/products/docker-desktop"
|
728 |
+
echo
|
729 |
+
exit 1
|
730 |
+
fi
|
731 |
+
fi
|
732 |
+
echo
|
733 |
+
echo "ERROR: Unsupported distribution '$lsb_dist'"
|
734 |
+
echo
|
735 |
+
exit 1
|
736 |
+
;;
|
737 |
+
esac
|
738 |
+
exit 1
|
739 |
+
}
|
740 |
+
|
741 |
+
# wrapped up in a function so that we have some protection against only getting
|
742 |
+
# half the file during "curl | sh"
|
743 |
+
do_install
|