inu-ai's picture
Update README.md
dc21d7b
|
raw
history blame
6.52 kB
metadata
language: ja
tags:
  - ja
  - japanese
  - gpt
  - text-generation
  - lm
  - nlp
license: mit
datasets:
  - kunishou/databricks-dolly-15k-ja
widget:
  - text: >-
      <s>\n以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。\n[SEP]\n指示:\n日本で一番広い湖は?\n[SEP]\n応答:\n

dolly-japanese-gpt-1b

1.3Bパラメータの日本語GPTモデルを使用した指示に応答するAIです。VRAM 7GB または RAM 7GB が必要で、問題なく動作すると思われます。

rinna社の「japanese-gpt-1b」を、日本語データセット「databricks-dolly-15k-ja」を使用して学習させました。

学習データやモデルを作成および配布してくださった方々に心から感謝申し上げます。

モデルの使用方法

モデルの読み込み

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("inu-ai/dolly-japanese-gpt-1b", use_fast=False)
model = AutoModelForCausalLM.from_pretrained("inu-ai/dolly-japanese-gpt-1b").to(device)

ChatGPT4によるサンプルコード

MAX_ASSISTANT_LENGTH = 100
MAX_INPUT_LENGTH = 1024
INPUT_PROMPT = r'<s>\n以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。\n[SEP]\n指示:\n{instruction}\n[SEP]\n入力:\n{input}\n[SEP]\n応答:\n'
NO_INPUT_PROMPT = r'<s>\n以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。\n[SEP]\n指示:\n{instruction}\n[SEP]\n応答:\n'

def prepare_input(instruction, input_text):
    if input_text != "":
        prompt = INPUT_PROMPT.format(instruction=instruction, input=input_text)
    else:
        prompt = NO_INPUT_PROMPT.format(instruction=instruction)
    return prompt

def format_output(output):
    output = output.lstrip("<s>").rstrip("</s>").replace("[SEP]", "").replace("\\n", "\n")
    return output

def generate_response(instruction, input_text):
    prompt = prepare_input(instruction, input_text)
    token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
    n = len(token_ids[0])

    with torch.no_grad():
        output_ids = model.generate(
            token_ids.to(model.device),
            min_length=n,
            max_length=min(MAX_INPUT_LENGTH, n + MAX_ASSISTANT_LENGTH),
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            bad_words_ids=[[tokenizer.unk_token_id]]
        )

    output = tokenizer.decode(output_ids.tolist()[0])
    formatted_output_all = format_output(output)
    response = f"Assistant:{formatted_output_all.split('応答:')[-1].strip()}"

    return formatted_output_all, response 

instruction = "あなたは何でも正確に答えられるAIです。"
questions = [
    "日本で一番高い山は?",
    "日本で一番広い湖は?",
    "世界で一番高い山は?",
    "世界で一番広い湖は?",
    "冗談を言ってください。",
]

# 各質問に対して応答を生成して表示
for question in questions:
    formatted_output_all, response = generate_response(instruction, question)
    print(response)

出力

Assistant:富士山
Assistant:琵琶湖
Assistant:エベレストです。
Assistant:それは、面積で最大であると同時に、深さでも最も深い湖でもあります。
Assistant:A.I.は、自分の意思で、そして、自分の考えで行動することができます。

ChatGPT4による説明

このコードは、GPT-2モデルを使って、指定された指示と入力に対して適切な応答を生成するAIアシスタントです。 まず、prepare_input関数でプロンプトを作成し、generate_response関数でモデルから応答を生成します。 生成された応答を整形し、質問ごとに結果を表示します。

評価

100回の「入力」のような質問を行い、それらに対する「応答」に正解の文字列が含まれるかで評価しています。 一番正答率が高い18エポック目のモデルを選択しました。(やり過ぎたかもしれないです。)

入力 応答 正答率[%]
日本で一番広い湖は? 琵琶湖 91
世界で一番高い山は? エベレスト 84

学習データのフォーマット

alpacaと同じように、以下のようなフォーマットにしています。

<s>
以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。
[SEP]
指示:
あなたは何でも正確に答えられるAIです。
[SEP]
入力:
User:日本で一番高い山は?
[SEP]
応答:
富士山
</s>

transformersのコードでtxtファイルを学習する場合、1データ1行のようなので改行コードを一旦\nに置き換えています。 学習データはdatabricks-dolly-15k-ja.txtです。

学習のハイパーパラメータ

学習時には以下のハイパーパラメータを使用: ※VRAMが足りない場合、optimをadafactorにするとVRAM使用量が減りました。

python.exe transformers/examples/pytorch/language-modeling/run_clm.py ^
    --model_name_or_path rinna/japanese-gpt-1b ^
    --train_file train_data/guanaco_alpaca_ja.txt ^
    --output_dir output ^
    --do_train ^
    --bf16 True ^
    --tf32 True ^
    --optim adamw_bnb_8bit ^
    --num_train_epochs 18 ^
    --save_steps 384 ^
    --logging_steps 38 ^
    --learning_rate 1e-07 ^
    --lr_scheduler_type constant ^
    --gradient_checkpointing ^
    --per_device_train_batch_size 8 ^
    --save_safetensors True ^
    --logging_dir logs

ライブラリのバージョン

  • Transformers 4.28.0.dev0
  • Pytorch 2.0.0+cu117
  • Tokenizers 0.13.3
  • bitsandbytes 0.37.2

ライセンス

MITで大丈夫そうです。

  • japanese-gpt-1b - mit
  • databricks-dolly-15k-ja - CC BY SA 3.0