Pusheen commited on
Commit
a81b34e
1 Parent(s): 3059d49

Update gligen/ldm/models/diffusion/plms.py

Browse files
Files changed (1) hide show
  1. gligen/ldm/models/diffusion/plms.py +21 -16
gligen/ldm/models/diffusion/plms.py CHANGED
@@ -108,7 +108,7 @@ class PLMSSampler(object):
108
  # x = self.update_only_self( input,i, index, ts )
109
  # elif loss_type=='CAR':
110
  # x = self.update_loss_only_cross( input,i, index, ts )
111
- x = self.update_loss_only_cross( input,i, index, ts )
112
  input["x"] = x
113
  img, pred_x0, e_t = self.p_sample_plms(input, ts, index=index, uc=uc, guidance_scale=guidance_scale, old_eps=old_eps, t_next=ts_next)
114
  input["x"] = img
@@ -171,16 +171,16 @@ class PLMSSampler(object):
171
  def update_loss_LoCo(self, input,index1, index, ts,type_loss='self_accross'):
172
 
173
  if index1 < 10:
174
- loss_scale = 3
175
- max_iter = 5
176
  elif index1 < 20:
177
- loss_scale = 2
178
- max_iter = 5
179
  else:
180
  loss_scale = 1
181
  max_iter = 1
 
182
  loss_threshold = 0.1
183
-
184
  max_index = 30
185
  x = deepcopy(input["x"])
186
  iteration = 0
@@ -188,24 +188,29 @@ class PLMSSampler(object):
188
  input["timesteps"] = ts
189
 
190
  print("optimize", index1)
 
191
  while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
192
  print('iter', iteration)
 
193
  x = x.requires_grad_(True)
194
  input['x'] = x
195
  e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
196
-
197
- bboxes = input['boxes']
198
  object_positions = input['object_position']
199
- loss2 = caculate_loss_LoCo(att_second,att_first,att_third, bboxes=bboxes,
200
- object_positions=object_positions, t = index1)*loss_scale
201
- loss = loss2
202
- print('loss', loss)
203
- hh = torch.autograd.backward(loss)
204
- grad_cond = x.grad
205
- x = x - grad_cond
 
 
 
206
  x = x.detach()
207
  iteration += 1
208
- torch.cuda.empty_cache()
 
209
  return x
210
 
211
  def update_loss_self_cross(self, input,index1, index, ts,type_loss='self_accross' ):
 
108
  # x = self.update_only_self( input,i, index, ts )
109
  # elif loss_type=='CAR':
110
  # x = self.update_loss_only_cross( input,i, index, ts )
111
+ x = self.update_loss_LoCo( input,i, index, ts )
112
  input["x"] = x
113
  img, pred_x0, e_t = self.p_sample_plms(input, ts, index=index, uc=uc, guidance_scale=guidance_scale, old_eps=old_eps, t_next=ts_next)
114
  input["x"] = img
 
171
  def update_loss_LoCo(self, input,index1, index, ts,type_loss='self_accross'):
172
 
173
  if index1 < 10:
174
+ loss_scale = 4
175
+ max_iter = 1
176
  elif index1 < 20:
177
+ loss_scale = 3
178
+ max_iter = 1
179
  else:
180
  loss_scale = 1
181
  max_iter = 1
182
+
183
  loss_threshold = 0.1
 
184
  max_index = 30
185
  x = deepcopy(input["x"])
186
  iteration = 0
 
188
  input["timesteps"] = ts
189
 
190
  print("optimize", index1)
191
+ self.model.train()
192
  while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
193
  print('iter', iteration)
194
+ # import pdb; pdb.set_trace()
195
  x = x.requires_grad_(True)
196
  input['x'] = x
197
  e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
198
+ bboxes = input['boxes_att']
 
199
  object_positions = input['object_position']
200
+ loss1 = caculate_loss_self_att(self_first, self_second, self_third, bboxes=bboxes,
201
+ object_positions=object_positions, t = index1)*loss_scale
202
+ loss2 = caculate_loss_att_fixed_cnt(att_second,att_first,att_third, bboxes=bboxes,
203
+ object_positions=object_positions, t = index1)*loss_scale
204
+ loss = loss1 + loss2
205
+ print('loss', loss, loss1, loss2)
206
+ # hh = torch.autograd.backward(loss, retain_graph=True)
207
+ grad_cond = torch.autograd.grad(loss.requires_grad_(True), [x])[0]
208
+ # grad_cond = x.grad
209
+ x = x - grad_cond
210
  x = x.detach()
211
  iteration += 1
212
+
213
+
214
  return x
215
 
216
  def update_loss_self_cross(self, input,index1, index, ts,type_loss='self_accross' ):