Pusheen commited on
Commit
4abfe96
1 Parent(s): 1e34662

Update gligen/ldm/models/diffusion/plms.py

Browse files
Files changed (1) hide show
  1. gligen/ldm/models/diffusion/plms.py +4 -177
gligen/ldm/models/diffusion/plms.py CHANGED
@@ -108,21 +108,10 @@ class PLMSSampler(object):
108
  # three loss types
109
  if loss_type !=None and loss_type!='standard':
110
  if input['object_position'] != []:
111
- if loss_type=='SAR_CAR':
112
- x = self.update_loss_self_cross( input,i, index, ts )
113
- elif loss_type=='SAR':
114
- x = self.update_only_self( input,i, index, ts )
115
- elif loss_type=='CAR':
116
- x = self.update_loss_only_cross( input,i, index, ts )
117
- elif loss_type=='LoCo':
118
-
119
- #print('Utilizing LoCo!!')
120
- time_factor = noise_scheduler.sigmas[i] ** 2
121
- x = self.update_loss_LoCo( input,i, index, ts, time_factor = time_factor)
122
-
123
- elif loss_type=='LAC':
124
- #print('Utilizing LoCo!!')
125
- x = self.update_loss_LAC( input,i, index, ts )
126
  input["x"] = x
127
  img, pred_x0, e_t = self.p_sample_plms(input, ts, index=index, uc=uc, guidance_scale=guidance_scale, old_eps=old_eps, t_next=ts_next)
128
  input["x"] = img
@@ -132,86 +121,7 @@ class PLMSSampler(object):
132
 
133
  return img
134
 
135
- def update_loss_self_cross(self, input,index1, index, ts,type_loss='self_accross' ):
136
- if index1 < 10:
137
- loss_scale = 3
138
- max_iter = 5
139
- elif index1 < 20:
140
- loss_scale = 2
141
- max_iter = 3
142
- else:
143
- loss_scale = 1
144
- max_iter = 1
145
-
146
- loss_threshold = 0.1
147
- max_index = 30
148
- x = deepcopy(input["x"])
149
- iteration = 0
150
- loss = torch.tensor(10000)
151
- input["timesteps"] = ts
152
-
153
- print("optimize", index1)
154
- while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
155
- print('iter', iteration)
156
- x = x.requires_grad_(True)
157
- input['x'] = x
158
- e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
159
- bboxes = input['boxes']
160
- object_positions = input['object_position']
161
- loss1 = caculate_loss_self_att(self_first, self_second, self_third, bboxes=bboxes,
162
- object_positions=object_positions, t = index1)*loss_scale
163
- loss2 = caculate_loss_att_fixed_cnt(att_second,att_first,att_third, bboxes=bboxes,
164
- object_positions=object_positions, t = index1)*loss_scale
165
- loss = loss1 + loss2
166
- print('AR loss:', loss, 'SAR:', loss1, 'CAR:', loss2)
167
- hh = torch.autograd.backward(loss)
168
- grad_cond = x.grad
169
- x = x - grad_cond
170
- x = x.detach()
171
- iteration += 1
172
- torch.cuda.empty_cache()
173
- return x
174
-
175
- def update_loss_only_cross(self, input,index1, index, ts,type_loss='self_accross'):
176
-
177
- if index1 < 10:
178
- loss_scale = 3
179
- max_iter = 5
180
- elif index1 < 20:
181
- loss_scale = 2
182
- max_iter = 5
183
- else:
184
- loss_scale = 1
185
- max_iter = 1
186
- loss_threshold = 0.1
187
 
188
- max_index = 30
189
- x = deepcopy(input["x"])
190
- iteration = 0
191
- loss = torch.tensor(10000)
192
- input["timesteps"] = ts
193
-
194
- print("optimize", index1)
195
- while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
196
- print('iter', iteration)
197
- x = x.requires_grad_(True)
198
- print('x shape', x.shape)
199
- input['x'] = x
200
- e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
201
-
202
- bboxes = input['boxes']
203
- object_positions = input['object_position']
204
- loss2 = caculate_loss_att_fixed_cnt(att_second,att_first,att_third, bboxes=bboxes,
205
- object_positions=object_positions, t = index1)*loss_scale
206
- loss = loss2
207
- print('loss', loss)
208
- hh = torch.autograd.backward(loss, retain_graph=True)
209
- grad_cond = x.grad
210
- x = x - grad_cond
211
- x = x.detach()
212
- iteration += 1
213
- torch.cuda.empty_cache()
214
- return x
215
 
216
  def update_loss_LoCo(self, input,index1, index, ts, time_factor, type_loss='self_accross'):
217
 
@@ -261,91 +171,8 @@ class PLMSSampler(object):
261
  torch.cuda.empty_cache()
262
  return x
263
 
264
- def update_loss_LAC(self, input,index1, index, ts,type_loss='self_accross'):
265
-
266
- # loss_scale = 30
267
- # max_iter = 5
268
-
269
- if index1 < 10:
270
- loss_scale = 6
271
- max_iter = 5
272
- elif index1 < 20:
273
- loss_scale = 4
274
- max_iter = 3
275
- else:
276
- loss_scale = 1
277
- max_iter = 1
278
- loss_threshold = 0.002
279
-
280
- max_index = 30
281
- x = deepcopy(input["x"])
282
- iteration = 0
283
- loss = torch.tensor(10000)
284
- input["timesteps"] = ts
285
-
286
- print("optimize", index1)
287
- while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
288
- print('iter', iteration)
289
- x = x.requires_grad_(True)
290
- # print('x shape', x.shape)
291
- input['x'] = x
292
- e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
293
-
294
- bboxes = input['boxes']
295
- object_positions = input['object_position']
296
- loss2 = caculate_loss_LAC(att_second,att_first,att_third, bboxes=bboxes,
297
- object_positions=object_positions, t = index1)*loss_scale
298
- loss = loss2
299
- print('LoCo loss', loss)
300
- hh = torch.autograd.backward(loss, retain_graph=True)
301
- grad_cond = x.grad
302
- x = x - grad_cond
303
- x = x.detach()
304
- iteration += 1
305
- torch.cuda.empty_cache()
306
- return x
307
-
308
 
309
 
310
- def update_only_self(self, input,index1, index, ts,type_loss='self_accross' ):
311
- if index1 < 10:
312
- loss_scale = 4
313
- max_iter = 5
314
- elif index1 < 20:
315
- loss_scale = 3
316
- max_iter = 5
317
- else:
318
- loss_scale = 1
319
- max_iter = 1
320
- loss_threshold = 0.1
321
-
322
- max_index = 30
323
- x = deepcopy(input["x"])
324
- iteration = 0
325
- loss = torch.tensor(10000)
326
- input["timesteps"] = ts
327
-
328
- print("optimize", index1)
329
- while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
330
- print('iter', iteration)
331
- x = x.requires_grad_(True)
332
- input['x'] = x
333
- e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
334
-
335
- bboxes = input['boxes']
336
- object_positions = input['object_position']
337
- loss = caculate_loss_self_att(self_first, self_second, self_third, bboxes=bboxes,
338
- object_positions=object_positions, t = index1)*loss_scale
339
- print('loss', loss)
340
- hh = torch.autograd.backward(loss)
341
- grad_cond = x.grad
342
-
343
- x = x - grad_cond
344
- x = x.detach()
345
- iteration += 1
346
- torch.cuda.empty_cache()
347
- return x
348
-
349
  @torch.no_grad()
350
  def p_sample_plms(self, input, t, index, guidance_scale=1., uc=None, old_eps=None, t_next=None):
351
  x = deepcopy(input["x"])
 
108
  # three loss types
109
  if loss_type !=None and loss_type!='standard':
110
  if input['object_position'] != []:
111
+
112
+ time_factor = noise_scheduler.sigmas[i] ** 2
113
+ x = self.update_loss_LoCo( input,i, index, ts, time_factor = time_factor)
114
+
 
 
 
 
 
 
 
 
 
 
 
115
  input["x"] = x
116
  img, pred_x0, e_t = self.p_sample_plms(input, ts, index=index, uc=uc, guidance_scale=guidance_scale, old_eps=old_eps, t_next=ts_next)
117
  input["x"] = img
 
121
 
122
  return img
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  def update_loss_LoCo(self, input,index1, index, ts, time_factor, type_loss='self_accross'):
127
 
 
171
  torch.cuda.empty_cache()
172
  return x
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  @torch.no_grad()
177
  def p_sample_plms(self, input, t, index, guidance_scale=1., uc=None, old_eps=None, t_next=None):
178
  x = deepcopy(input["x"])