Safetensors
File size: 3,483 Bytes
c74db76
 
 
 
a20b574
e3fca71
328f0b8
94a3267
c74db76
a20b574
e38d61e
94a3267
 
c74db76
a20b574
e38d61e
e3fca71
94a3267
c74db76
3ae6e36
 
94a3267
 
 
c74db76
a20b574
e38d61e
c74db76
 
4aad6e1
c74db76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aad6e1
c74db76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a20b574
e38d61e
94a3267
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
---
license: cc-by-4.0
---

Mixtral8X7B Instructの日本語生成を安定させるためのLora実験モデルです。  

注:bf16での使用を想定しています。  
量子化推論する場合は、bf16でモデルを読み込んだ状態でLora適応またはマージ、その後に量子化してください。

**目的**

Mixtral-8x7Bは高性能な言語モデルですが、日本語出力に多言語が混入するcode-switchingがよく見られます。  
元の性能を維持しながら、日本語生成を安定させる方法として、Loraの効果を検証しました。

**学習データセット**

学習データセットとして、下記のDPOデータセットを使用しています。  
DPO trainingはVRAM消費が多く、今回はchosenのデータを使用したsft学習しています。

Chatbot Arena Conversations JA (calm2) Dataset  
:[cyberagent/chatbot-arena-ja-calm2-7b-chat-experimental](https://huggingface.co/datasets/cyberagent/chatbot-arena-ja-calm2-7b-chat-experimental)  
指示文 : [lmsys/chatbot_arena_conversations](https://huggingface.co/datasets/lmsys/chatbot_arena_conversations)のユーザ入力(CC-BY 4.0)を利用。  
指示文の和訳 : [facebookの翻訳モデル(MIT License)](https://huggingface.co/facebook/wmt21-dense-24-wide-en-x)が使用されています。  
応答文 : calm2-7b-chat(Apache 2.0)の出力です。

**evaluation**

大きな性能低下がないことを確認しました

##Lora

num_fewshot: 2, batch_size: 1
|         Task         |Version|  Metric   | Value |   |Stderr|
|----------------------|------:|-----------|------:|---|-----:|
|jsquad-1.1-0.3        |    1.1|exact_match|72.3323|   |      |
|                      |       |f1         |85.4772|   |      |
|jcommonsenseqa-1.1-0.3|    1.1|acc        | 0.7498|±  |0.0130|
|                      |       |acc_norm   | 0.4138|±  |0.0147|


num_fewshot: 2, batch_size: 1
|      Task       |Version|  Metric   | Value |   |Stderr|
|-----------------|------:|-----------|------:|---|-----:|
|jnli-1.1-0.3     |    1.1|acc        | 0.5912|±  |0.0100|
|                 |       |acc_norm   | 0.4108|±  |0.0100|
|marc_ja-1.1-0.3  |    1.1|acc        | 0.9620|±  |0.0025|
|                 |       |acc_norm   | 0.9620|±  |0.0025|
|jaqket_v2-0.1-0.3|    0.1|exact_match|71.6495|   |      |
|                 |       |f1         |79.4725|   |      |


##Base model

num_fewshot: 3,3, batch_size: 1
|         Task         |Version|  Metric   | Value |   |Stderr|
|----------------------|------:|-----------|------:|---|-----:|
|jsquad-1.1-0.3        |    1.1|exact_match|68.1225|   |      |
|                      |       |f1         |83.5285|   |      |
|jcommonsenseqa-1.1-0.3|    1.1|acc        | 0.7766|±  |0.0125|
|                      |       |acc_norm   | 0.4629|±  |0.0149|


num_fewshot: 2, batch_size: 1
|      Task       |Version|  Metric   | Value |   |Stderr|
|-----------------|------:|-----------|------:|---|-----:|
|jnli-1.1-0.3     |    1.1|acc        | 0.6228|±  |0.0098|
|                 |       |acc_norm   | 0.5288|±  |0.0101|
|marc_ja-1.1-0.3  |    1.1|acc        | 0.9630|±  |0.0025|
|                 |       |acc_norm   | 0.9630|±  |0.0025|
|jaqket_v2-0.1-0.3|    0.1|exact_match|67.9553|   |      |
|                 |       |f1         |78.7550|   |      |


**その他**

Lora学習時のcontext長は4096tokenまでですが、4k token以上の出力も可能です。