tianyuz commited on
Commit
6c7ec60
1 Parent(s): 6804210

update readme

Browse files
Files changed (1) hide show
  1. README.md +147 -8
README.md CHANGED
@@ -21,23 +21,162 @@ This repository provides a base-sized Japanese RoBERTa model. The model is provi
21
 
22
  # How to use the model
23
 
24
- Since this is a private repo, first login your huggingface account from the command line:
25
-
26
- ~~~
27
- transformer-cli login
28
- ~~~
29
-
30
  *NOTE:* Use `T5Tokenizer` to initiate the tokenizer.
31
 
32
  ~~~~
33
  from transformers import T5Tokenizer, RobertaForMaskedLM
34
 
35
- tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-roberta-base", use_auth_token=True)
36
  tokenizer.do_lower_case = True # due to some bug of tokenizer config loading
37
 
38
- model = RobertaForMaskedLM.from_pretrained("rinna/japanese-roberta-base", use_auth_token=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  ~~~~
40
 
 
41
  # Model architecture
42
  A 12-layer, 768-hidden-size transformer-based masked language model.
43
 
 
21
 
22
  # How to use the model
23
 
 
 
 
 
 
 
24
  *NOTE:* Use `T5Tokenizer` to initiate the tokenizer.
25
 
26
  ~~~~
27
  from transformers import T5Tokenizer, RobertaForMaskedLM
28
 
29
+ tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-roberta-base")
30
  tokenizer.do_lower_case = True # due to some bug of tokenizer config loading
31
 
32
+ model = RobertaForMaskedLM.from_pretrained("rinna/japanese-roberta-base")
33
+ ~~~~
34
+
35
+ # How to use the model for masked token prediction
36
+
37
+ *NOTE:* To predict a masked token, be sure to add a `[CLS]` token before the sentence for the model to correctly encode it, as it is used during the model training.
38
+
39
+ Here we adopt the example by [kenta1984](https://qiita.com/kenta1984/items/7f3a5d859a15b20657f3#%E6%97%A5%E6%9C%AC%E8%AA%9Epre-trained-models) to illustrate how our model works as a masked language model.
40
+
41
+ ~~~~
42
+ # original text
43
+ text = "テレビでサッカーの試合を見る。"
44
+
45
+ # prepend [CLS]
46
+ text = "[CLS]" + text
47
+
48
+ # tokenize
49
+ tokens = tokenizer.tokenize(text)
50
+ print(tokens) # output: ['[CLS]', '▁', 'テレビ', 'で', 'サッカー', 'の試合', 'を見る', '。']
51
+
52
+ # mask a token
53
+ masked_idx = 4
54
+ tokens[masked_idx] = tokenizer.mask_token
55
+ print(tokens) # output: ['[CLS]', '▁', 'テレビ', 'で', '[MASK]', 'の試合', 'を見る', '。']
56
+
57
+ # convert to ids
58
+ token_ids = tokenizer.convert_tokens_to_ids(tokens)
59
+ print(token_ids) # output: [4, 9, 480, 19, 6, 8466, 6518, 8]
60
+
61
+ # convert to tensor
62
+ import torch
63
+ token_tensor = torch.tensor([token_ids])
64
+
65
+ # get the top 50 predictions of the masked token
66
+ model = model.eval()
67
+ with torch.no_grad():
68
+ outputs = model(token_tensor)
69
+ predictions = outputs[0][0, masked_idx].topk(100)
70
+ for i, index_t in enumerate(predictions.indices):
71
+ index = index_t.item()
72
+ token = tokenizer.convert_ids_to_tokens([index])[0]
73
+ print(i, token)
74
+
75
+ """
76
+ 0 サッカー
77
+ 1 メジャーリーグ
78
+ 2 フィギュアスケート
79
+ 3 バレーボール
80
+ 4 ボクシング
81
+ 5 ワールドカップ
82
+ 6 バスケットボール
83
+ 7 阪神タイガース
84
+ 8 プロ野球
85
+ 9 アメリカンフットボール
86
+ 10 日本代表
87
+ 11 高校野球
88
+ 12 福岡ソフトバンクホークス
89
+ 13 プレミアリーグ
90
+ 14 ファイターズ
91
+ 15 ラグビー
92
+ 16 東北楽天ゴールデンイーグルス
93
+ 17 中日ドラゴンズ
94
+ 18 アイスホッケー
95
+ 19 フットサル
96
+ 20 サッカー選手
97
+ 21 スポーツ
98
+ 22 チャンピオンズリーグ
99
+ 23 ジャイアンツ
100
+ 24 ソフトボール
101
+ 25 バスケット
102
+ 26 フットボール
103
+ 27 新日本プロレス
104
+ 28 バドミントン
105
+ 29 千葉ロッテマリーンズ
106
+ 30 <unk>
107
+ 31 北京オリンピック
108
+ 32 広島東洋カープ
109
+ 33 キックボクシング
110
+ 34 オリンピック
111
+ 35 ロンドンオリンピック
112
+ 36 読売ジャイアンツ
113
+ 37 テニス
114
+ 38 東京オリンピック
115
+ 39 日本シリーズ
116
+ 40 ヤクルトスワローズ
117
+ 41 タイガース
118
+ 42 サッカークラブ
119
+ 43 ハンドボール
120
+ 44 野球
121
+ 45 バルセロナ
122
+ 46 ホッケー
123
+ 47 格闘技
124
+ 48 大相撲
125
+ 49 ブンデスリーガ
126
+ 50 スキージャンプ
127
+ 51 プロサッカー選手
128
+ 52 ヤンキース
129
+ 53 社会人野球
130
+ 54 クライマックスシリーズ
131
+ 55 クリケット
132
+ 56 トップリーグ
133
+ 57 パラリンピック
134
+ 58 クラブチーム
135
+ 59 ニュージーランド
136
+ 60 総合格闘技
137
+ 61 ウィンブルドン
138
+ 62 ドラゴンボール
139
+ 63 レスリング
140
+ 64 ドラゴンズ
141
+ 65 プロ野球選手
142
+ 66 リオデジャネイロオリンピック
143
+ 67 ホークス
144
+ 68 全日本プロレス
145
+ 69 プロレス
146
+ 70 ヴェルディ
147
+ 71 都市対抗野球
148
+ 72 ライオンズ
149
+ 73 グランプリシリーズ
150
+ 74 日本プロ野球
151
+ 75 アテネオリンピック
152
+ 76 ヤクルト
153
+ 77 イーグルス
154
+ 78 巨人
155
+ 79 ワールドシリーズ
156
+ 80 アーセナル
157
+ 81 マスターズ
158
+ 82 ソフトバンク
159
+ 83 日本ハム
160
+ 84 クロアチア
161
+ 85 マリナーズ
162
+ 86 サッカーリーグ
163
+ 87 アトランタオリンピック
164
+ 88 ゴルフ
165
+ 89 ジャニーズ
166
+ 90 甲子園
167
+ 91 夏の甲子園
168
+ 92 陸上競技
169
+ 93 ベースボール
170
+ 94 卓球
171
+ 95 プロ
172
+ 96 南アフリカ
173
+ 97 レッズ
174
+ 98 ウルグアイ
175
+ 99 オールスターゲーム
176
+ """
177
  ~~~~
178
 
179
+
180
  # Model architecture
181
  A 12-layer, 768-hidden-size transformer-based masked language model.
182