AINovelChat / app.py
tori29umai's picture
Update app.py
477abe1 verified
raw
history blame
13.1 kB
import os
import spaces
import sys
import gradio as gr
from llama_cpp import Llama
import configparser
from functools import partial
from utils.dl_utils import dl_guff_model
import io
import tempfile
import csv
# 定数
DEFAULT_INI_FILE = 'settings.ini'
MODEL_FILE_EXTENSION = '.gguf'
# パスの設定
if getattr(sys, 'frozen', False):
BASE_PATH = os.path.dirname(sys.executable)
MODEL_DIR = os.path.join(os.path.dirname(BASE_PATH), "AI-NovelAssistant", "models")
else:
BASE_PATH = os.path.dirname(os.path.abspath(__file__))
MODEL_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
# モデルディレクトリが存在しない場合は作成
if not os.path.exists("models"):
os.makedirs("models")
# 使用するモデルのファイル名を指定
model_filename = "EZO-Common-9B-gemma-2-it.f16.gguf"
model_path = os.path.join("models", model_filename)
# モデルファイルが存在しない場合はダウンロード
if not os.path.exists(model_path):
dl_guff_model("models", f"https://huggingface.co/MCZK/EZO-Common-9B-gemma-2-it-GGUF/resolve/main/{model_filename}")
class ConfigManager:
@staticmethod
def load_settings(filename):
config = configparser.ConfigParser()
config.read(filename, encoding='utf-8')
return config
@staticmethod
def save_settings(config, filename):
with open(filename, 'w', encoding='utf-8') as configfile:
config.write(configfile)
@staticmethod
def update_setting(section, key, value, filename):
config = ConfigManager.load_settings(filename)
config[section][key] = value
ConfigManager.save_settings(config, filename)
return f"設定を更新しました: [{section}] {key} = {value}"
@staticmethod
def create_default_settings(filename):
config = configparser.ConfigParser()
config['Character'] = {
'gen_author_description': 'あなたは新進気鋭の和風伝奇ミステリー小説家で、細やかな筆致と巧みな構成で若い世代にとても人気があります。'
}
config['Models'] = {
'DEFAULT_GEN_MODEL': 'EZO-Common-9B-gemma-2-it.f16.gguf'
}
config['GenerateParameters'] = {
'n_gpu_layers': '-1',
'temperature': '0.35',
'top_p': '0.9',
'top_k': '40',
'repetition_penalty': '1.2',
'n_ctx': '10000'
}
ConfigManager.save_settings(config, filename)
print(f"デフォルト設定ファイル {filename} を作成しました。")
class ModelManager:
@staticmethod
def get_model_files():
return [f for f in os.listdir(MODEL_DIR) if f.endswith(MODEL_FILE_EXTENSION)]
class Settings:
@staticmethod
def _parse_config(config):
settings = {}
if 'Character' in config:
settings['gen_author_description'] = config['Character'].get('gen_author_description', '')
if 'Models' in config:
settings['DEFAULT_GEN_MODEL'] = config['Models'].get('DEFAULT_GEN_MODEL', '')
if 'GenerateParameters' in config:
settings['gen_n_gpu_layers'] = int(config['GenerateParameters'].get('n_gpu_layers', '-1'))
settings['gen_temperature'] = float(config['GenerateParameters'].get('temperature', '0.35'))
settings['gen_top_p'] = float(config['GenerateParameters'].get('top_p', '0.9'))
settings['gen_top_k'] = int(config['GenerateParameters'].get('top_k', '40'))
settings['gen_rep_pen'] = float(config['GenerateParameters'].get('repetition_penalty', '1.2'))
settings['gen_n_ctx'] = int(config['GenerateParameters'].get('n_ctx', '10000'))
return settings
@staticmethod
def load_from_ini(filename):
config = ConfigManager.load_settings(filename)
return Settings._parse_config(config)
class TextGenerator:
def __init__(self):
self.llm = None
self.settings = None
self.current_model = None
@spaces.GPU(duration=120)
def load_model(self):
if self.llm:
del self.llm
self.llm = None
try:
model_path = os.path.join(MODEL_DIR, self.settings['DEFAULT_GEN_MODEL'])
n_gpu_layers = self.settings['gen_n_gpu_layers']
self.llm = Llama(model_path=model_path, n_ctx=self.settings['gen_n_ctx'], n_gpu_layers=n_gpu_layers)
self.current_model = 'GEN'
print(f"GEN モデル {model_path} のロードが完了しました。(n_gpu_layers: {n_gpu_layers})")
except Exception as e:
print(f"GEN モデルのロード中にエラーが発生しました: {str(e)}")
def generate_text(self, text, gen_characters, gen_token_multiplier, instruction):
if not self.llm:
self.load_model()
if not self.llm:
return "モデルのロードに失敗しました。設定を確認してください。"
author_description = self.settings.get('gen_author_description', '')
max_tokens = int(gen_characters * gen_token_multiplier)
messages = [
{"role": "user", "content": f"{author_description}\n\n以下の指示に従ってテキストを生成してください:"},
{"role": "assistant", "content": "はい、承知しました。指示に従ってテキストを生成いたします。"},
{"role": "user", "content": f"{instruction}\n\n生成するテキスト(目安は{gen_characters}文字):\n\n{text}"}
]
try:
response = self.llm.create_chat_completion(
messages=messages,
max_tokens=max_tokens,
temperature=self.settings['gen_temperature'],
top_p=self.settings['gen_top_p'],
top_k=self.settings['gen_top_k'],
repeat_penalty=self.settings['gen_rep_pen'],
)
generated_text = response["choices"][0]["message"]["content"].strip()
return generated_text
except Exception as e:
print(f"テキスト生成中にエラーが発生しました: {str(e)}")
return "テキスト生成中にエラーが発生しました。設定を確認してください。"
def load_settings(self, filename):
self.settings = Settings.load_from_ini(filename)
# グローバル変数
text_generator = TextGenerator()
model_files = ModelManager.get_model_files()
# Gradioインターフェース
def build_gradio_interface():
with gr.Blocks() as iface:
gr.HTML("""
<style>
#output {
resize: both;
overflow: auto;
min-height: 100px;
max-height: 80vh;
}
</style>
""")
with gr.Tab("文章生成"):
with gr.Row():
with gr.Column(scale=2):
instruction_type = gr.Dropdown(
choices=["自由入力", "推敲", "プロット作成", "あらすじ作成", "地の文追加"],
label="指示タイプ",
value="自由入力"
)
gen_instruction = gr.Textbox(
label="指示",
value="",
lines=3
)
gen_input_text = gr.Textbox(lines=5, label="処理されるテキストを入力してください")
gen_input_char_count = gr.HTML(value="文字数: 0")
with gr.Column(scale=1):
gen_characters = gr.Slider(minimum=10, maximum=10000, value=500, step=10, label="出力文字数", info="出力文字数の目安")
gen_token_multiplier = gr.Slider(minimum=0.5, maximum=3, value=1.75, step=0.01, label="文字/トークン数倍率", info="文字/最大トークン数倍率")
generate_button = gr.Button("文章生成開始")
generated_output = gr.Textbox(label="生成された文章", elem_id="output")
generate_button.click(
text_generator.generate_text,
inputs=[gen_input_text, gen_characters, gen_token_multiplier, gen_instruction],
outputs=[generated_output]
)
def update_instruction(choice):
instructions = {
"自由入力": "",
"推敲": "以下のテキストを推敲してください。原文の文体や特徴的な表現は保持しつつ、必要に応じて微調整を加えてください。文章の流れを自然にし、表現を洗練させることが目標ですが、元の雰囲気や個性を損なわないよう注意してください。",
"プロット作成": "以下のテキストをプロットにしてください。起承転結に分割すること。",
"あらすじ作成": "以下のテキストをあらすじにして、簡潔にまとめて下さい。",
"地の文追加": "以下のテキストに地の文を加え、情景を豊かにしてください。会話文は絶対に追加・改変しないでください。"
}
return instructions.get(choice, "")
instruction_type.change(
update_instruction,
inputs=[instruction_type],
outputs=[gen_instruction]
)
def update_char_count(text):
return f"文字数: {len(text)}"
gen_input_text.change(
update_char_count,
inputs=[gen_input_text],
outputs=[gen_input_char_count]
)
with gr.Tab("設定"):
output = gr.Textbox(label="更新状態")
config = ConfigManager.load_settings(DEFAULT_INI_FILE)
with gr.Column():
gr.Markdown("### モデル設定")
model_dropdown = gr.Dropdown(
label="DEFAULT_GEN_MODEL",
choices=ModelManager.get_model_files(),
value=config['Models'].get('DEFAULT_GEN_MODEL', '')
)
model_dropdown.change(
partial(ConfigManager.update_setting, 'Models', 'DEFAULT_GEN_MODEL'),
inputs=[model_dropdown],
outputs=[output]
)
gr.Markdown("### 文章生成設定")
gen_author_description = gr.TextArea(
label="gen_author_description",
value=config['Character'].get('gen_author_description', ''),
lines=5
)
gen_author_description.change(
partial(ConfigManager.update_setting, 'Character', 'gen_author_description'),
inputs=[gen_author_description],
outputs=[output]
)
gr.Markdown("### 文章生成パラメータ設定")
for key in ['n_gpu_layers', 'temperature', 'top_p', 'top_k', 'repetition_penalty', 'n_ctx']:
value = config['GenerateParameters'].get(key, '0')
if key == 'n_gpu_layers':
input_component = gr.Slider(label=key, value=int(value), minimum=-1, maximum=255, step=1)
elif key in ['temperature', 'top_p', 'repetition_penalty']:
input_component = gr.Slider(label=key, value=float(value), minimum=0.0, maximum=1.0, step=0.05)
elif key == 'top_k':
input_component = gr.Slider(label=key, value=int(value), minimum=1, maximum=200, step=1)
elif key == 'n_ctx':
input_component = gr.Slider(label=key, value=int(value), minimum=10000, maximum=100000, step=1000)
else:
input_component = gr.Textbox(label=key, value=value)
input_component.change(
partial(ConfigManager.update_setting, 'GenerateParameters', key),
inputs=[input_component],
outputs=[output]
)
apply_settings_button = gr.Button("設定を適用")
apply_settings_button.click(
lambda: text_generator.load_settings(DEFAULT_INI_FILE),
outputs=[output]
)
return iface
if __name__ == "__main__":
if not os.path.exists(DEFAULT_INI_FILE):
print(f"{DEFAULT_INI_FILE} が見つかりません。デフォルト設定で作成します。")
ConfigManager.create_default_settings(DEFAULT_INI_FILE)
text_generator.load_settings(DEFAULT_INI_FILE)
demo = build_gradio_interface()
demo.launch(share=True)