Pusheen commited on
Commit
ecd3e86
1 Parent(s): 5c29177

Update gligen/ldm/models/diffusion/plms.py

Browse files
Files changed (1) hide show
  1. gligen/ldm/models/diffusion/plms.py +58 -8
gligen/ldm/models/diffusion/plms.py CHANGED
@@ -5,7 +5,7 @@ from functools import partial
5
  from copy import deepcopy
6
  from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
7
  import math
8
- from ldm.models.diffusion.loss import caculate_loss_att_fixed_cnt, caculate_loss_self_att
9
  class PLMSSampler(object):
10
  def __init__(self, diffusion, model, schedule="linear", alpha_generator_func=None, set_alpha_scale=None):
11
  super().__init__()
@@ -102,12 +102,13 @@ class PLMSSampler(object):
102
  # three loss types
103
  if loss_type !=None and loss_type!='standard':
104
  if input['object_position'] != []:
105
- if loss_type=='SAR_CAR':
106
- x = self.update_loss_self_cross( input,i, index, ts )
107
- elif loss_type=='SAR':
108
- x = self.update_only_self( input,i, index, ts )
109
- elif loss_type=='CAR':
110
- x = self.update_loss_only_cross( input,i, index, ts )
 
111
  input["x"] = x
112
  img, pred_x0, e_t = self.p_sample_plms(input, ts, index=index, uc=uc, guidance_scale=guidance_scale, old_eps=old_eps, t_next=ts_next)
113
  input["x"] = img
@@ -116,7 +117,56 @@ class PLMSSampler(object):
116
  old_eps.pop(0)
117
 
118
  return img
119
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  def update_loss_self_cross(self, input,index1, index, ts,type_loss='self_accross' ):
121
  if index1 < 10:
122
  loss_scale = 4
 
5
  from copy import deepcopy
6
  from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
7
  import math
8
+ from ldm.models.diffusion.loss import caculate_loss_att_fixed_cnt, caculate_loss_self_att, caculate_loss_LoCo_V2
9
  class PLMSSampler(object):
10
  def __init__(self, diffusion, model, schedule="linear", alpha_generator_func=None, set_alpha_scale=None):
11
  super().__init__()
 
102
  # three loss types
103
  if loss_type !=None and loss_type!='standard':
104
  if input['object_position'] != []:
105
+ # if loss_type=='SAR_CAR':
106
+ # x = self.update_loss_self_cross( input,i, index, ts )
107
+ # elif loss_type=='SAR':
108
+ # x = self.update_only_self( input,i, index, ts )
109
+ # elif loss_type=='CAR':
110
+ # x = self.update_loss_only_cross( input,i, index, ts )
111
+ x = self.update_loss_LoCo( input,i, index, ts )
112
  input["x"] = x
113
  img, pred_x0, e_t = self.p_sample_plms(input, ts, index=index, uc=uc, guidance_scale=guidance_scale, old_eps=old_eps, t_next=ts_next)
114
  input["x"] = img
 
117
  old_eps.pop(0)
118
 
119
  return img
120
+
121
+
122
+ def update_loss_LoCo(self, input,index1, index, ts, time_factor, type_loss='self_accross'):
123
+
124
+ # loss_scale = 30
125
+ # max_iter = 5
126
+ #print('time_factor is: ', time_factor)
127
+ if index1 < 10:
128
+ loss_scale = 8
129
+ max_iter = 5
130
+ elif index1 < 20:
131
+ loss_scale = 5
132
+ max_iter = 5
133
+ else:
134
+ loss_scale = 1
135
+ max_iter = 1
136
+ loss_threshold = 0.1
137
+
138
+ max_index = 30
139
+ x = deepcopy(input["x"])
140
+ iteration = 0
141
+ loss = torch.tensor(10000)
142
+ input["timesteps"] = ts
143
+
144
+ # print("optimize", index1)
145
+ while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
146
+ # print('iter', iteration)
147
+ x = x.requires_grad_(True)
148
+ # print('x shape', x.shape)
149
+ input['x'] = x
150
+ e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
151
+
152
+ bboxes = input['boxes']
153
+ object_positions = input['object_position']
154
+ loss2 = caculate_loss_LoCo_V2(att_second,att_first,att_third, bboxes=bboxes,
155
+ object_positions=object_positions, t = index1)*loss_scale
156
+ # loss = loss2
157
+ # loss.requires_grad_(True)
158
+ #print('LoCo loss', loss)
159
+
160
+
161
+
162
+ grad_cond = torch.autograd.grad(loss2.requires_grad_(True), [x])[0]
163
+ # grad_cond = x.grad
164
+ x = x - grad_cond
165
+ x = x.detach()
166
+ iteration += 1
167
+ torch.cuda.empty_cache()
168
+ return x
169
+
170
  def update_loss_self_cross(self, input,index1, index, ts,type_loss='self_accross' ):
171
  if index1 < 10:
172
  loss_scale = 4