Q-bert commited on
Commit
6b64b3e
1 Parent(s): b671602

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +24 -0
README.md CHANGED
@@ -35,6 +35,30 @@ print(generated_text)
35
  ```
36
  > Hi, I'm looking for a new job. I've been working at a company for about a year now.
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  # Credits:
39
 
40
  https://huggingface.co/state-spaces
 
35
  ```
36
  > Hi, I'm looking for a new job. I've been working at a company for about a year now.
37
 
38
+ # For Training:
39
+ ```python
40
+ from transformers import Trainer ,TrainingArguments
41
+ import torch
42
+ import os
43
+
44
+
45
+ class MambaTrainer(Trainer):
46
+ def compute_loss(self, model, inputs, return_outputs=False):
47
+ input_ids = inputs.pop("input_ids")
48
+ lm_logits = model(input_ids)[0]
49
+
50
+ labels = input_ids.to(lm_logits.device)
51
+ shift_logits = lm_logits[:, :-1, :].contiguous()
52
+ labels = labels[:, 1:].contiguous()
53
+
54
+ loss_fct = torch.nn.CrossEntropyLoss()
55
+ lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
56
+
57
+ return lm_loss
58
+ ```
59
+
60
+ You must use this class for training. And fp16 must be **False**.
61
+
62
  # Credits:
63
 
64
  https://huggingface.co/state-spaces