Q-bert commited on
Commit
bdd4211
1 Parent(s): afe29cc

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +94 -0
README.md ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ tags:
6
+ - mamba-hf
7
+ ---
8
+
9
+ # MambaHermes-3B
10
+
11
+ <img src="https://th.bing.com/th/id/OIG.Jp5dA01tOAFcwSp544nv?pid=ImgGn" width="300" height="300" alt="mamba-hf">
12
+
13
+ Mamba Models with hf_integration.
14
+
15
+ For modeling codes: [**mamba-hf**](https://github.com/LegallyCoder/mamba-hf)
16
+
17
+ # Usage:
18
+
19
+ ```python
20
+ import torch
21
+ from transformers import AutoTokenizer, AutoModelForCausalLM
22
+
23
+ CHAT_TEMPLATE_ID = "HuggingFaceH4/zephyr-7b-beta"
24
+
25
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
26
+ model_name = "Q-bert/MambaHermes-3B"
27
+
28
+ eos_token = "<|endoftext|>"
29
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
30
+ tokenizer.eos_token = eos_token
31
+ tokenizer.pad_token = tokenizer.eos_token
32
+ tokenizer.chat_template = AutoTokenizer.from_pretrained(CHAT_TEMPLATE_ID).chat_template
33
+
34
+ model = AutoModelForCausalLM.from_pretrained(
35
+ model_name, device=device, trust_remote_code=True, dtype=torch.float16)
36
+
37
+ messages = []
38
+ prompt = "Tell me 5 sites to visit in Spain"
39
+ messages.append(dict(role="user", content=prompt))
40
+
41
+ input_ids = tokenizer.apply_chat_template(
42
+ messages, return_tensors="pt", add_generation_prompt=True
43
+ ).to(device)
44
+
45
+ out = model.generate(
46
+ input_ids=input_ids,
47
+ max_length=2000,
48
+ temperature=0.9,
49
+ top_p=0.7,
50
+ eos_token_id=tokenizer.eos_token_id,
51
+ )
52
+
53
+ decoded = tokenizer.batch_decode(out)
54
+ assistant_message = (
55
+ decoded[0].split("<|assistant|>\n")[-1].replace(eos, "")
56
+ )
57
+
58
+ print(assistant_message)
59
+
60
+ ```
61
+
62
+
63
+ # For Training:
64
+ ```python
65
+ from transformers import Trainer ,TrainingArguments
66
+ import torch
67
+ import os
68
+
69
+
70
+ class MambaTrainer(Trainer):
71
+ def compute_loss(self, model, inputs, return_outputs=False):
72
+ input_ids = inputs.pop("input_ids")
73
+ lm_logits = model(input_ids)[0]
74
+
75
+ labels = input_ids.to(lm_logits.device)
76
+ shift_logits = lm_logits[:, :-1, :].contiguous()
77
+ labels = labels[:, 1:].contiguous()
78
+
79
+ loss_fct = torch.nn.CrossEntropyLoss()
80
+ lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
81
+
82
+ return lm_loss
83
+ ```
84
+
85
+ You must use this class for training. And fp16 must be **False**.
86
+
87
+ # Credits:
88
+
89
+ https://huggingface.co/state-spaces
90
+
91
+ https://huggingface.co/clibrain/mamba-2.8b-instruct-openhermes
92
+
93
+ Special thanks to Albert Gu and Tri Dao for their articles. (https://arxiv.org/abs/2312.00752)
94
+