Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import spaces | |
import sys | |
import gradio as gr | |
from llama_cpp import Llama | |
import configparser | |
from functools import partial | |
from utils.dl_utils import dl_guff_model | |
import io | |
import tempfile | |
import csv | |
# 定数 | |
DEFAULT_INI_FILE = 'settings.ini' | |
MODEL_FILE_EXTENSION = '.gguf' | |
# パスの設定 | |
if getattr(sys, 'frozen', False): | |
BASE_PATH = os.path.dirname(sys.executable) | |
MODEL_DIR = os.path.join(os.path.dirname(BASE_PATH), "AI-NovelAssistant", "models") | |
else: | |
BASE_PATH = os.path.dirname(os.path.abspath(__file__)) | |
MODEL_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models") | |
# モデルディレクトリが存在しない場合は作成 | |
if not os.path.exists("models"): | |
os.makedirs("models") | |
# 使用するモデルのファイル名を指定 | |
model_filename = "EZO-Common-9B-gemma-2-it.f16.gguf" | |
model_path = os.path.join("models", model_filename) | |
# モデルファイルが存在しない場合はダウンロード | |
if not os.path.exists(model_path): | |
dl_guff_model("models", f"https://huggingface.co/MCZK/EZO-Common-9B-gemma-2-it-GGUF/resolve/main/{model_filename}") | |
class ConfigManager: | |
def load_settings(filename): | |
config = configparser.ConfigParser() | |
config.read(filename, encoding='utf-8') | |
return config | |
def save_settings(config, filename): | |
with open(filename, 'w', encoding='utf-8') as configfile: | |
config.write(configfile) | |
def update_setting(section, key, value, filename): | |
config = ConfigManager.load_settings(filename) | |
config[section][key] = value | |
ConfigManager.save_settings(config, filename) | |
return f"設定を更新しました: [{section}] {key} = {value}" | |
def create_default_settings(filename): | |
config = configparser.ConfigParser() | |
config['Character'] = { | |
'gen_author_description': 'あなたは新進気鋭の和風伝奇ミステリー小説家で、細やかな筆致と巧みな構成で若い世代にとても人気があります。' | |
} | |
config['Models'] = { | |
'DEFAULT_GEN_MODEL': 'EZO-Common-9B-gemma-2-it.f16.gguf' | |
} | |
config['GenerateParameters'] = { | |
'n_gpu_layers': '-1', | |
'temperature': '0.35', | |
'top_p': '0.9', | |
'top_k': '40', | |
'repetition_penalty': '1.2', | |
'n_ctx': '10000' | |
} | |
ConfigManager.save_settings(config, filename) | |
print(f"デフォルト設定ファイル {filename} を作成しました。") | |
class ModelManager: | |
def get_model_files(): | |
return [f for f in os.listdir(MODEL_DIR) if f.endswith(MODEL_FILE_EXTENSION)] | |
class Settings: | |
def _parse_config(config): | |
settings = {} | |
if 'Character' in config: | |
settings['gen_author_description'] = config['Character'].get('gen_author_description', '') | |
if 'Models' in config: | |
settings['DEFAULT_GEN_MODEL'] = config['Models'].get('DEFAULT_GEN_MODEL', '') | |
if 'GenerateParameters' in config: | |
settings['gen_n_gpu_layers'] = int(config['GenerateParameters'].get('n_gpu_layers', '-1')) | |
settings['gen_temperature'] = float(config['GenerateParameters'].get('temperature', '0.35')) | |
settings['gen_top_p'] = float(config['GenerateParameters'].get('top_p', '0.9')) | |
settings['gen_top_k'] = int(config['GenerateParameters'].get('top_k', '40')) | |
settings['gen_rep_pen'] = float(config['GenerateParameters'].get('repetition_penalty', '1.2')) | |
settings['gen_n_ctx'] = int(config['GenerateParameters'].get('n_ctx', '10000')) | |
return settings | |
def load_from_ini(filename): | |
config = ConfigManager.load_settings(filename) | |
return Settings._parse_config(config) | |
class TextGenerator: | |
def __init__(self): | |
self.llm = None | |
self.settings = None | |
self.current_model = None | |
def load_model(self): | |
if self.llm: | |
del self.llm | |
self.llm = None | |
try: | |
model_path = os.path.join(MODEL_DIR, self.settings['DEFAULT_GEN_MODEL']) | |
n_gpu_layers = self.settings['gen_n_gpu_layers'] | |
self.llm = Llama(model_path=model_path, n_ctx=self.settings['gen_n_ctx'], n_gpu_layers=n_gpu_layers) | |
self.current_model = 'GEN' | |
print(f"GEN モデル {model_path} のロードが完了しました。(n_gpu_layers: {n_gpu_layers})") | |
except Exception as e: | |
print(f"GEN モデルのロード中にエラーが発生しました: {str(e)}") | |
def generate_text(self, text, gen_characters, gen_token_multiplier, instruction): | |
if not self.llm: | |
self.load_model() | |
if not self.llm: | |
return "モデルのロードに失敗しました。設定を確認してください。" | |
author_description = self.settings.get('gen_author_description', '') | |
max_tokens = int(gen_characters * gen_token_multiplier) | |
messages = [ | |
{"role": "user", "content": f"{author_description}\n\n以下の指示に従ってテキストを生成してください:"}, | |
{"role": "assistant", "content": "はい、承知しました。指示に従ってテキストを生成いたします。"}, | |
{"role": "user", "content": f"{instruction}\n\n生成するテキスト(目安は{gen_characters}文字):\n\n{text}"} | |
] | |
try: | |
response = self.llm.create_chat_completion( | |
messages=messages, | |
max_tokens=max_tokens, | |
temperature=self.settings['gen_temperature'], | |
top_p=self.settings['gen_top_p'], | |
top_k=self.settings['gen_top_k'], | |
repeat_penalty=self.settings['gen_rep_pen'], | |
) | |
generated_text = response["choices"][0]["message"]["content"].strip() | |
return generated_text | |
except Exception as e: | |
print(f"テキスト生成中にエラーが発生しました: {str(e)}") | |
return "テキスト生成中にエラーが発生しました。設定を確認してください。" | |
def load_settings(self, filename): | |
self.settings = Settings.load_from_ini(filename) | |
# グローバル変数 | |
text_generator = TextGenerator() | |
model_files = ModelManager.get_model_files() | |
# Gradioインターフェース | |
def build_gradio_interface(): | |
with gr.Blocks() as iface: | |
gr.HTML(""" | |
<style> | |
#output { | |
resize: both; | |
overflow: auto; | |
min-height: 100px; | |
max-height: 80vh; | |
} | |
</style> | |
""") | |
with gr.Tab("文章生成"): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
instruction_type = gr.Dropdown( | |
choices=["自由入力", "推敲", "プロット作成", "あらすじ作成", "地の文追加"], | |
label="指示タイプ", | |
value="自由入力" | |
) | |
gen_instruction = gr.Textbox( | |
label="指示", | |
value="", | |
lines=3 | |
) | |
gen_input_text = gr.Textbox(lines=5, label="処理されるテキストを入力してください") | |
gen_input_char_count = gr.HTML(value="文字数: 0") | |
with gr.Column(scale=1): | |
gen_characters = gr.Slider(minimum=10, maximum=10000, value=500, step=10, label="出力文字数", info="出力文字数の目安") | |
gen_token_multiplier = gr.Slider(minimum=0.5, maximum=3, value=1.75, step=0.01, label="文字/トークン数倍率", info="文字/最大トークン数倍率") | |
generate_button = gr.Button("文章生成開始") | |
generated_output = gr.Textbox(label="生成された文章", elem_id="output") | |
generate_button.click( | |
text_generator.generate_text, | |
inputs=[gen_input_text, gen_characters, gen_token_multiplier, gen_instruction], | |
outputs=[generated_output] | |
) | |
def update_instruction(choice): | |
instructions = { | |
"自由入力": "", | |
"推敲": "以下のテキストを推敲してください。原文の文体や特徴的な表現は保持しつつ、必要に応じて微調整を加えてください。文章の流れを自然にし、表現を洗練させることが目標ですが、元の雰囲気や個性を損なわないよう注意してください。", | |
"プロット作成": "以下のテキストをプロットにしてください。起承転結に分割すること。", | |
"あらすじ作成": "以下のテキストをあらすじにして、簡潔にまとめて下さい。", | |
"地の文追加": "以下のテキストに地の文を加え、情景を豊かにしてください。会話文は絶対に追加・改変しないでください。" | |
} | |
return instructions.get(choice, "") | |
instruction_type.change( | |
update_instruction, | |
inputs=[instruction_type], | |
outputs=[gen_instruction] | |
) | |
def update_char_count(text): | |
return f"文字数: {len(text)}" | |
gen_input_text.change( | |
update_char_count, | |
inputs=[gen_input_text], | |
outputs=[gen_input_char_count] | |
) | |
with gr.Tab("設定"): | |
output = gr.Textbox(label="更新状態") | |
config = ConfigManager.load_settings(DEFAULT_INI_FILE) | |
with gr.Column(): | |
gr.Markdown("### モデル設定") | |
model_dropdown = gr.Dropdown( | |
label="DEFAULT_GEN_MODEL", | |
choices=ModelManager.get_model_files(), | |
value=config['Models'].get('DEFAULT_GEN_MODEL', '') | |
) | |
model_dropdown.change( | |
partial(ConfigManager.update_setting, 'Models', 'DEFAULT_GEN_MODEL'), | |
inputs=[model_dropdown], | |
outputs=[output] | |
) | |
gr.Markdown("### 文章生成設定") | |
gen_author_description = gr.TextArea( | |
label="gen_author_description", | |
value=config['Character'].get('gen_author_description', ''), | |
lines=5 | |
) | |
gen_author_description.change( | |
partial(ConfigManager.update_setting, 'Character', 'gen_author_description'), | |
inputs=[gen_author_description], | |
outputs=[output] | |
) | |
gr.Markdown("### 文章生成パラメータ設定") | |
for key in ['n_gpu_layers', 'temperature', 'top_p', 'top_k', 'repetition_penalty', 'n_ctx']: | |
value = config['GenerateParameters'].get(key, '0') | |
if key == 'n_gpu_layers': | |
input_component = gr.Slider(label=key, value=int(value), minimum=-1, maximum=255, step=1) | |
elif key in ['temperature', 'top_p', 'repetition_penalty']: | |
input_component = gr.Slider(label=key, value=float(value), minimum=0.0, maximum=1.0, step=0.05) | |
elif key == 'top_k': | |
input_component = gr.Slider(label=key, value=int(value), minimum=1, maximum=200, step=1) | |
elif key == 'n_ctx': | |
input_component = gr.Slider(label=key, value=int(value), minimum=10000, maximum=100000, step=1000) | |
else: | |
input_component = gr.Textbox(label=key, value=value) | |
input_component.change( | |
partial(ConfigManager.update_setting, 'GenerateParameters', key), | |
inputs=[input_component], | |
outputs=[output] | |
) | |
apply_settings_button = gr.Button("設定を適用") | |
apply_settings_button.click( | |
lambda: text_generator.load_settings(DEFAULT_INI_FILE), | |
outputs=[output] | |
) | |
return iface | |
if __name__ == "__main__": | |
if not os.path.exists(DEFAULT_INI_FILE): | |
print(f"{DEFAULT_INI_FILE} が見つかりません。デフォルト設定で作成します。") | |
ConfigManager.create_default_settings(DEFAULT_INI_FILE) | |
text_generator.load_settings(DEFAULT_INI_FILE) | |
demo = build_gradio_interface() | |
demo.launch(share=True) |