SunderAli17 commited on
Commit
27faac2
1 Parent(s): b2fb80e

Update flux/model.py

Browse files
Files changed (1) hide show
  1. flux/model.py +7 -7
flux/model.py CHANGED
@@ -79,9 +79,9 @@ class Flux(nn.Module):
79
 
80
  self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
81
 
82
- self.pulid_ca = None
83
- self.pulid_double_interval = 2
84
- self.pulid_single_interval = 4
85
 
86
  def forward(
87
  self,
@@ -115,8 +115,8 @@ class Flux(nn.Module):
115
  for i, block in enumerate(self.double_blocks):
116
  img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
117
 
118
- if i % self.pulid_double_interval == 0 and id is not None:
119
- img = img + id_weight * self.pulid_ca[ca_idx](id, img)
120
  ca_idx += 1
121
 
122
  img = torch.cat((txt, img), 1)
@@ -124,8 +124,8 @@ class Flux(nn.Module):
124
  x = block(img, vec=vec, pe=pe)
125
  real_img, txt = x[:, txt.shape[1]:, ...], x[:, :txt.shape[1], ...]
126
 
127
- if i % self.pulid_single_interval == 0 and id is not None:
128
- real_img = real_img + id_weight * self.pulid_ca[ca_idx](id, real_img)
129
  ca_idx += 1
130
 
131
  img = torch.cat((txt, real_img), 1)
 
79
 
80
  self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
81
 
82
+ self.toonmage_ca = None
83
+ self.toonmage_double_interval = 2
84
+ self.toonmage_single_interval = 4
85
 
86
  def forward(
87
  self,
 
115
  for i, block in enumerate(self.double_blocks):
116
  img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
117
 
118
+ if i % self.toonmage_double_interval == 0 and id is not None:
119
+ img = img + id_weight * self.toonmage_ca[ca_idx](id, img)
120
  ca_idx += 1
121
 
122
  img = torch.cat((txt, img), 1)
 
124
  x = block(img, vec=vec, pe=pe)
125
  real_img, txt = x[:, txt.shape[1]:, ...], x[:, :txt.shape[1], ...]
126
 
127
+ if i % self.toonmage_single_interval == 0 and id is not None:
128
+ real_img = real_img + id_weight * self.toonmage_ca[ca_idx](id, real_img)
129
  ca_idx += 1
130
 
131
  img = torch.cat((txt, real_img), 1)