Text Generation
Transformers
PyTorch
Chinese
English
gpt2
text-generation-inference
Inference Endpoints
Hollway commited on
Commit
59db153
1 Parent(s): 09748a0

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +33 -10
README.md CHANGED
@@ -13,10 +13,11 @@ pipeline_tag: text-generation
13
 
14
  ## 1 Usage
15
 
16
- ### 1.1 Initalization
17
- ```
18
  !pip install transformers[torch]
19
 
 
20
  from transformers import GPT2Tokenizer, GPT2LMHeadModel
21
  import torch
22
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -25,21 +26,43 @@ tokenizer = GPT2Tokenizer.from_pretrained('Hollway/gpt2_finetune')
25
  model = GPT2LMHeadModel.from_pretrained('Hollway/gpt2_finetune').to(device)
26
  ```
27
 
28
- ### 1.2 Inference
29
  ```
30
- def generate(text):
31
- inputs = tokenizer(text, return_tensors="pt").to(model.device)
32
  with torch.no_grad():
33
  tokens = model.generate(
34
  **inputs,
35
- max_new_tokens=256,
36
  do_sample=True,
37
- temperature=0.7,
38
- top_p=0.9,
39
- repetition_penalty=1.05,
40
  pad_token_id=tokenizer.pad_token_id,
41
  )
42
  return tokenizer.decode(tokens[0], skip_special_tokens=True)
43
 
44
- generate("只因你")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  ```
 
13
 
14
  ## 1 Usage
15
 
16
+ ### 1.1 Initalization 初始化
17
+
18
  !pip install transformers[torch]
19
 
20
+ ```
21
  from transformers import GPT2Tokenizer, GPT2LMHeadModel
22
  import torch
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
26
  model = GPT2LMHeadModel.from_pretrained('Hollway/gpt2_finetune').to(device)
27
  ```
28
 
29
+ ### 1.2 Inference 基本推理任务
30
  ```
31
+ def generate(text): # 基本的下文预测任务
32
+ inputs = tokenizer(text, return_tensors="pt").to(device)
33
  with torch.no_grad():
34
  tokens = model.generate(
35
  **inputs,
36
+ max_new_tokens=512,
37
  do_sample=True,
 
 
 
38
  pad_token_id=tokenizer.pad_token_id,
39
  )
40
  return tokenizer.decode(tokens[0], skip_special_tokens=True)
41
 
42
+ generate("派蒙是应急食品,但是不能吃派蒙,请分析不能吃的原因。")
43
+ ```
44
+
45
+ ### 1.3 Chatbot 聊天模式
46
+ ```
47
+ def chat(turns=5): # 多轮对话模式,通过字符串拼接实现。
48
+ for step in range(turns):
49
+ query = input(">> 用户:")
50
+ new_user_input_ids = tokenizer.encode(
51
+ f"用户: {query}\n\n系统: ", return_tensors='pt').to(device)
52
+ bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids
53
+
54
+ base_tokens = bot_input_ids.shape[-1]
55
+ chat_history_ids = model.generate(
56
+ bot_input_ids,
57
+ max_length=base_tokens+64, # 单次回复的最大token数量
58
+ do_sample=True,
59
+ pad_token_id=tokenizer.eos_token_id)
60
+
61
+ response = tokenizer.decode(
62
+ chat_history_ids[:, bot_input_ids.shape[-1]:][0],
63
+ skip_special_tokens=True)
64
+
65
+ print(f"系统: {response}\n")
66
+
67
+ chat(turns=5)
68
  ```