vivicai commited on
Commit
e3f23e8
1 Parent(s): 2d80fa5

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +34 -0
README.md CHANGED
@@ -23,4 +23,38 @@ tokenizer = AutoTokenizer.from_pretrained("TigerResearch/tigerbot-180b-research"
23
 
24
  model = AutoModelForCausalLM.from_pretrained("TigerResearch/tigerbot-180b-research")
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  ```
 
23
 
24
  model = AutoModelForCausalLM.from_pretrained("TigerResearch/tigerbot-180b-research")
25
 
26
+ max_memory = get_balanced_memory(model)
27
+ device_map = infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["BloomBlock"])
28
+ model = dispatch_model(model, device_map=device_map, offload_buffers=True)
29
+
30
+ device = torch.cuda.current_device()
31
+
32
+
33
+ tok_ins = "\n\n### Instruction:\n"
34
+ tok_res = "\n\n### Response:\n"
35
+ prompt_input = tok_ins + "{instruction}" + tok_res
36
+
37
+ input_text = "What is the next number after this list: [1, 2, 3, 5, 8, 13, 21]"
38
+ input_text = prompt_input.format_map({'instruction': input_text})
39
+
40
+ max_input_length = 512
41
+ max_generate_length = 1024
42
+ generation_kwargs = {
43
+ "top_p": 0.95,
44
+ "temperature": 0.8,
45
+ "max_length": max_generate_length,
46
+ "eos_token_id": tokenizer.eos_token_id,
47
+ "pad_token_id": tokenizer.pad_token_id,
48
+ "early_stopping": True,
49
+ "no_repeat_ngram_size": 4,
50
+ }
51
+
52
+ inputs = tokenizer(input_text, return_tensors='pt', truncation=True, max_length=max_input_length)
53
+ inputs = {k: v.to(device) for k, v in inputs.items()}
54
+ output = model.generate(**inputs, **generation_kwargs)
55
+ answer = ''
56
+ for tok_id in output[0][inputs['input_ids'].shape[1]:]:
57
+ if tok_id != tokenizer.eos_token_id:
58
+ answer += tokenizer.decode(tok_id)
59
+ print(answer)
60
  ```