Spaces:
Sleeping
Sleeping
Update gligen/ldm/models/diffusion/plms.py
Browse files
gligen/ldm/models/diffusion/plms.py
CHANGED
@@ -119,55 +119,95 @@ class PLMSSampler(object):
|
|
119 |
return img
|
120 |
|
121 |
|
122 |
-
def update_loss_LoCo(self, input,index1, index, ts, type_loss='self_accross'):
|
123 |
|
124 |
-
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
if index1 < 10:
|
128 |
-
loss_scale =
|
129 |
max_iter = 5
|
130 |
elif index1 < 20:
|
131 |
-
loss_scale =
|
132 |
max_iter = 5
|
133 |
else:
|
134 |
loss_scale = 1
|
135 |
max_iter = 1
|
136 |
loss_threshold = 0.1
|
137 |
-
|
138 |
max_index = 30
|
139 |
x = deepcopy(input["x"])
|
140 |
iteration = 0
|
141 |
loss = torch.tensor(10000)
|
142 |
input["timesteps"] = ts
|
143 |
|
144 |
-
|
145 |
-
self.model.train()
|
146 |
while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
|
147 |
-
|
148 |
x = x.requires_grad_(True)
|
149 |
-
# print('x shape', x.shape)
|
150 |
input['x'] = x
|
151 |
e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
|
152 |
|
153 |
-
bboxes = input['
|
154 |
object_positions = input['object_position']
|
155 |
loss2 = caculate_loss_LoCo(att_second,att_first,att_third, bboxes=bboxes,
|
156 |
object_positions=object_positions, t = index1)*loss_scale
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
grad_cond = torch.autograd.grad(loss2.requires_grad_(True), [x])[0]
|
164 |
-
# grad_cond = x.grad
|
165 |
-
x = x - grad_cond
|
166 |
x = x.detach()
|
167 |
iteration += 1
|
168 |
torch.cuda.empty_cache()
|
169 |
return x
|
170 |
-
|
171 |
def update_loss_self_cross(self, input,index1, index, ts,type_loss='self_accross' ):
|
172 |
if index1 < 10:
|
173 |
loss_scale = 4
|
|
|
119 |
return img
|
120 |
|
121 |
|
122 |
+
# def update_loss_LoCo(self, input,index1, index, ts, type_loss='self_accross'):
|
123 |
|
124 |
+
# # loss_scale = 30
|
125 |
+
# # max_iter = 5
|
126 |
+
# #print('time_factor is: ', time_factor)
|
127 |
+
# if index1 < 10:
|
128 |
+
# loss_scale = 8
|
129 |
+
# max_iter = 5
|
130 |
+
# elif index1 < 20:
|
131 |
+
# loss_scale = 5
|
132 |
+
# max_iter = 5
|
133 |
+
# else:
|
134 |
+
# loss_scale = 1
|
135 |
+
# max_iter = 1
|
136 |
+
# loss_threshold = 0.1
|
137 |
+
|
138 |
+
# max_index = 30
|
139 |
+
# x = deepcopy(input["x"])
|
140 |
+
# iteration = 0
|
141 |
+
# loss = torch.tensor(10000)
|
142 |
+
# input["timesteps"] = ts
|
143 |
+
|
144 |
+
# # print("optimize", index1)
|
145 |
+
# self.model.train()
|
146 |
+
# while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
|
147 |
+
# # print('iter', iteration)
|
148 |
+
# x = x.requires_grad_(True)
|
149 |
+
# # print('x shape', x.shape)
|
150 |
+
# input['x'] = x
|
151 |
+
# e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
|
152 |
+
|
153 |
+
# bboxes = input['boxes_att']
|
154 |
+
# object_positions = input['object_position']
|
155 |
+
# loss2 = caculate_loss_LoCo(att_second,att_first,att_third, bboxes=bboxes,
|
156 |
+
# object_positions=object_positions, t = index1)*loss_scale
|
157 |
+
# # loss = loss2
|
158 |
+
# # loss.requires_grad_(True)
|
159 |
+
# #print('LoCo loss', loss)
|
160 |
+
|
161 |
+
|
162 |
+
|
163 |
+
# grad_cond = torch.autograd.grad(loss2.requires_grad_(True), [x])[0]
|
164 |
+
# # grad_cond = x.grad
|
165 |
+
# x = x - grad_cond
|
166 |
+
# x = x.detach()
|
167 |
+
# iteration += 1
|
168 |
+
# torch.cuda.empty_cache()
|
169 |
+
# return x
|
170 |
+
|
171 |
+
def update_loss_LoCo(self, input,index1, index, ts,type_loss='self_accross'):
|
172 |
+
|
173 |
if index1 < 10:
|
174 |
+
loss_scale = 3
|
175 |
max_iter = 5
|
176 |
elif index1 < 20:
|
177 |
+
loss_scale = 2
|
178 |
max_iter = 5
|
179 |
else:
|
180 |
loss_scale = 1
|
181 |
max_iter = 1
|
182 |
loss_threshold = 0.1
|
183 |
+
|
184 |
max_index = 30
|
185 |
x = deepcopy(input["x"])
|
186 |
iteration = 0
|
187 |
loss = torch.tensor(10000)
|
188 |
input["timesteps"] = ts
|
189 |
|
190 |
+
print("optimize", index1)
|
|
|
191 |
while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
|
192 |
+
print('iter', iteration)
|
193 |
x = x.requires_grad_(True)
|
|
|
194 |
input['x'] = x
|
195 |
e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
|
196 |
|
197 |
+
bboxes = input['boxes']
|
198 |
object_positions = input['object_position']
|
199 |
loss2 = caculate_loss_LoCo(att_second,att_first,att_third, bboxes=bboxes,
|
200 |
object_positions=object_positions, t = index1)*loss_scale
|
201 |
+
loss = loss2
|
202 |
+
print('loss', loss)
|
203 |
+
hh = torch.autograd.backward(loss)
|
204 |
+
grad_cond = x.grad
|
205 |
+
x = x - grad_cond
|
|
|
|
|
|
|
|
|
206 |
x = x.detach()
|
207 |
iteration += 1
|
208 |
torch.cuda.empty_cache()
|
209 |
return x
|
210 |
+
|
211 |
def update_loss_self_cross(self, input,index1, index, ts,type_loss='self_accross' ):
|
212 |
if index1 < 10:
|
213 |
loss_scale = 4
|