Pusheen commited on
Commit
d708ebf
1 Parent(s): f4f89f9

Update gligen/ldm/models/diffusion/plms.py

Browse files
Files changed (1) hide show
  1. gligen/ldm/models/diffusion/plms.py +62 -22
gligen/ldm/models/diffusion/plms.py CHANGED
@@ -119,55 +119,95 @@ class PLMSSampler(object):
119
  return img
120
 
121
 
122
- def update_loss_LoCo(self, input,index1, index, ts, type_loss='self_accross'):
123
 
124
- # loss_scale = 30
125
- # max_iter = 5
126
- #print('time_factor is: ', time_factor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  if index1 < 10:
128
- loss_scale = 8
129
  max_iter = 5
130
  elif index1 < 20:
131
- loss_scale = 5
132
  max_iter = 5
133
  else:
134
  loss_scale = 1
135
  max_iter = 1
136
  loss_threshold = 0.1
137
-
138
  max_index = 30
139
  x = deepcopy(input["x"])
140
  iteration = 0
141
  loss = torch.tensor(10000)
142
  input["timesteps"] = ts
143
 
144
- # print("optimize", index1)
145
- self.model.train()
146
  while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
147
- # print('iter', iteration)
148
  x = x.requires_grad_(True)
149
- # print('x shape', x.shape)
150
  input['x'] = x
151
  e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
152
 
153
- bboxes = input['boxes_att']
154
  object_positions = input['object_position']
155
  loss2 = caculate_loss_LoCo(att_second,att_first,att_third, bboxes=bboxes,
156
  object_positions=object_positions, t = index1)*loss_scale
157
- # loss = loss2
158
- # loss.requires_grad_(True)
159
- #print('LoCo loss', loss)
160
-
161
-
162
-
163
- grad_cond = torch.autograd.grad(loss2.requires_grad_(True), [x])[0]
164
- # grad_cond = x.grad
165
- x = x - grad_cond
166
  x = x.detach()
167
  iteration += 1
168
  torch.cuda.empty_cache()
169
  return x
170
-
171
  def update_loss_self_cross(self, input,index1, index, ts,type_loss='self_accross' ):
172
  if index1 < 10:
173
  loss_scale = 4
 
119
  return img
120
 
121
 
122
+ # def update_loss_LoCo(self, input,index1, index, ts, type_loss='self_accross'):
123
 
124
+ # # loss_scale = 30
125
+ # # max_iter = 5
126
+ # #print('time_factor is: ', time_factor)
127
+ # if index1 < 10:
128
+ # loss_scale = 8
129
+ # max_iter = 5
130
+ # elif index1 < 20:
131
+ # loss_scale = 5
132
+ # max_iter = 5
133
+ # else:
134
+ # loss_scale = 1
135
+ # max_iter = 1
136
+ # loss_threshold = 0.1
137
+
138
+ # max_index = 30
139
+ # x = deepcopy(input["x"])
140
+ # iteration = 0
141
+ # loss = torch.tensor(10000)
142
+ # input["timesteps"] = ts
143
+
144
+ # # print("optimize", index1)
145
+ # self.model.train()
146
+ # while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
147
+ # # print('iter', iteration)
148
+ # x = x.requires_grad_(True)
149
+ # # print('x shape', x.shape)
150
+ # input['x'] = x
151
+ # e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
152
+
153
+ # bboxes = input['boxes_att']
154
+ # object_positions = input['object_position']
155
+ # loss2 = caculate_loss_LoCo(att_second,att_first,att_third, bboxes=bboxes,
156
+ # object_positions=object_positions, t = index1)*loss_scale
157
+ # # loss = loss2
158
+ # # loss.requires_grad_(True)
159
+ # #print('LoCo loss', loss)
160
+
161
+
162
+
163
+ # grad_cond = torch.autograd.grad(loss2.requires_grad_(True), [x])[0]
164
+ # # grad_cond = x.grad
165
+ # x = x - grad_cond
166
+ # x = x.detach()
167
+ # iteration += 1
168
+ # torch.cuda.empty_cache()
169
+ # return x
170
+
171
+ def update_loss_LoCo(self, input,index1, index, ts,type_loss='self_accross'):
172
+
173
  if index1 < 10:
174
+ loss_scale = 3
175
  max_iter = 5
176
  elif index1 < 20:
177
+ loss_scale = 2
178
  max_iter = 5
179
  else:
180
  loss_scale = 1
181
  max_iter = 1
182
  loss_threshold = 0.1
183
+
184
  max_index = 30
185
  x = deepcopy(input["x"])
186
  iteration = 0
187
  loss = torch.tensor(10000)
188
  input["timesteps"] = ts
189
 
190
+ print("optimize", index1)
 
191
  while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
192
+ print('iter', iteration)
193
  x = x.requires_grad_(True)
 
194
  input['x'] = x
195
  e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
196
 
197
+ bboxes = input['boxes']
198
  object_positions = input['object_position']
199
  loss2 = caculate_loss_LoCo(att_second,att_first,att_third, bboxes=bboxes,
200
  object_positions=object_positions, t = index1)*loss_scale
201
+ loss = loss2
202
+ print('loss', loss)
203
+ hh = torch.autograd.backward(loss)
204
+ grad_cond = x.grad
205
+ x = x - grad_cond
 
 
 
 
206
  x = x.detach()
207
  iteration += 1
208
  torch.cuda.empty_cache()
209
  return x
210
+
211
  def update_loss_self_cross(self, input,index1, index, ts,type_loss='self_accross' ):
212
  if index1 < 10:
213
  loss_scale = 4