oshizo commited on
Commit
2b46c24
1 Parent(s): 89ad7ba

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +85 -0
README.md CHANGED
@@ -1,3 +1,88 @@
1
  ---
2
  license: mit
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
+ datasets:
4
+ - SkelterLabsInc/JaQuAD
5
+ language:
6
+ - ja
7
  ---
8
+ # Model Card for Model ID
9
+
10
+ <!-- Provide a quick summary of what the model is/does. -->
11
+
12
+ このモデルはrinna/japanese-gpt-1bをベースモデルとして、
13
+ コンテキストからの抽出型QAと、解答を新たなコンテキストでリファインするための学習を行ったモデルです。
14
+
15
+ gpt-index(v0.2.5)で利用することを前提に学習をしており、通常のQAタスクで使用することは想定していません。
16
+
17
+ 利用例はこのリポジトリを参照してください。
18
+ https://github.com/oshizo/gpt_index_japanese_trial
19
+
20
+
21
+ # Model Details
22
+
23
+ モデルは2種類のpromptテンプレートに対してQA応答するように訓練されています。
24
+
25
+ ```python
26
+ DEFAULT_PROMPT = """
27
+ 文脈情報は以下です。
28
+ ---
29
+ {context_str}
30
+ ---
31
+ 事前知識ではなく、文脈情報を参考に質問に答えてください。:{query_str}
32
+ """
33
+ ```
34
+
35
+ ```python
36
+ REFINE_PROMPT = """
37
+ 質問は以下です。:{query_str}
38
+ すでに答えの候補があります。:{existing_answer}
39
+ 必要な場合のみ、以下の文脈情報を使ってこの答えを改良することができます。
40
+ ---
41
+ {context_msg}
42
+ ---
43
+ この文脈情報により、元の答えを改良して質問に答えてください。
44
+ 文脈情報が有用でない場合は元の答えをそのまま返してください。
45
+ """
46
+ ```
47
+
48
+ ```python
49
+ import torch
50
+ from transformers import T5Tokenizer, AutoModelForCausalLM
51
+
52
+ tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt-1b")
53
+ model = AutoModelForCausalLM.from_pretrained("oshizo/qa-refine-japanese-gpt-1b").to("cuda")
54
+
55
+
56
+ prompt = DEFAULT_PROMPT.format(
57
+ context_str="山路を登りながら、こう考えた。智に働けば角が立つ。情に棹させば流される。意地を通せば窮屈だ。とかくに人の世は住みにくい。住みにくさが高じると、安い所へ引き越したくなる。どこへ越しても住みにくいと悟った時、詩が生れて、画が出来る。",
58
+ query_str="意地を通すとどうなってしまう?"
59
+ )
60
+
61
+ token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
62
+ n = len(token_ids[0])
63
+
64
+ with torch.no_grad():
65
+ output_ids = model.generate(
66
+ token_ids.to(model.device),
67
+ max_length=n+100,
68
+ min_length=n+2,
69
+ do_sample=False,
70
+ pad_token_id=tokenizer.pad_token_id,
71
+ bos_token_id=tokenizer.bos_token_id,
72
+ eos_token_id=tokenizer.eos_token_id,
73
+ )
74
+ output = tokenizer.decode(output_ids.tolist()[0][n:])
75
+ output.replace("</s>", "")
76
+
77
+ # -> 窮屈
78
+
79
+ ```
80
+
81
+ # Training Details
82
+
83
+ JGLUE/JSQuADとJaQuADを用いて、コンテキストからの抽出型QAと、解答を新たなコンテキストでリファインするための学習を行いました。
84
+
85
+ 学習スクリプトについてはこのリポジトリを参照してください。
86
+ https://github.com/oshizo/gpt_index_japanese_trial
87
+
88
+ Google Colab Pro A100 で約3.5時間、9.9kステップ学習しました。