amaye15 commited on
Commit
9e4975c
1 Parent(s): 65e03b8

Upload AutoEncoder

Browse files
Files changed (2) hide show
  1. model.safetensors +1 -1
  2. modeling_autoencoder.py +1 -1
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e25c1acc996e01cfcae7e106a00d1cb9b0f87489ef8ab0adc42f7baf6c82666b
3
  size 133840
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf45d89286799f00cb616e4cf8a0e16c087162b7e23a5f4c4a98d5641ec51c4c
3
  size 133840
modeling_autoencoder.py CHANGED
@@ -331,7 +331,7 @@ class AutoEncoder(PreTrainedModel):
331
  outputs.loss = loss_fn(outputs.logits.view(-1), outputs.labels.view(-1))
332
  elif not torch.is_floating_point(outputs.labels) and not torch.is_complex(outputs.labels):
333
  loss_fn = nn.CrossEntropyLoss()
334
- outputs.loss = loss_fn(outputs.logits.view(-1, self.config.vocab_size), outputs.labels.view(-1))
335
  else:
336
  raise ValueError("Unsupported tensor dtype for these loss functions")
337
 
 
331
  outputs.loss = loss_fn(outputs.logits.view(-1), outputs.labels.view(-1))
332
  elif not torch.is_floating_point(outputs.labels) and not torch.is_complex(outputs.labels):
333
  loss_fn = nn.CrossEntropyLoss()
334
+ outputs.loss = loss_fn(outputs.logits.reshape(-1, self.config.vocab_size), outputs.labels.view(-1))
335
  else:
336
  raise ValueError("Unsupported tensor dtype for these loss functions")
337