Spaces:
Sleeping
Sleeping
Update gligen/ldm/models/diffusion/plms.py
Browse files
gligen/ldm/models/diffusion/plms.py
CHANGED
@@ -108,7 +108,7 @@ class PLMSSampler(object):
|
|
108 |
# x = self.update_only_self( input,i, index, ts )
|
109 |
# elif loss_type=='CAR':
|
110 |
# x = self.update_loss_only_cross( input,i, index, ts )
|
111 |
-
x = self.
|
112 |
input["x"] = x
|
113 |
img, pred_x0, e_t = self.p_sample_plms(input, ts, index=index, uc=uc, guidance_scale=guidance_scale, old_eps=old_eps, t_next=ts_next)
|
114 |
input["x"] = img
|
@@ -171,16 +171,16 @@ class PLMSSampler(object):
|
|
171 |
def update_loss_LoCo(self, input,index1, index, ts,type_loss='self_accross'):
|
172 |
|
173 |
if index1 < 10:
|
174 |
-
loss_scale =
|
175 |
-
max_iter =
|
176 |
elif index1 < 20:
|
177 |
-
loss_scale =
|
178 |
-
max_iter =
|
179 |
else:
|
180 |
loss_scale = 1
|
181 |
max_iter = 1
|
|
|
182 |
loss_threshold = 0.1
|
183 |
-
|
184 |
max_index = 30
|
185 |
x = deepcopy(input["x"])
|
186 |
iteration = 0
|
@@ -188,24 +188,29 @@ class PLMSSampler(object):
|
|
188 |
input["timesteps"] = ts
|
189 |
|
190 |
print("optimize", index1)
|
|
|
191 |
while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
|
192 |
print('iter', iteration)
|
|
|
193 |
x = x.requires_grad_(True)
|
194 |
input['x'] = x
|
195 |
e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
|
196 |
-
|
197 |
-
bboxes = input['boxes']
|
198 |
object_positions = input['object_position']
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
|
|
|
|
|
|
206 |
x = x.detach()
|
207 |
iteration += 1
|
208 |
-
|
|
|
209 |
return x
|
210 |
|
211 |
def update_loss_self_cross(self, input,index1, index, ts,type_loss='self_accross' ):
|
|
|
108 |
# x = self.update_only_self( input,i, index, ts )
|
109 |
# elif loss_type=='CAR':
|
110 |
# x = self.update_loss_only_cross( input,i, index, ts )
|
111 |
+
x = self.update_loss_LoCo( input,i, index, ts )
|
112 |
input["x"] = x
|
113 |
img, pred_x0, e_t = self.p_sample_plms(input, ts, index=index, uc=uc, guidance_scale=guidance_scale, old_eps=old_eps, t_next=ts_next)
|
114 |
input["x"] = img
|
|
|
171 |
def update_loss_LoCo(self, input,index1, index, ts,type_loss='self_accross'):
|
172 |
|
173 |
if index1 < 10:
|
174 |
+
loss_scale = 4
|
175 |
+
max_iter = 1
|
176 |
elif index1 < 20:
|
177 |
+
loss_scale = 3
|
178 |
+
max_iter = 1
|
179 |
else:
|
180 |
loss_scale = 1
|
181 |
max_iter = 1
|
182 |
+
|
183 |
loss_threshold = 0.1
|
|
|
184 |
max_index = 30
|
185 |
x = deepcopy(input["x"])
|
186 |
iteration = 0
|
|
|
188 |
input["timesteps"] = ts
|
189 |
|
190 |
print("optimize", index1)
|
191 |
+
self.model.train()
|
192 |
while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
|
193 |
print('iter', iteration)
|
194 |
+
# import pdb; pdb.set_trace()
|
195 |
x = x.requires_grad_(True)
|
196 |
input['x'] = x
|
197 |
e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
|
198 |
+
bboxes = input['boxes_att']
|
|
|
199 |
object_positions = input['object_position']
|
200 |
+
loss1 = caculate_loss_self_att(self_first, self_second, self_third, bboxes=bboxes,
|
201 |
+
object_positions=object_positions, t = index1)*loss_scale
|
202 |
+
loss2 = caculate_loss_att_fixed_cnt(att_second,att_first,att_third, bboxes=bboxes,
|
203 |
+
object_positions=object_positions, t = index1)*loss_scale
|
204 |
+
loss = loss1 + loss2
|
205 |
+
print('loss', loss, loss1, loss2)
|
206 |
+
# hh = torch.autograd.backward(loss, retain_graph=True)
|
207 |
+
grad_cond = torch.autograd.grad(loss.requires_grad_(True), [x])[0]
|
208 |
+
# grad_cond = x.grad
|
209 |
+
x = x - grad_cond
|
210 |
x = x.detach()
|
211 |
iteration += 1
|
212 |
+
|
213 |
+
|
214 |
return x
|
215 |
|
216 |
def update_loss_self_cross(self, input,index1, index, ts,type_loss='self_accross' ):
|