Spaces:
Sleeping
Sleeping
Update gligen/ldm/models/diffusion/plms.py
Browse files
gligen/ldm/models/diffusion/plms.py
CHANGED
@@ -108,21 +108,10 @@ class PLMSSampler(object):
|
|
108 |
# three loss types
|
109 |
if loss_type !=None and loss_type!='standard':
|
110 |
if input['object_position'] != []:
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
elif loss_type=='CAR':
|
116 |
-
x = self.update_loss_only_cross( input,i, index, ts )
|
117 |
-
elif loss_type=='LoCo':
|
118 |
-
|
119 |
-
#print('Utilizing LoCo!!')
|
120 |
-
time_factor = noise_scheduler.sigmas[i] ** 2
|
121 |
-
x = self.update_loss_LoCo( input,i, index, ts, time_factor = time_factor)
|
122 |
-
|
123 |
-
elif loss_type=='LAC':
|
124 |
-
#print('Utilizing LoCo!!')
|
125 |
-
x = self.update_loss_LAC( input,i, index, ts )
|
126 |
input["x"] = x
|
127 |
img, pred_x0, e_t = self.p_sample_plms(input, ts, index=index, uc=uc, guidance_scale=guidance_scale, old_eps=old_eps, t_next=ts_next)
|
128 |
input["x"] = img
|
@@ -132,86 +121,7 @@ class PLMSSampler(object):
|
|
132 |
|
133 |
return img
|
134 |
|
135 |
-
def update_loss_self_cross(self, input,index1, index, ts,type_loss='self_accross' ):
|
136 |
-
if index1 < 10:
|
137 |
-
loss_scale = 3
|
138 |
-
max_iter = 5
|
139 |
-
elif index1 < 20:
|
140 |
-
loss_scale = 2
|
141 |
-
max_iter = 3
|
142 |
-
else:
|
143 |
-
loss_scale = 1
|
144 |
-
max_iter = 1
|
145 |
-
|
146 |
-
loss_threshold = 0.1
|
147 |
-
max_index = 30
|
148 |
-
x = deepcopy(input["x"])
|
149 |
-
iteration = 0
|
150 |
-
loss = torch.tensor(10000)
|
151 |
-
input["timesteps"] = ts
|
152 |
-
|
153 |
-
print("optimize", index1)
|
154 |
-
while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
|
155 |
-
print('iter', iteration)
|
156 |
-
x = x.requires_grad_(True)
|
157 |
-
input['x'] = x
|
158 |
-
e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
|
159 |
-
bboxes = input['boxes']
|
160 |
-
object_positions = input['object_position']
|
161 |
-
loss1 = caculate_loss_self_att(self_first, self_second, self_third, bboxes=bboxes,
|
162 |
-
object_positions=object_positions, t = index1)*loss_scale
|
163 |
-
loss2 = caculate_loss_att_fixed_cnt(att_second,att_first,att_third, bboxes=bboxes,
|
164 |
-
object_positions=object_positions, t = index1)*loss_scale
|
165 |
-
loss = loss1 + loss2
|
166 |
-
print('AR loss:', loss, 'SAR:', loss1, 'CAR:', loss2)
|
167 |
-
hh = torch.autograd.backward(loss)
|
168 |
-
grad_cond = x.grad
|
169 |
-
x = x - grad_cond
|
170 |
-
x = x.detach()
|
171 |
-
iteration += 1
|
172 |
-
torch.cuda.empty_cache()
|
173 |
-
return x
|
174 |
-
|
175 |
-
def update_loss_only_cross(self, input,index1, index, ts,type_loss='self_accross'):
|
176 |
-
|
177 |
-
if index1 < 10:
|
178 |
-
loss_scale = 3
|
179 |
-
max_iter = 5
|
180 |
-
elif index1 < 20:
|
181 |
-
loss_scale = 2
|
182 |
-
max_iter = 5
|
183 |
-
else:
|
184 |
-
loss_scale = 1
|
185 |
-
max_iter = 1
|
186 |
-
loss_threshold = 0.1
|
187 |
|
188 |
-
max_index = 30
|
189 |
-
x = deepcopy(input["x"])
|
190 |
-
iteration = 0
|
191 |
-
loss = torch.tensor(10000)
|
192 |
-
input["timesteps"] = ts
|
193 |
-
|
194 |
-
print("optimize", index1)
|
195 |
-
while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
|
196 |
-
print('iter', iteration)
|
197 |
-
x = x.requires_grad_(True)
|
198 |
-
print('x shape', x.shape)
|
199 |
-
input['x'] = x
|
200 |
-
e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
|
201 |
-
|
202 |
-
bboxes = input['boxes']
|
203 |
-
object_positions = input['object_position']
|
204 |
-
loss2 = caculate_loss_att_fixed_cnt(att_second,att_first,att_third, bboxes=bboxes,
|
205 |
-
object_positions=object_positions, t = index1)*loss_scale
|
206 |
-
loss = loss2
|
207 |
-
print('loss', loss)
|
208 |
-
hh = torch.autograd.backward(loss, retain_graph=True)
|
209 |
-
grad_cond = x.grad
|
210 |
-
x = x - grad_cond
|
211 |
-
x = x.detach()
|
212 |
-
iteration += 1
|
213 |
-
torch.cuda.empty_cache()
|
214 |
-
return x
|
215 |
|
216 |
def update_loss_LoCo(self, input,index1, index, ts, time_factor, type_loss='self_accross'):
|
217 |
|
@@ -261,91 +171,8 @@ class PLMSSampler(object):
|
|
261 |
torch.cuda.empty_cache()
|
262 |
return x
|
263 |
|
264 |
-
def update_loss_LAC(self, input,index1, index, ts,type_loss='self_accross'):
|
265 |
-
|
266 |
-
# loss_scale = 30
|
267 |
-
# max_iter = 5
|
268 |
-
|
269 |
-
if index1 < 10:
|
270 |
-
loss_scale = 6
|
271 |
-
max_iter = 5
|
272 |
-
elif index1 < 20:
|
273 |
-
loss_scale = 4
|
274 |
-
max_iter = 3
|
275 |
-
else:
|
276 |
-
loss_scale = 1
|
277 |
-
max_iter = 1
|
278 |
-
loss_threshold = 0.002
|
279 |
-
|
280 |
-
max_index = 30
|
281 |
-
x = deepcopy(input["x"])
|
282 |
-
iteration = 0
|
283 |
-
loss = torch.tensor(10000)
|
284 |
-
input["timesteps"] = ts
|
285 |
-
|
286 |
-
print("optimize", index1)
|
287 |
-
while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
|
288 |
-
print('iter', iteration)
|
289 |
-
x = x.requires_grad_(True)
|
290 |
-
# print('x shape', x.shape)
|
291 |
-
input['x'] = x
|
292 |
-
e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
|
293 |
-
|
294 |
-
bboxes = input['boxes']
|
295 |
-
object_positions = input['object_position']
|
296 |
-
loss2 = caculate_loss_LAC(att_second,att_first,att_third, bboxes=bboxes,
|
297 |
-
object_positions=object_positions, t = index1)*loss_scale
|
298 |
-
loss = loss2
|
299 |
-
print('LoCo loss', loss)
|
300 |
-
hh = torch.autograd.backward(loss, retain_graph=True)
|
301 |
-
grad_cond = x.grad
|
302 |
-
x = x - grad_cond
|
303 |
-
x = x.detach()
|
304 |
-
iteration += 1
|
305 |
-
torch.cuda.empty_cache()
|
306 |
-
return x
|
307 |
-
|
308 |
|
309 |
|
310 |
-
def update_only_self(self, input,index1, index, ts,type_loss='self_accross' ):
|
311 |
-
if index1 < 10:
|
312 |
-
loss_scale = 4
|
313 |
-
max_iter = 5
|
314 |
-
elif index1 < 20:
|
315 |
-
loss_scale = 3
|
316 |
-
max_iter = 5
|
317 |
-
else:
|
318 |
-
loss_scale = 1
|
319 |
-
max_iter = 1
|
320 |
-
loss_threshold = 0.1
|
321 |
-
|
322 |
-
max_index = 30
|
323 |
-
x = deepcopy(input["x"])
|
324 |
-
iteration = 0
|
325 |
-
loss = torch.tensor(10000)
|
326 |
-
input["timesteps"] = ts
|
327 |
-
|
328 |
-
print("optimize", index1)
|
329 |
-
while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
|
330 |
-
print('iter', iteration)
|
331 |
-
x = x.requires_grad_(True)
|
332 |
-
input['x'] = x
|
333 |
-
e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
|
334 |
-
|
335 |
-
bboxes = input['boxes']
|
336 |
-
object_positions = input['object_position']
|
337 |
-
loss = caculate_loss_self_att(self_first, self_second, self_third, bboxes=bboxes,
|
338 |
-
object_positions=object_positions, t = index1)*loss_scale
|
339 |
-
print('loss', loss)
|
340 |
-
hh = torch.autograd.backward(loss)
|
341 |
-
grad_cond = x.grad
|
342 |
-
|
343 |
-
x = x - grad_cond
|
344 |
-
x = x.detach()
|
345 |
-
iteration += 1
|
346 |
-
torch.cuda.empty_cache()
|
347 |
-
return x
|
348 |
-
|
349 |
@torch.no_grad()
|
350 |
def p_sample_plms(self, input, t, index, guidance_scale=1., uc=None, old_eps=None, t_next=None):
|
351 |
x = deepcopy(input["x"])
|
|
|
108 |
# three loss types
|
109 |
if loss_type !=None and loss_type!='standard':
|
110 |
if input['object_position'] != []:
|
111 |
+
|
112 |
+
time_factor = noise_scheduler.sigmas[i] ** 2
|
113 |
+
x = self.update_loss_LoCo( input,i, index, ts, time_factor = time_factor)
|
114 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
input["x"] = x
|
116 |
img, pred_x0, e_t = self.p_sample_plms(input, ts, index=index, uc=uc, guidance_scale=guidance_scale, old_eps=old_eps, t_next=ts_next)
|
117 |
input["x"] = img
|
|
|
121 |
|
122 |
return img
|
123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
def update_loss_LoCo(self, input,index1, index, ts, time_factor, type_loss='self_accross'):
|
127 |
|
|
|
171 |
torch.cuda.empty_cache()
|
172 |
return x
|
173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
@torch.no_grad()
|
177 |
def p_sample_plms(self, input, t, index, guidance_scale=1., uc=None, old_eps=None, t_next=None):
|
178 |
x = deepcopy(input["x"])
|