Abhinav Jangra
commited on
Commit
β’
5f17cb0
1
Parent(s):
fc4505f
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
import streamlit as st
|
4 |
+
|
5 |
+
|
6 |
+
st.warning("For larger images, the processing time may be significant. Consider using a lower resolution image or be prepared to wait.")
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from torch.autograd import Variable
|
11 |
+
from torch.optim import SGD
|
12 |
+
from torchvision import models, transforms
|
13 |
+
import PIL
|
14 |
+
from PIL import Image as PILImage
|
15 |
+
import os
|
16 |
+
import matplotlib
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
from matplotlib import animation
|
19 |
+
from IPython.display import HTML
|
20 |
+
|
21 |
+
import scipy.ndimage as ndimage
|
22 |
+
|
23 |
+
#%matplotlib inline
|
24 |
+
|
25 |
+
import scipy.ndimage as nd
|
26 |
+
import PIL.Image
|
27 |
+
from IPython.display import clear_output, Image, display
|
28 |
+
from io import BytesIO
|
29 |
+
|
30 |
+
|
31 |
+
def showarray(a, fmt='jpeg'):
|
32 |
+
a = np.uint8(np.clip(a, 0, 255))
|
33 |
+
f = BytesIO()
|
34 |
+
PIL.Image.fromarray(a).save(f, fmt)
|
35 |
+
display(Image(data=f.getvalue()))
|
36 |
+
|
37 |
+
def showtensor(a):
|
38 |
+
mean = np.array([0.485, 0.456, 0.406]).reshape([1, 1, 3])
|
39 |
+
std = np.array([0.229, 0.224, 0.225]).reshape([1, 1, 3])
|
40 |
+
inp = a[0, :, :, :]
|
41 |
+
inp = inp.transpose(1, 2, 0)
|
42 |
+
inp = std * inp + mean
|
43 |
+
inp *= 255
|
44 |
+
showarray(inp)
|
45 |
+
clear_output(wait=True)
|
46 |
+
|
47 |
+
def plot_images(im, titles=None):
|
48 |
+
plt.figure(figsize=(30, 20))
|
49 |
+
|
50 |
+
for i in range(len(im)):
|
51 |
+
plt.subplot(10 / 5 + 1, 5, i + 1)
|
52 |
+
plt.axis('off')
|
53 |
+
if titles is not None:
|
54 |
+
plt.title(titles[i])
|
55 |
+
plt.imshow(im[i])
|
56 |
+
|
57 |
+
plt.pause(0.001)
|
58 |
+
|
59 |
+
normalise = transforms.Compose([
|
60 |
+
transforms.ToTensor(),
|
61 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
62 |
+
])
|
63 |
+
|
64 |
+
normalise_resize = transforms.Compose([
|
65 |
+
transforms.Resize((224, 224)),
|
66 |
+
transforms.ToTensor(),
|
67 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
68 |
+
])
|
69 |
+
|
70 |
+
def init_image(size=(400, 400, 3)):
|
71 |
+
img = PIL.Image.fromarray(np.uint8(np.full(size, 150)))
|
72 |
+
img = PIL.Image.fromarray(np.uint8(np.random.uniform(150, 180, size)))
|
73 |
+
img_tensor = normalise(img).unsqueeze(0)
|
74 |
+
img_np = img_tensor.numpy()
|
75 |
+
return img, img_tensor, img_np
|
76 |
+
|
77 |
+
def load_image(path, resize=False, size=None):
|
78 |
+
img = PIL.Image.open(path)
|
79 |
+
|
80 |
+
# if size is not None:
|
81 |
+
# img.thumbnail(size, Image.ANTIALIAS)
|
82 |
+
|
83 |
+
|
84 |
+
if resize:
|
85 |
+
img_tensor = normalise_resize(img).unsqueeze(0)
|
86 |
+
else:
|
87 |
+
img_tensor = normalise(img).unsqueeze(0)
|
88 |
+
img_np = img_tensor.numpy()
|
89 |
+
return img, img_tensor, img_np
|
90 |
+
|
91 |
+
def tensor_to_img(t):
|
92 |
+
a = t.numpy()
|
93 |
+
mean = np.array([0.485, 0.456, 0.406]).reshape([1, 1, 3])
|
94 |
+
std = np.array([0.229, 0.224, 0.225]).reshape([1, 1, 3])
|
95 |
+
inp = a[0, :, :, :]
|
96 |
+
inp = inp.transpose(1, 2, 0)
|
97 |
+
inp = std * inp + mean
|
98 |
+
inp *= 255
|
99 |
+
inp = np.uint8(np.clip(inp, 0, 255))
|
100 |
+
return PIL.Image.fromarray(inp)
|
101 |
+
|
102 |
+
def image_to_variable(image, requires_grad=False, cuda=False):
|
103 |
+
if cuda:
|
104 |
+
image = Variable(image.cuda(), requires_grad=requires_grad)
|
105 |
+
else:
|
106 |
+
image = Variable(image, requires_grad=requires_grad)
|
107 |
+
return image
|
108 |
+
|
109 |
+
model = models.vgg16(pretrained=True)
|
110 |
+
|
111 |
+
use_gpu = False
|
112 |
+
if torch.cuda.is_available():
|
113 |
+
use_gpu = True
|
114 |
+
|
115 |
+
print(model)
|
116 |
+
|
117 |
+
for param in model.parameters():
|
118 |
+
param.requires_grad = False
|
119 |
+
|
120 |
+
if use_gpu:
|
121 |
+
print("Using CUDA")
|
122 |
+
model.cuda()
|
123 |
+
|
124 |
+
def octaver_fn(model, base_img, step_fn, octave_n=6, octave_scale=1.4, iter_n=10, **step_args):
|
125 |
+
octaves = [base_img]#list of octaves with base image as the first argument
|
126 |
+
|
127 |
+
for i in range(octave_n - 1):#number of octaves that are to be applied
|
128 |
+
octaves.append(nd.zoom(octaves[-1], (1, 1, 1.0 / octave_scale, 1.0 / octave_scale), order=1))
|
129 |
+
|
130 |
+
detail = np.zeros_like(octaves[-1])#Initializes a detail image with zeros, having the same shape as the last octave image in octaves list
|
131 |
+
for octave, octave_base in enumerate(octaves[::-1]):#octaves list is reversed and then enumerated
|
132 |
+
h, w = octave_base.shape[-2:]#second last and last element in the shape of the enumerating object
|
133 |
+
|
134 |
+
if octave > 0:
|
135 |
+
h1, w1 = detail.shape[-2:]
|
136 |
+
detail = nd.zoom(detail, (1, 1, 1.0 * h / h1, 1.0 * w / w1), order=1)#resize detail image
|
137 |
+
|
138 |
+
src = octave_base + detail
|
139 |
+
|
140 |
+
for i in range(iter_n):
|
141 |
+
src = step_fn(model, src, **step_args)
|
142 |
+
|
143 |
+
detail = src.numpy() - octave_base#modified image - current base , no more zeros
|
144 |
+
|
145 |
+
return src
|
146 |
+
|
147 |
+
|
148 |
+
def objective(dst, guide_features):#return the objective image we need for further operations
|
149 |
+
if guide_features is None:
|
150 |
+
return dst.data
|
151 |
+
else:
|
152 |
+
x = dst.data[0].cpu().numpy()
|
153 |
+
y = guide_features.data[0].cpu().numpy()
|
154 |
+
ch, w, h = x.shape
|
155 |
+
x = x.reshape(ch, -1)
|
156 |
+
y = y.reshape(ch, -1)
|
157 |
+
A = x.T.dot(y)
|
158 |
+
diff = y[:, A.argmax(1)]
|
159 |
+
diff = torch.Tensor(np.array([diff.reshape(ch, w, h)])).cuda()
|
160 |
+
return diff
|
161 |
+
|
162 |
+
def make_step(model, img, objective=objective, control=None, step_size=1.5, end=28, jitter=32):
|
163 |
+
global use_gpu
|
164 |
+
|
165 |
+
mean = np.array([0.485, 0.456, 0.406]).reshape([3, 1, 1])
|
166 |
+
std = np.array([0.229, 0.224, 0.225]).reshape([3, 1, 1])
|
167 |
+
|
168 |
+
#introducing a randomness in picture to avoid local minimas
|
169 |
+
ox, oy = np.random.randint(-jitter, jitter+1, 2)
|
170 |
+
img = np.roll(np.roll(img, ox, -1), oy, -2)
|
171 |
+
|
172 |
+
#preparing for grad ascent
|
173 |
+
tensor = torch.Tensor(img)
|
174 |
+
img_var = image_to_variable(tensor, requires_grad=True, cuda=use_gpu)
|
175 |
+
model.zero_grad()
|
176 |
+
|
177 |
+
#Forward Pass through the Model
|
178 |
+
x = img_var
|
179 |
+
for index, layer in enumerate(model.features.children()):
|
180 |
+
x = layer(x)
|
181 |
+
if index == end:
|
182 |
+
break
|
183 |
+
|
184 |
+
delta = objective(x, control)
|
185 |
+
x.backward(delta)#we calc loss wrt a custom objective function
|
186 |
+
|
187 |
+
#L2 Regularization on gradients
|
188 |
+
mean_square = torch.Tensor([torch.mean(img_var.grad.data ** 2)])
|
189 |
+
if use_gpu:
|
190 |
+
mean_square = mean_square.cuda()
|
191 |
+
img_var.grad.data /= torch.sqrt(mean_square)#scaling
|
192 |
+
img_var.data.add_(img_var.grad.data * step_size)#updating image
|
193 |
+
|
194 |
+
result = img_var.data.cpu().numpy()
|
195 |
+
result = np.roll(np.roll(result, -ox, -1), -oy, -2)#reverse jitter effect
|
196 |
+
result[0, :, :, :] = np.clip(result[0, :, :, :], -mean / std, (1 - mean) / std)#clipping
|
197 |
+
showtensor(result)
|
198 |
+
|
199 |
+
return torch.Tensor(result)
|
200 |
+
|
201 |
+
def deepdream(model, base_img, octave_n=6, octave_scale=1.4,
|
202 |
+
iter_n=10, end=28, control=None, objective=objective,
|
203 |
+
step_size=1.5, jitter=32):
|
204 |
+
|
205 |
+
return octaver_fn(
|
206 |
+
model, base_img, step_fn=make_step,
|
207 |
+
octave_n=octave_n, octave_scale=octave_scale,
|
208 |
+
iter_n=iter_n, end=end, control=control,
|
209 |
+
objective=objective, step_size=step_size, jitter=jitter
|
210 |
+
)
|
211 |
+
|
212 |
+
# input_img, input_tensor, input_np = load_image('IMG_20201204_125738.jpg')
|
213 |
+
# dream = deepdream(model, input_np, end=14, step_size=0.06, octave_n=6)
|
214 |
+
# dream = tensor_to_img(dream)
|
215 |
+
# dream.save('dream00.jpg')
|
216 |
+
# dream
|
217 |
+
|
218 |
+
st.title('Deep Dream Generator')
|
219 |
+
|
220 |
+
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
221 |
+
|
222 |
+
if uploaded_file is not None:
|
223 |
+
# Display the uploaded image
|
224 |
+
try:
|
225 |
+
image = PILImage.open(uploaded_file)
|
226 |
+
st.image(image, caption='Uploaded Image.', use_column_width=True)
|
227 |
+
|
228 |
+
# Get user input for end value
|
229 |
+
end_value = st.slider('End Value', min_value=0, max_value=50, value=14, step=1)
|
230 |
+
octave_value=st.slider('Scaling factor',min_value=3,max_value=10,value=6,step=1)
|
231 |
+
|
232 |
+
# Generate deep dream
|
233 |
+
if st.button('Generate Deep Dream'):
|
234 |
+
img, img_tensor, img_np = load_image(uploaded_file)
|
235 |
+
dream = deepdream(model, img_np, end=end_value, step_size=0.06, octave_n=octave_value)
|
236 |
+
dream_img = tensor_to_img(dream)
|
237 |
+
st.image(dream_img, caption='Generated Deep Dream Image.', use_column_width=True)
|
238 |
+
|
239 |
+
except PIL.UnidentifiedImageError:
|
240 |
+
st.error("Unable to open the uploaded image. Please make sure it is a valid image file.")
|
241 |
+
|
242 |
+
|
243 |
+
st.text("Made with π¦π by Abhinav")
|