Upload AutoEncoder
Browse files- model.safetensors +1 -1
- modeling_autoencoder.py +1 -1
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 133840
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cf45d89286799f00cb616e4cf8a0e16c087162b7e23a5f4c4a98d5641ec51c4c
|
3 |
size 133840
|
modeling_autoencoder.py
CHANGED
@@ -331,7 +331,7 @@ class AutoEncoder(PreTrainedModel):
|
|
331 |
outputs.loss = loss_fn(outputs.logits.view(-1), outputs.labels.view(-1))
|
332 |
elif not torch.is_floating_point(outputs.labels) and not torch.is_complex(outputs.labels):
|
333 |
loss_fn = nn.CrossEntropyLoss()
|
334 |
-
outputs.loss = loss_fn(outputs.logits.
|
335 |
else:
|
336 |
raise ValueError("Unsupported tensor dtype for these loss functions")
|
337 |
|
|
|
331 |
outputs.loss = loss_fn(outputs.logits.view(-1), outputs.labels.view(-1))
|
332 |
elif not torch.is_floating_point(outputs.labels) and not torch.is_complex(outputs.labels):
|
333 |
loss_fn = nn.CrossEntropyLoss()
|
334 |
+
outputs.loss = loss_fn(outputs.logits.reshape(-1, self.config.vocab_size), outputs.labels.view(-1))
|
335 |
else:
|
336 |
raise ValueError("Unsupported tensor dtype for these loss functions")
|
337 |
|