|
--- |
|
license: bigscience-bloom-rail-1.0 |
|
datasets: |
|
- YeungNLP/firefly-train-1.1M |
|
- BelleGroup/train_2M_CN |
|
language: |
|
- zh |
|
--- |
|
# Langboat_bloom-6b4-zh-instruct_finetune-chat |
|
是基于Langboat_bloom-6b4-zh模型,在firefly-train-1.1M和Belle-train_2m_cn数据集上采用的QLoRA方法微调的对话模型。 |
|
在CEVAL上的评测结果: |
|
|
|
| STEM | Social Sciences | Humanities | Others | Average | AVG(Hard) | |
|
|------|-----------------|------------|--------|---------|-----------| |
|
| 27.9 | 27.2 | 24.8 | 26.4 | 26.8 | 28.0 | |
|
|
|
# 使用 |
|
## 单轮指令生成 |
|
```python |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
device = "cuda" |
|
model = AutoModelForCausalLM.from_pretrained("SmilePanda/Langboat_bloom-6b4-zh-instruct_finetune-chat", device_map=device) |
|
tokenizer = AutoTokenizer.from_pretrained("SmilePanda/Langboat_bloom-6b4-zh-instruct_finetune-chat", use_fast=False) |
|
|
|
source_prefix = "human" |
|
target_prefix = "assistant" |
|
query = "你好" |
|
sentence = f"{source_prefix}: \n{query}\n\n{target_prefix}: \n" |
|
print("query: ", sentence) |
|
input_ids = tokenizer(sentence, return_tensors='pt').input_ids.to(device) |
|
outputs = model.generate(input_ids=input_ids, max_new_tokens=500, |
|
do_sample=True, |
|
top_p=0.8, |
|
temperature=0.35, |
|
repetition_penalty=1.2, |
|
eos_token_id=tokenizer.eos_token_id) |
|
rets = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].strip() |
|
response = rets.replace(sentence, "") |
|
print(response) |
|
``` |
|
|
|
## 多轮对话 |
|
```python |
|
import os |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
device = "cuda" |
|
model = AutoModelForCausalLM.from_pretrained("SmilePanda/Langboat_bloom-6b4-zh-instruct_finetune-chat", device_map=device) |
|
tokenizer = AutoTokenizer.from_pretrained("SmilePanda/Langboat_bloom-6b4-zh-instruct_finetune-chat", use_fast=False) |
|
|
|
source_prefix = "human" |
|
target_prefix = "assistant" |
|
|
|
history = "" |
|
|
|
while True: |
|
query = input("user: ").strip() |
|
if not query: |
|
continue |
|
if query == 'q' or query == 'stop': |
|
break |
|
if history: |
|
sentence = history + f"\n{source_prefix}: \n{query}\n\n{target_prefix}: \n" |
|
else: |
|
sentence = f"{source_prefix}: \n{query}\n\n{target_prefix}: \n" |
|
input_ids = tokenizer(sentence, return_tensors='pt').input_ids.to(device) |
|
outputs = model.generate(input_ids=input_ids, max_new_tokens=1024, |
|
do_sample=True, |
|
top_p=0.90, |
|
temperature=0.1, |
|
repetition_penalty=1.0, |
|
eos_token_id=tokenizer.eos_token_id) |
|
rets = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].strip() |
|
print("bloom: {}".format(rets.replace(sentence, ""))) |
|
history = rets |
|
``` |
|
|