Files changed (1) hide show
  1. build_mlp.py +1 -1
build_mlp.py CHANGED
@@ -192,9 +192,9 @@ class PLoRA(nn.Linear):
192
  def forward(self, x, im_mask=None):
193
  B, N, C = x.shape
194
  x = x.reshape(-1, C)
195
- im_mask = im_mask.view(-1)
196
  res = super().forward(x)
197
  if im_mask is not None:
 
198
  if torch.sum(im_mask) > 0:
199
  part_x = x[im_mask]
200
  res[im_mask] += self.Plora_B(self.Plora_A(
 
192
  def forward(self, x, im_mask=None):
193
  B, N, C = x.shape
194
  x = x.reshape(-1, C)
 
195
  res = super().forward(x)
196
  if im_mask is not None:
197
+ im_mask = im_mask.view(-1)
198
  if torch.sum(im_mask) > 0:
199
  part_x = x[im_mask]
200
  res[im_mask] += self.Plora_B(self.Plora_A(