Pusheen commited on
Commit
5c29177
1 Parent(s): be0890e

Update gligen/ldm/models/diffusion/plms.py

Browse files
Files changed (1) hide show
  1. gligen/ldm/models/diffusion/plms.py +10 -62
gligen/ldm/models/diffusion/plms.py CHANGED
@@ -5,7 +5,7 @@ from functools import partial
5
  from copy import deepcopy
6
  from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
7
  import math
8
- from ldm.models.diffusion.loss import caculate_loss_att_fixed_cnt, caculate_loss_self_att, caculate_loss_LoCo_V2
9
  class PLMSSampler(object):
10
  def __init__(self, diffusion, model, schedule="linear", alpha_generator_func=None, set_alpha_scale=None):
11
  super().__init__()
@@ -57,14 +57,14 @@ class PLMSSampler(object):
57
 
58
 
59
  # @torch.no_grad()
60
- def sample(self, S, shape, input, uc=None, guidance_scale=1, mask=None, x0=None, loss_type='LoCo'):
61
  self.make_schedule(ddim_num_steps=S)
62
  # import pdb; pdb.set_trace()
63
  return self.plms_sampling(shape, input, uc, guidance_scale, mask=mask, x0=x0, loss_type=loss_type)
64
 
65
 
66
  # @torch.no_grad()
67
- def plms_sampling(self, shape, input, uc=None, guidance_scale=1, mask=None, x0=None, loss_type='LoCo'):
68
 
69
  b = shape[0]
70
 
@@ -102,16 +102,12 @@ class PLMSSampler(object):
102
  # three loss types
103
  if loss_type !=None and loss_type!='standard':
104
  if input['object_position'] != []:
105
- # if loss_type=='SAR_CAR':
106
- # x = self.update_loss_self_cross( input,i, index, ts )
107
- # elif loss_type=='SAR':
108
- # x = self.update_only_self( input,i, index, ts )
109
- # elif loss_type=='CAR':
110
- # x = self.update_loss_only_cross( input,i, index, ts )
111
- # elif loss_type=='LoCo':
112
-
113
- x = self.update_loss_only_cross( input,i, index, ts )
114
- # x = self.update_loss_LoCo( input,i, index, ts, )
115
  input["x"] = x
116
  img, pred_x0, e_t = self.p_sample_plms(input, ts, index=index, uc=uc, guidance_scale=guidance_scale, old_eps=old_eps, t_next=ts_next)
117
  input["x"] = img
@@ -120,55 +116,7 @@ class PLMSSampler(object):
120
  old_eps.pop(0)
121
 
122
  return img
123
-
124
- def update_loss_LoCo(self, input,index1, index, ts, type_loss='self_accross'):
125
-
126
- # loss_scale = 30
127
- # max_iter = 5
128
- #print('time_factor is: ', time_factor)
129
- if index1 < 10:
130
- loss_scale = 8
131
- max_iter = 5
132
- elif index1 < 20:
133
- loss_scale = 5
134
- max_iter = 5
135
- else:
136
- loss_scale = 1
137
- max_iter = 1
138
- loss_threshold = 0.1
139
-
140
- max_index = 30
141
- x = deepcopy(input["x"])
142
- iteration = 0
143
- loss = torch.tensor(10000)
144
- input["timesteps"] = ts
145
-
146
- # print("optimize", index1)
147
- while loss.item() > loss_threshold and iteration < max_iter and (index1 < max_index) :
148
- # print('iter', iteration)
149
- x = x.requires_grad_(True)
150
- # print('x shape', x.shape)
151
- input['x'] = x
152
- e_t, att_first, att_second, att_third, self_first, self_second, self_third = self.model(input)
153
-
154
- bboxes = input['boxes']
155
- object_positions = input['object_position']
156
- loss2 = caculate_loss_LoCo_V2(att_second,att_first,att_third, bboxes=bboxes,
157
- object_positions=object_positions, t = index1)*loss_scale
158
- # loss = loss2
159
- # loss.requires_grad_(True)
160
- #print('LoCo loss', loss)
161
-
162
-
163
-
164
- hh = torch.autograd.backward(loss2, retain_graph=True)
165
- grad_cond = x.grad
166
- x = x - grad_cond
167
- x = x.detach()
168
- iteration += 1
169
- torch.cuda.empty_cache()
170
- return x
171
-
172
  def update_loss_self_cross(self, input,index1, index, ts,type_loss='self_accross' ):
173
  if index1 < 10:
174
  loss_scale = 4
 
5
  from copy import deepcopy
6
  from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
7
  import math
8
+ from ldm.models.diffusion.loss import caculate_loss_att_fixed_cnt, caculate_loss_self_att
9
  class PLMSSampler(object):
10
  def __init__(self, diffusion, model, schedule="linear", alpha_generator_func=None, set_alpha_scale=None):
11
  super().__init__()
 
57
 
58
 
59
  # @torch.no_grad()
60
+ def sample(self, S, shape, input, uc=None, guidance_scale=1, mask=None, x0=None, loss_type='SAR_CAR'):
61
  self.make_schedule(ddim_num_steps=S)
62
  # import pdb; pdb.set_trace()
63
  return self.plms_sampling(shape, input, uc, guidance_scale, mask=mask, x0=x0, loss_type=loss_type)
64
 
65
 
66
  # @torch.no_grad()
67
+ def plms_sampling(self, shape, input, uc=None, guidance_scale=1, mask=None, x0=None, loss_type='SAR_CAR'):
68
 
69
  b = shape[0]
70
 
 
102
  # three loss types
103
  if loss_type !=None and loss_type!='standard':
104
  if input['object_position'] != []:
105
+ if loss_type=='SAR_CAR':
106
+ x = self.update_loss_self_cross( input,i, index, ts )
107
+ elif loss_type=='SAR':
108
+ x = self.update_only_self( input,i, index, ts )
109
+ elif loss_type=='CAR':
110
+ x = self.update_loss_only_cross( input,i, index, ts )
 
 
 
 
111
  input["x"] = x
112
  img, pred_x0, e_t = self.p_sample_plms(input, ts, index=index, uc=uc, guidance_scale=guidance_scale, old_eps=old_eps, t_next=ts_next)
113
  input["x"] = img
 
116
  old_eps.pop(0)
117
 
118
  return img
119
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  def update_loss_self_cross(self, input,index1, index, ts,type_loss='self_accross' ):
121
  if index1 < 10:
122
  loss_scale = 4