YeungNLP commited on
Commit
44fbebe
1 Parent(s): c80ef98

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +50 -0
README.md CHANGED
@@ -1,3 +1,53 @@
1
  ---
2
  license: apache-2.0
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ language:
4
+ - en
5
  ---
6
+
7
+ This model is finetuend on "mistralai/Mixtral-8x7B-v0.1" with Firefly
8
+
9
+ ## Run the model
10
+
11
+ ```python
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer
13
+ import torch
14
+
15
+ model_name_or_path = 'YeungNLP/firefly-mixtral-8x7b'
16
+ max_new_tokens = 500
17
+ top_p = 0.9
18
+ temperature = 0.35
19
+ repetition_penalty = 1.0
20
+
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ model_name_or_path,
23
+ trust_remote_code=True,
24
+ low_cpu_mem_usage=True,
25
+ torch_dtype=torch.float16,
26
+ device_map='auto'
27
+ )
28
+ model = model.eval()
29
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
30
+
31
+ text = "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions."
32
+
33
+ inst_begin_tokens = tokenizer.encode('[INST]', add_special_tokens=False)
34
+ inst_end_tokens = tokenizer.encode('[/INST]', add_special_tokens=False)
35
+ human_tokens = tokenizer.encode(text, add_special_tokens=False)
36
+ input_ids = [tokenizer.bos_token_id] + inst_begin_tokens + human_tokens + inst_end_tokens
37
+
38
+ # input_ids = human_tokens
39
+ input_ids = torch.tensor([input_ids], dtype=torch.long).cuda()
40
+
41
+ with torch.no_grad():
42
+ outputs = model.generate(
43
+ input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=True,
44
+ top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty,
45
+ eos_token_id=tokenizer.eos_token_id
46
+ )
47
+ outputs = outputs.tolist()[0][len(input_ids[0]):]
48
+ response = tokenizer.decode(outputs)
49
+ response = response.strip().replace(tokenizer.eos_token, "").strip()
50
+ print("Chatbot:{}".format(response))
51
+
52
+ ```
53
+