Text Generation
Transformers
PyTorch
Chinese
English
gpt2
text-generation-inference
Inference Endpoints
gpt2_finetune / README.md
Hollway's picture
Update README.md
59db153
|
raw
history blame
1.94 kB
metadata
language:
  - zh
  - en
license: mit
datasets:
  - TigerResearch/tigerbot-zhihu-zh-10k
  - TigerResearch/tigerbot-book-qa-1k
pipeline_tag: text-generation

中文文本生成

1 Usage

1.1 Initalization 初始化

!pip install transformers[torch]

from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = GPT2Tokenizer.from_pretrained('Hollway/gpt2_finetune')
model = GPT2LMHeadModel.from_pretrained('Hollway/gpt2_finetune').to(device)

1.2 Inference 基本推理任务

def generate(text):  # 基本的下文预测任务
    inputs = tokenizer(text, return_tensors="pt").to(device)
    with torch.no_grad():
        tokens = model.generate(
            **inputs,
            max_new_tokens=512,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
        )
    return tokenizer.decode(tokens[0], skip_special_tokens=True)

generate("派蒙是应急食品,但是不能吃派蒙,请分析不能吃的原因。")

1.3 Chatbot 聊天模式

def chat(turns=5): # 多轮对话模式,通过字符串拼接实现。
    for step in range(turns):
        query = input(">> 用户:")
        new_user_input_ids = tokenizer.encode(
            f"用户: {query}\n\n系统: ", return_tensors='pt').to(device)
        bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids

        base_tokens = bot_input_ids.shape[-1]
        chat_history_ids = model.generate(
            bot_input_ids,
            max_length=base_tokens+64, # 单次回复的最大token数量
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id)

        response = tokenizer.decode(
            chat_history_ids[:, bot_input_ids.shape[-1]:][0], 
            skip_special_tokens=True)

        print(f"系统: {response}\n")

chat(turns=5)