shunxing1234
commited on
Commit
•
8b77d13
1
Parent(s):
49e363d
Update README.md
Browse files
README.md
CHANGED
@@ -34,6 +34,32 @@ license: other
|
|
34 |
|
35 |
### 1. 推理/Inference
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
|
39 |
|
|
|
34 |
|
35 |
### 1. 推理/Inference
|
36 |
|
37 |
+
```python
|
38 |
+
from transformers import AutoTokenizer, AquilaForCausalLM
|
39 |
+
import torch
|
40 |
+
from cyg_conversation import covert_prompt_to_input_ids_with_history
|
41 |
+
|
42 |
+
tokenizer = AutoTokenizer.from_pretrained("BAAI/AquilaChat-7B")
|
43 |
+
model = AquilaForCausalLM.from_pretrained("BAAI/AquilaChat-7B")
|
44 |
+
model.eval()
|
45 |
+
model.to("cuda:0")
|
46 |
+
vocab = tokenizer.vocab
|
47 |
+
print(len(vocab))
|
48 |
+
|
49 |
+
text = "请给出10个要到北京旅游的理由。"
|
50 |
+
|
51 |
+
tokens = covert_prompt_to_input_ids_with_history(text, history=[], tokenizer=tokenizer, max_token=512)
|
52 |
+
|
53 |
+
tokens = torch.tensor(tokens)[None,].to("cuda:0")
|
54 |
+
|
55 |
+
|
56 |
+
with torch.no_grad():
|
57 |
+
out = model.generate(tokens, do_sample=True, max_length=512, eos_token_id=100007)[0]
|
58 |
+
|
59 |
+
out = tokenizer.decode(out.cpu().numpy().tolist())
|
60 |
+
|
61 |
+
print(out)
|
62 |
+
```
|
63 |
|
64 |
|
65 |
|