Spaces:
Running
on
Zero
Running
on
Zero
import os | |
os.environ['CUDA_VISIBLE_DEVICES'] = '' | |
import spaces | |
import sys | |
import time | |
import socket | |
import gradio as gr | |
from llama_cpp import Llama | |
import datetime | |
from jinja2 import Template | |
import configparser | |
from functools import partial | |
import threading | |
import asyncio | |
import csv | |
from utils.dl_utils import dl_guff_model | |
import tempfile | |
# 定数 | |
DEFAULT_INI_FILE = 'settings.ini' | |
MODEL_FILE_EXTENSION = '.gguf' | |
# パスの設定 | |
BASE_PATH = os.path.dirname(os.path.abspath(__file__)) | |
MODEL_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models") | |
# モデルディレクトリが存在しない場合は作成 | |
if not os.path.exists("models"): | |
os.makedirs("models") | |
dl_guff_model("models", "https://huggingface.co/MCZK/EZO-Common-9B-gemma-2-it-GGUF/resolve/main/EZO-Common-9B-gemma-2-it.Q8_0.gguf") | |
dl_guff_model("models", "https://huggingface.co/second-state/Mistral-Nemo-Instruct-2407-GGUF/resolve/main/Mistral-Nemo-Instruct-2407-Q8_0.gguf") | |
class ConfigManager: | |
def load_settings(filename): | |
config = configparser.ConfigParser() | |
config.read(filename, encoding='utf-8') | |
return config | |
def save_settings(config, filename): | |
with open(filename, 'w', encoding='utf-8') as configfile: | |
config.write(configfile) | |
def update_setting(section, key, value, filename): | |
config = ConfigManager.load_settings(filename) | |
config[section][key] = value | |
ConfigManager.save_settings(config, filename) | |
return f"設定を更新しました: [{section}] {key} = {value}" | |
class ModelManager: | |
def get_model_files(): | |
return [f for f in os.listdir(MODEL_DIR) if f.endswith(MODEL_FILE_EXTENSION)] | |
def update_model_dropdown(config, section, key): | |
current_value = config[section][key] | |
model_files = ModelManager.get_model_files() | |
if current_value not in model_files: | |
download_message = f"現在の{key}({current_value})が見つかりません。ダウンロードしてください。" | |
model_files.insert(0, current_value) | |
else: | |
download_message = "" | |
return model_files, current_value, download_message | |
class NetworkUtils: | |
def get_ip_address(): | |
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: | |
try: | |
s.connect(('10.255.255.255', 1)) | |
return s.getsockname()[0] | |
except Exception: | |
return '127.0.0.1' | |
def find_available_port(starting_port): | |
port = starting_port | |
while NetworkUtils.is_port_in_use(port): | |
print(f"Port {port} is in use, trying next one.") | |
port += 1 | |
return port | |
def is_port_in_use(port): | |
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
return s.connect_ex(('localhost', port)) == 0 | |
class Settings: | |
def _parse_config(config): | |
settings = {} | |
if 'Character' in config: | |
settings['chat_author_description'] = config['Character'].get('chat_author_description', '') | |
settings['chat_instructions'] = config['Character'].get('chat_instructions', '') | |
settings['example_qa'] = config['Character'].get('example_qa', '').split('\n') | |
settings['gen_author_description'] = config['Character'].get('gen_author_description', '') | |
if 'Models' in config: | |
settings['DEFAULT_CHAT_MODEL'] = config['Models'].get('DEFAULT_CHAT_MODEL', '') | |
settings['DEFAULT_GEN_MODEL'] = config['Models'].get('DEFAULT_GEN_MODEL', '') | |
if 'ChatParameters' in config: | |
settings['chat_n_gpu_layers'] = int(config['ChatParameters'].get('n_gpu_layers', '-1')) | |
settings['chat_temperature'] = float(config['ChatParameters'].get('temperature', '0.35')) | |
settings['chat_top_p'] = float(config['ChatParameters'].get('top_p', '0.9')) | |
settings['chat_top_k'] = int(config['ChatParameters'].get('top_k', '40')) | |
settings['chat_rep_pen'] = float(config['ChatParameters'].get('repetition_penalty', '1.2')) | |
settings['chat_n_ctx'] = int(config['ChatParameters'].get('n_ctx', '10000')) | |
if 'GenerateParameters' in config: | |
settings['gen_n_gpu_layers'] = int(config['GenerateParameters'].get('n_gpu_layers', '-1')) | |
settings['gen_temperature'] = float(config['GenerateParameters'].get('temperature', '0.35')) | |
settings['gen_top_p'] = float(config['GenerateParameters'].get('top_p', '0.9')) | |
settings['gen_top_k'] = int(config['GenerateParameters'].get('top_k', '40')) | |
settings['gen_rep_pen'] = float(config['GenerateParameters'].get('repetition_penalty', '1.2')) | |
settings['gen_n_ctx'] = int(config['GenerateParameters'].get('n_ctx', '10000')) | |
return settings | |
def save_to_ini(settings, filename): | |
config = configparser.ConfigParser() | |
config['Character'] = { | |
'chat_author_description': settings.get('chat_author_description', ''), | |
'chat_instructions': settings.get('chat_instructions', ''), | |
'example_qa': '\n'.join(settings.get('example_qa', [])), | |
'gen_author_description': settings.get('gen_author_description', '') | |
} | |
config['Models'] = { | |
'DEFAULT_CHAT_MODEL': settings.get('DEFAULT_CHAT_MODEL', ''), | |
'DEFAULT_GEN_MODEL': settings.get('DEFAULT_GEN_MODEL', '') | |
} | |
config['ChatParameters'] = { | |
'n_gpu_layers': str(settings.get('chat_n_gpu_layers', -1)), | |
'temperature': str(settings.get('chat_temperature', 0.35)), | |
'top_p': str(settings.get('chat_top_p', 0.9)), | |
'top_k': str(settings.get('chat_top_k', 40)), | |
'repetition_penalty': str(settings.get('chat_rep_pen', 1.2)), | |
'n_ctx': str(settings.get('chat_n_ctx', 10000)) | |
} | |
config['GenerateParameters'] = { | |
'n_gpu_layers': str(settings.get('gen_n_gpu_layers', -1)), | |
'temperature': str(settings.get('gen_temperature', 0.35)), | |
'top_p': str(settings.get('gen_top_p', 0.9)), | |
'top_k': str(settings.get('gen_top_k', 40)), | |
'repetition_penalty': str(settings.get('gen_rep_pen', 1.2)), | |
'n_ctx': str(settings.get('gen_n_ctx', 10000)) | |
} | |
ConfigManager.save_settings(config, filename) | |
def create_default_ini(filename): | |
default_settings = { | |
'chat_author_description': "あなたは優秀な小説執筆アシスタントです。三幕構造や起承転結、劇中劇などのあらゆる小説理論や小説技法にも通じています。", | |
'chat_instructions': "丁寧な敬語でアイディアのヒアリングしてください。物語をより面白くする提案、キャラクター造形の考察、世界観を膨らませる手伝いなどをお願いします。求められた時以外は基本、聞き役に徹してユーザー自身に言語化させるよう促してください。ユーザーのことは『ユーザー』と呼んでください。", | |
'example_qa': [ | |
"user: キャラクターの設定について悩んでいます。", | |
"assistant: キャラクター設定は物語の核となる重要な要素ですね。ユーザーが現在考えているキャラクターについて、簡単にご説明いただけますでしょうか?", | |
"user: どんな設定を説明をしたらいいでしょうか?", | |
"assistant: 例えば、年齢、性別、職業、性格の特徴などから始めていただけると、より具体的なアドバイスができるかと思います。", | |
"user: プロットを書き出したいので、ヒアリングお願いします。", | |
"assistant: 承知しました。ではまず『起承転結』の起から考えていきましょう。", | |
"user: 読者を惹きこむ為のコツを提案してください", | |
"assistant: 諸説ありますが、『謎・ピンチ・意外性』を冒頭に持ってくることが重要だと言います。", | |
"user: プロットが面白いか自信がないので、考察のお手伝いをお願いします。", | |
"assistant: プロットについてコメントをする前に、まずこの物語の『売り』について簡単に説明してください", | |
], | |
'gen_author_description': 'あなたは新進気鋭の和風伝奇ミステリー小説家で、細やかな筆致と巧みな構成で若い世代にとても人気があります。', | |
'DEFAULT_CHAT_MODEL': 'EZO-Common-9B-gemma-2-it.Q8_0.gguf', | |
'DEFAULT_GEN_MODEL': 'EZO-Common-9B-gemma-2-it.Q8_0.gguf', | |
'chat_n_gpu_layers': -1, | |
'chat_temperature': 0.35, | |
'chat_top_p': 0.9, | |
'chat_top_k': 40, | |
'chat_rep_pen': 1.2, | |
'chat_n_ctx': 10000, | |
'gen_n_gpu_layers': -1, | |
'gen_temperature': 0.35, | |
'gen_top_p': 0.9, | |
'gen_top_k': 40, | |
'gen_rep_pen': 1.2, | |
'gen_n_ctx': 10000 | |
} | |
Settings.save_to_ini(default_settings, filename) | |
def load_from_ini(filename): | |
config = ConfigManager.load_settings(filename) | |
return Settings._parse_config(config) | |
class GenTextParams: | |
def __init__(self): | |
self.gen_n_gpu_layers = -1 | |
self.gen_temperature = 0.35 | |
self.gen_top_p = 0.9 | |
self.gen_top_k = 40 | |
self.gen_rep_pen = 1.2 | |
self.gen_n_ctx = 10000 | |
self.chat_n_gpu_layers = -1 | |
self.chat_temperature = 0.35 | |
self.chat_top_p = 0.9 | |
self.chat_top_k = 40 | |
self.chat_rep_pen = 1.2 | |
self.chat_n_ctx = 10000 | |
def update_generate_parameters(self, n_gpu_layers, temperature, top_p, top_k, rep_pen, n_ctx): | |
self.gen_n_gpu_layers = n_gpu_layers | |
self.gen_temperature = temperature | |
self.gen_top_p = top_p | |
self.gen_top_k = top_k | |
self.gen_rep_pen = rep_pen | |
self.gen_n_ctx = n_ctx | |
def update_chat_parameters(self, n_gpu_layers, temperature, top_p, top_k, rep_pen, n_ctx): | |
self.chat_n_gpu_layers = n_gpu_layers | |
self.chat_temperature = temperature | |
self.chat_top_p = top_p | |
self.chat_top_k = top_k | |
self.chat_rep_pen = rep_pen | |
self.chat_n_ctx = n_ctx | |
class LlamaAdapter: | |
def __init__(self, model_path, params, n_gpu_layers): | |
self.model_path = model_path | |
self.params = params | |
self.n_gpu_layers = n_gpu_layers | |
self.llm = Llama(model_path=model_path, n_ctx=params.chat_n_ctx, n_gpu_layers=n_gpu_layers) | |
def generate_text(self, text, author_description, gen_characters, gen_token_multiplier, instruction): | |
max_tokens = int(gen_characters * gen_token_multiplier) | |
messages = [ | |
{"role": "system", "content": author_description}, | |
{"role": "user", "content": f"以下の指示に従ってテキストを生成してください:\n\n{instruction}\n\n生成するテキスト(目安は{gen_characters}文字):\n\n{text}"} | |
] | |
response = self.llm.create_chat_completion( | |
messages=messages, | |
max_tokens=max_tokens, | |
temperature=self.params.gen_temperature, | |
top_p=self.params.gen_top_p, | |
top_k=self.params.gen_top_k, | |
repeat_penalty=self.params.gen_rep_pen, | |
) | |
return response["choices"][0]["message"]["content"].strip() | |
def generate(self, prompt, max_new_tokens=10000, temperature=None, top_p=None, top_k=None, repeat_penalty=None): | |
if temperature is None: | |
temperature = self.params.chat_temperature | |
if top_p is None: | |
top_p = self.params.chat_top_p | |
if top_k is None: | |
top_k = self.params.chat_top_k | |
if repeat_penalty is None: | |
repeat_penalty = self.params.chat_rep_pen | |
response = self.llm( | |
prompt, | |
max_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repeat_penalty=repeat_penalty, | |
stop=["user:", "・会話履歴", "<END>"] | |
) | |
# 返り値の形式が変更された可能性があるため、より柔軟に処理 | |
if isinstance(response, dict) and "choices" in response: | |
return response["choices"][0]["text"] | |
elif isinstance(response, str): | |
return response | |
else: | |
raise ValueError(f"Unexpected response format: {type(response)}") | |
def create_chat_completion(self, messages, max_tokens, temperature, top_p, top_k, repeat_penalty): | |
return self.llm.create_chat_completion( | |
messages=messages, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repeat_penalty=repeat_penalty | |
) | |
def chat_or_gen(text, gen_characters, gen_token_multiplier, instruction, mode): | |
if mode == "chat": | |
return character_maker.generate_response(text) | |
elif mode == "gen": | |
return character_maker.generate_text(text, gen_characters, gen_token_multiplier, instruction) | |
class CharacterMaker: | |
def __init__(self): | |
self.llama = None | |
self.history = [] | |
self.chat_history = [] | |
self.settings = None | |
self.model_loaded = threading.Event() | |
self.current_model = None | |
self.model_lock = threading.Lock() | |
self.use_chat_format = False | |
self.last_loaded_settings = None | |
self.chat_model_settings = None | |
self.gen_model_settings = None | |
def load_model(self, model_type): | |
with self.model_lock: | |
new_settings = self.get_current_settings(model_type) | |
if model_type == 'CHAT': | |
if self.chat_model_settings == new_settings: | |
print("CHATモデルの設定に変更がないため、リロードをスキップします。") | |
return | |
self.chat_model_settings = new_settings | |
else: # GEN | |
if self.gen_model_settings == new_settings: | |
print("GENモデルの設定に変更がないため、リロードをスキップします。") | |
return | |
self.gen_model_settings = new_settings | |
if self.are_models_identical(): | |
if self.llama and self.current_model == 'SHARED': | |
print("CHATモデルとGENモデルの設定が同じで、既にロードされています。リロードをスキップします。") | |
return | |
print("CHATモデルとGENモデルの設定が同じです。共有モデルとしてロードします。") | |
self.reload_model('SHARED', new_settings) | |
else: | |
print(f"{model_type}モデルをロードします。") | |
self.reload_model(model_type, new_settings) | |
def reload_model(self, model_type, settings): | |
if self.llama: | |
del self.llama | |
self.llama = None | |
self.model_loaded.clear() | |
try: | |
model_path = os.path.join(MODEL_DIR, settings['model_path']) | |
self.llama = LlamaAdapter(model_path, params, settings['n_gpu_layers']) | |
self.current_model = model_type | |
self.model_loaded.set() | |
print(f"{model_type}モデルをロードしました。モデルパス: {model_path}、GPUレイヤー数: {settings['n_gpu_layers']}") | |
except Exception as e: | |
print(f"{model_type}モデルのロード中にエラーが発生しました: {e}") | |
import traceback | |
traceback.print_exc() | |
self.model_loaded.set() | |
def get_current_settings(self, model_type): | |
return { | |
'model_path': self.settings[f'DEFAULT_{model_type.upper()}_MODEL'], | |
'n_gpu_layers': self.settings[f'{model_type.lower()}_n_gpu_layers'], | |
'temperature': self.settings[f'{model_type.lower()}_temperature'], | |
'top_p': self.settings[f'{model_type.lower()}_top_p'], | |
'top_k': self.settings[f'{model_type.lower()}_top_k'], | |
'rep_pen': self.settings[f'{model_type.lower()}_rep_pen'], | |
'n_ctx': self.settings[f'{model_type.lower()}_n_ctx'] | |
} | |
def are_models_identical(self): | |
return self.chat_model_settings == self.gen_model_settings | |
def generate_response(self, input_str): | |
self.load_model('CHAT') | |
if not self.model_loaded.wait(timeout=30) or not self.llama: | |
return "モデルのロードに失敗しました。設定を確認してください。" | |
try: | |
if self.use_chat_format: | |
chat_messages = [{"role": "system", "content": self.settings.get('chat_author_description', '')}] | |
chat_messages.extend(self.chat_history) | |
chat_messages.append({"role": "user", "content": input_str}) | |
response = self.llama.llm.create_chat_completion( | |
messages=chat_messages, | |
max_tokens=1000, | |
temperature=self.llama.params.chat_temperature, | |
top_p=self.llama.params.chat_top_p, | |
top_k=self.llama.params.chat_top_k, | |
repeat_penalty=self.llama.params.chat_rep_pen, | |
) | |
res_text = response["choices"][0]["message"]["content"].strip() | |
self.chat_history.append({"role": "user", "content": input_str}) | |
self.chat_history.append({"role": "assistant", "content": res_text}) | |
else: | |
prompt = self._generate_prompt(input_str) | |
res_text = self.llama.generate(prompt, max_new_tokens=1000) | |
self.history.append({"user": input_str, "assistant": res_text}) | |
return res_text | |
except Exception as e: | |
print(f"レスポンス生成中にエラーが発生しました: {str(e)}") | |
return "レスポンス生成中にエラーが発生しました。設定を確認してください。" | |
def generate_text(self, text, gen_characters, gen_token_multiplier, instruction): | |
self.load_model('GEN') | |
if not self.model_loaded.wait(timeout=30) or not self.llama: | |
return "モデルのロードに失敗しました。設定を確認してください。" | |
author_description = self.settings.get('gen_author_description', '') | |
max_tokens = int(gen_characters * gen_token_multiplier) | |
try: | |
if self.use_chat_format: | |
messages = [ | |
{"role": "system", "content": author_description}, | |
{"role": "user", "content": f"以下の指示に従ってテキストを生成してください:\n\n{instruction}\n\n生成するテキスト(目安は{gen_characters}文字):\n\n{text}"} | |
] | |
response = self.llama.create_chat_completion( | |
messages=messages, | |
max_tokens=max_tokens, | |
temperature=self.llama.params.gen_temperature, | |
top_p=self.llama.params.gen_top_p, | |
top_k=self.llama.params.gen_top_k, | |
repeat_penalty=self.llama.params.gen_rep_pen, | |
) | |
generated_text = response["choices"][0]["message"]["content"].strip() | |
else: | |
prompt = f"{author_description}\n\n以下の指示に従ってテキストを生成してください:\n\n{instruction}\n\n生成するテキスト(目安は{gen_characters}文字):\n\n{text}\n\n生成されたテキスト:" | |
generated_text = self.llama.generate( | |
prompt, | |
max_new_tokens=max_tokens | |
) | |
return generated_text | |
except Exception as e: | |
print(f"テキスト生成中にエラーが発生しました: {str(e)}") | |
return "テキスト生成中にエラーが発生しました。設定を確認してください。" | |
def set_chat_format(self, use_chat_format): | |
self.use_chat_format = use_chat_format | |
def make_prompt(self, input_str: str): | |
prompt_template = """{{chat_author_description}} | |
{{chat_instructions}} | |
・キャラクターの回答例 | |
{% for qa in example_qa %} | |
{{qa}} | |
{% endfor %} | |
・会話履歴 | |
{% for history in histories %} | |
user: {{history.user}} | |
assistant: {{history.assistant}} | |
{% endfor %} | |
user: {{input_str}} | |
assistant:""" | |
template = Template(prompt_template) | |
return template.render( | |
chat_author_description=self.settings.get('chat_author_description', ''), | |
chat_instructions=self.settings.get('chat_instructions', ''), | |
example_qa=self.settings.get('example_qa', []), | |
histories=self.history, | |
input_str=input_str | |
) | |
def _generate_prompt(self, input_str: str): | |
return self.make_prompt(input_str) | |
def load_character(self, filename): | |
if isinstance(filename, list): | |
filename = filename[0] if filename else "" | |
self.settings = Settings.load_from_ini(filename) | |
def reset(self): | |
self.history = [] | |
self.chat_history = [] | |
self.use_chat_format = False | |
# グローバル変数 | |
params = GenTextParams() | |
character_maker = CharacterMaker() | |
model_files = ModelManager.get_model_files() | |
# チャット関連関数 | |
def chat_with_character(message, history): | |
if character_maker.use_chat_format: | |
character_maker.chat_history = [{"role": "user" if i % 2 == 0 else "assistant", "content": msg} for i, msg in enumerate(sum(history, []))] | |
else: | |
character_maker.history = [{"user": h[0], "assistant": h[1]} for h in history] | |
return chat_or_gen(text=message, gen_characters=None, gen_token_multiplier=None, instruction=None, mode="chat") | |
def chat_with_character_stream(message, history): | |
if character_maker.use_chat_format: | |
character_maker.chat_history = [{"role": "user" if i % 2 == 0 else "assistant", "content": msg} for i, msg in enumerate(sum(history, []))] | |
else: | |
character_maker.history = [{"user": h[0], "assistant": h[1]} for h in history] | |
response = chat_or_gen(text=message, gen_characters=None, gen_token_multiplier=None, instruction=None, mode="chat") | |
for i in range(len(response)): | |
time.sleep(0.05) # 各文字の表示間隔を調整 | |
yield response[:i+1] | |
# 文章生成関連関数 | |
def generate_text_wrapper(text, gen_characters, gen_token_multiplier, instruction): | |
return chat_or_gen(text=text, gen_characters=gen_characters, gen_token_multiplier=gen_token_multiplier, instruction=instruction, mode="gen") | |
# ログ関連関数 | |
def load_chat_log(file_name): | |
file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs", file_name) | |
chat_history = [] | |
with open(file_path, 'r', encoding='utf-8') as csvfile: | |
reader = csv.reader(csvfile) | |
next(reader) # Skip header | |
for row in reader: | |
if len(row) == 2: | |
role, message = row | |
if role == "user": | |
chat_history.append([message, None]) | |
elif role == "assistant": | |
if chat_history and chat_history[-1][1] is None: | |
chat_history[-1][1] = message | |
else: | |
chat_history.append([None, message]) | |
return chat_history | |
def save_and_download_chat_log(chat_history): | |
# 一時ファイルを作成 | |
current_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S") | |
filename = f"chat_log_{current_time}.csv" | |
temp_dir = tempfile.gettempdir() | |
file_path = os.path.join(temp_dir, filename) | |
with open(file_path, 'w', newline='', encoding='utf-8') as csvfile: | |
writer = csv.writer(csvfile) | |
writer.writerow(["Role", "Message"]) | |
for user_message, assistant_message in chat_history: | |
if user_message: | |
writer.writerow(["user", user_message]) | |
if assistant_message: | |
writer.writerow(["assistant", assistant_message]) | |
return file_path, filename # ファイルパスとファイル名を返す | |
def cleanup_temp_file(file_path): | |
try: | |
os.remove(file_path) | |
print(f"Temporary file {file_path} has been removed.") | |
except Exception as e: | |
print(f"Error removing temporary file {file_path}: {e}") | |
# チャットタブ内のダウンロードボタンのクリックイベントを更新 | |
def update_download_output(chat_history): | |
file_path, filename = save_and_download_chat_log(chat_history) | |
return gr.File(value=file_path, visible=True, label="ダウンロード準備完了"), file_path | |
def resume_chat_from_log(chat_history): | |
# チャットボットのUIを更新 | |
chatbot_ui = gr.update(value=chat_history) | |
# LLMの履歴を更新 | |
character_maker.history = [{"user": h[0], "assistant": h[1]} for h in chat_history if h[0] is not None and h[1] is not None] | |
return chatbot_ui | |
# グローバル変数として定義 | |
temp_settings = {} | |
def update_temp_setting(section, key, value): | |
global temp_settings | |
if section not in temp_settings: | |
temp_settings[section] = {} | |
temp_settings[section][key] = value | |
return f"{section}セクションの{key}を更新しました。適用ボタンを押すと設定が保存されます。" | |
def build_model_settings(config, section, output): | |
model_settings = [] | |
for key in ['DEFAULT_CHAT_MODEL', 'DEFAULT_GEN_MODEL']: | |
if key in config[section]: | |
with gr.Row(): | |
dropdown = gr.Dropdown( | |
label=key, | |
choices=ModelManager.get_model_files(), | |
value=config[section][key] | |
) | |
refresh_button = gr.Button("更新", size="sm") | |
status_message = gr.Markdown() | |
def update_dropdown(current_value): | |
model_files = ModelManager.get_model_files() | |
if current_value not in model_files: | |
model_files.insert(0, current_value) | |
status = f"現在の{key}({current_value})が見つかりません。ダウンロードしてください。" | |
else: | |
status = "モデルリストを更新しました。" | |
return gr.update(choices=model_files, value=current_value), status | |
refresh_button.click( | |
fn=update_dropdown, | |
inputs=[dropdown], | |
outputs=[dropdown, status_message] | |
) | |
dropdown.change( | |
partial(update_temp_setting, 'Models', key), | |
inputs=[dropdown], | |
outputs=[output] | |
) | |
model_settings.extend([dropdown, refresh_button, status_message]) | |
return model_settings | |
def apply_settings(): | |
global temp_settings | |
settings_changed = False | |
for section, settings in temp_settings.items(): | |
for key, value in settings.items(): | |
old_value = ConfigManager.load_settings(DEFAULT_INI_FILE)[section].get(key) | |
if str(value) != str(old_value): | |
ConfigManager.update_setting(section, key, str(value), DEFAULT_INI_FILE) | |
settings_changed = True | |
if not settings_changed: | |
return "設定に変更はありませんでした。" | |
# iniファイルを再読み込み | |
new_config = ConfigManager.load_settings(DEFAULT_INI_FILE) | |
# 設定を更新 | |
character_maker.settings = Settings._parse_config(new_config) | |
# パラメータを更新 | |
if 'ChatParameters' in new_config: | |
params.update_chat_parameters( | |
int(new_config['ChatParameters'].get('n_gpu_layers', '-1')), | |
float(new_config['ChatParameters'].get('temperature', '0.35')), | |
float(new_config['ChatParameters'].get('top_p', '0.9')), | |
int(new_config['ChatParameters'].get('top_k', '40')), | |
float(new_config['ChatParameters'].get('repetition_penalty', '1.2')), | |
int(new_config['ChatParameters'].get('n_ctx', '10000')) | |
) | |
if 'GenerateParameters' in new_config: | |
params.update_generate_parameters( | |
int(new_config['GenerateParameters'].get('n_gpu_layers', '-1')), | |
float(new_config['GenerateParameters'].get('temperature', '0.35')), | |
float(new_config['GenerateParameters'].get('top_p', '0.9')), | |
int(new_config['GenerateParameters'].get('top_k', '40')), | |
float(new_config['GenerateParameters'].get('repetition_penalty', '1.2')), | |
int(new_config['GenerateParameters'].get('n_ctx', '10000')) | |
) | |
# モデルのリロードをトリガー(実際のリロードは次の操作時に行われる) | |
character_maker.chat_model_settings = None | |
character_maker.gen_model_settings = None | |
# temp_settings をクリア | |
temp_settings.clear() | |
return "設定をiniファイルに保存し、アプリケーションに反映しました。次回の操作時に新しい設定が適用されます。" | |
# Gradioインターフェース | |
def build_gradio_interface(): | |
global temp_settings | |
def apply_settings_wrapper(): | |
return apply_settings() | |
def update_temp_setting(section, key, value): | |
global temp_settings | |
if section not in temp_settings: | |
temp_settings[section] = {} | |
temp_settings[section][key] = value | |
return f"{section}セクションの{key}を更新しました。適用ボタンを押すと設定が保存されます。" | |
with gr.Blocks() as iface: | |
# 新しいメッセージを追加 | |
gr.HTML(""" | |
<div style="background-color: #f0f0f0; padding: 10px; margin-bottom: 10px; border-radius: 5px;"> | |
<strong>注意:</strong> 一応念のため、NSFW創作用途の場合はモデルを設定タブから、「EZO-Common-9B-gemma-2-it.Q8_0.gguf」→「Mistral-Nemo-Instruct-2407-Q8_0.gguf」に変更推奨です。 | |
<br> | |
<a href="https://note.com/eurekachan/n/nd05d6307fead" target="_blank" style="color: #007bff; text-decoration: underline;"> | |
参考情報はこちら | |
</a> | |
</div> | |
""") | |
gr.HTML(""" | |
<style> | |
#chatbot, #chatbot_read { | |
resize: both; | |
overflow: auto; | |
min-height: 100px; | |
max-height: 80vh; | |
} | |
</style> | |
""") | |
tabs = gr.Tabs() | |
with tabs: | |
with gr.Tab("チャット", id="chat_tab") as chat_tab: | |
chatbot = gr.Chatbot(elem_id="chatbot") | |
chat_interface = gr.ChatInterface( | |
chat_with_character_stream, | |
chatbot=chatbot, | |
textbox=gr.Textbox(placeholder="メッセージを入力してください...", container=False, scale=7), | |
theme="soft", | |
submit_btn="送信", | |
stop_btn="停止", | |
retry_btn="もう一度生成", | |
undo_btn="前のメッセージを取り消す", | |
clear_btn="チャットをクリア", | |
) | |
with gr.Row(): | |
download_log_button = gr.Button("チャットログをダウンロード") | |
download_log_output = gr.File(label="ダウンロード", visible=False) | |
temp_file_path = gr.State() # 一時ファイルのパスを保存するための状態変数 | |
def update_download_output(chat_history): | |
file_path, filename = save_and_download_chat_log(chat_history) | |
return gr.File(value=file_path, visible=True, label="ダウンロード準備完了"), file_path | |
download_log_button.click( | |
update_download_output, | |
inputs=[chatbot], | |
outputs=[download_log_output, temp_file_path] | |
) | |
# ダウンロード完了後に一時ファイルを削除 | |
download_log_output.change( | |
cleanup_temp_file, | |
inputs=[temp_file_path] | |
) | |
with gr.Tab("文章生成"): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
instruction_type = gr.Dropdown( | |
choices=["自由入力", "推敲", "プロット作成", "あらすじ作成", "地の文追加"], | |
label="指示タイプ", | |
value="自由入力" | |
) | |
gen_instruction = gr.Textbox( | |
label="指示", | |
value="", | |
lines=3 | |
) | |
gen_input_text = gr.Textbox(lines=5, label="処理されるテキストを入力してください") | |
gen_input_char_count = gr.HTML(value="文字数: 0") | |
with gr.Column(scale=1): | |
gen_characters = gr.Slider(minimum=10, maximum=10000, value=500, step=10, label="出力文字数", info="出力文字数の目安") | |
gen_token_multiplier = gr.Slider(minimum=0.35, maximum=3, value=1.75, step=0.01, label="文字/トークン数倍率", info="文字/最大トークン数倍率") | |
generate_button = gr.Button("文章生成開始") | |
generated_output = gr.Textbox(label="生成された文章") | |
generate_button.click( | |
generate_text_wrapper, | |
inputs=[gen_input_text, gen_characters, gen_token_multiplier, gen_instruction], | |
outputs=[generated_output] | |
) | |
def update_instruction(choice): | |
instructions = { | |
"自由入力": "", | |
"推敲": "以下のテキストを推敲してください。原文の文体や特徴的な表現は保持しつつ、必要に応じて微調整を加えてください。文章の流れを自然にし、表現を洗練させることが目標ですが、元の雰囲気や個性を損なわないよう注意してください。", | |
"プロット作成": "以下のテキストをプロットにしてください。起承転結に分割すること。", | |
"あらすじ作成": "以下のテキストをあらすじにして、簡潔にまとめて下さい。", | |
"地の文追加": "以下のテキストの地の文を増やして、描写を膨らませるように推敲してください。文章の流れを自然にし、表現を洗練させることが目標ですが、なるべく元の文の意味や流れは残してください、", | |
} | |
return instructions.get(choice, "") | |
instruction_type.change( | |
update_instruction, | |
inputs=[instruction_type], | |
outputs=[gen_instruction] | |
) | |
def update_char_count(text): | |
return f"文字数: {len(text)}" | |
gen_input_text.change( | |
update_char_count, | |
inputs=[gen_input_text], | |
outputs=[gen_input_char_count] | |
) | |
# ログ閲覧タブの実装 | |
with gr.Tab("ログ閲覧", id="log_view_tab") as log_view_tab: | |
gr.Markdown("## チャットログ閲覧") | |
chatbot_read = gr.Chatbot(elem_id="chatbot_read") | |
log_file_upload = gr.File(label="ログファイルをアップロード", file_types=[".csv"]) | |
resume_chat_button = gr.Button("選択したログから会話を再開") | |
def load_and_display_uploaded_chat_log(file): | |
if file is None: | |
return gr.update(value=[]) | |
chat_history = [] | |
with open(file.name, 'r', encoding='utf-8') as csvfile: | |
reader = csv.reader(csvfile) | |
next(reader) # Skip header | |
for row in reader: | |
if len(row) == 2: | |
role, message = row | |
if role == "user": | |
chat_history.append([message, None]) | |
elif role == "assistant": | |
if chat_history and chat_history[-1][1] is None: | |
chat_history[-1][1] = message | |
else: | |
chat_history.append([None, message]) | |
return gr.update(value=chat_history) | |
log_file_upload.change( | |
load_and_display_uploaded_chat_log, | |
inputs=[log_file_upload], | |
outputs=[chatbot_read] | |
) | |
def resume_chat_and_switch_tab(chat_history): | |
chatbot_ui = resume_chat_from_log(chat_history) | |
return chatbot_ui, gr.update(selected="chat_tab") | |
resume_chat_button.click( | |
resume_chat_and_switch_tab, | |
inputs=[chatbot_read], | |
outputs=[chatbot, tabs] | |
) | |
with gr.Tab("設定"): | |
output = gr.Textbox(label="更新状態") | |
config = ConfigManager.load_settings(DEFAULT_INI_FILE) | |
with gr.Column(): | |
gr.Markdown("### モデル設定") | |
model_settings = build_model_settings(config, "Models", output) | |
gr.Markdown("### チャット設定") | |
for key in ['chat_author_description', 'chat_instructions', 'example_qa']: | |
if key == 'example_qa': | |
input_component = gr.TextArea(label=key, value=config['Character'].get(key, ''), lines=10) | |
else: | |
input_component = gr.TextArea(label=key, value=config['Character'].get(key, ''), lines=5) | |
input_component.change( | |
partial(update_temp_setting, 'Character', key), | |
inputs=[input_component], | |
outputs=[output] | |
) | |
gr.Markdown("### 文章生成設定") | |
key = 'gen_author_description' | |
input_component = gr.TextArea(label=key, value=config['Character'].get(key, ''), lines=5) | |
input_component.change( | |
partial(update_temp_setting, 'Character', key), | |
inputs=[input_component], | |
outputs=[output] | |
) | |
gr.Markdown("### チャットパラメータ設定") | |
for key in ['n_gpu_layers', 'temperature', 'top_p', 'top_k', 'repetition_penalty', 'n_ctx']: | |
value = config['ChatParameters'].get(key, '0') | |
if key == 'n_gpu_layers': | |
input_component = gr.Slider(label=key, value=int(value), minimum=-1, maximum=255, step=1) | |
elif key in ['temperature', 'top_p', 'repetition_penalty']: | |
input_component = gr.Slider(label=key, value=float(value), minimum=0.0, maximum=1.0, step=0.05) | |
elif key == 'top_k': | |
input_component = gr.Slider(label=key, value=int(value), minimum=1, maximum=200, step=1) | |
elif key == 'n_ctx': | |
input_component = gr.Slider(label=key, value=int(value), minimum=10000, maximum=100000, step=1000) | |
else: | |
input_component = gr.Textbox(label=key, value=value) | |
input_component.change( | |
partial(update_temp_setting, 'ChatParameters', key), | |
inputs=[input_component], | |
outputs=[output] | |
) | |
gr.Markdown("### 文章生成パラメータ設定") | |
for key in ['n_gpu_layers', 'temperature', 'top_p', 'top_k', 'repetition_penalty', 'n_ctx']: | |
value = config['GenerateParameters'].get(key, '0') | |
if key == 'n_gpu_layers': | |
input_component = gr.Slider(label=key, value=int(value), minimum=-1, maximum=255, step=1) | |
elif key in ['temperature', 'top_p', 'repetition_penalty']: | |
input_component = gr.Slider(label=key, value=float(value), minimum=0.0, maximum=1.0, step=0.05) | |
elif key == 'top_k': | |
input_component = gr.Slider(label=key, value=int(value), minimum=1, maximum=200, step=1) | |
elif key == 'n_ctx': | |
input_component = gr.Slider(label=key, value=int(value), minimum=10000, maximum=100000, step=1000) | |
else: | |
input_component = gr.Textbox(label=key, value=value) | |
input_component.change( | |
partial(update_temp_setting, 'GenerateParameters', key), | |
inputs=[input_component], | |
outputs=[output] | |
) | |
apply_ini_settings_button = gr.Button("設定を適用") | |
apply_ini_settings_button.click( | |
apply_settings, | |
outputs=[output] | |
) | |
return iface | |
async def start_gradio(): | |
print(f"{DEFAULT_INI_FILE} をデフォルト設定で上書きします。") | |
Settings.create_default_ini(DEFAULT_INI_FILE) | |
config = ConfigManager.load_settings(DEFAULT_INI_FILE) | |
settings = Settings._parse_config(config) | |
character_maker.settings = settings | |
character_maker.load_character(DEFAULT_INI_FILE) | |
# パラメータの初期化 | |
params.update_chat_parameters( | |
settings['chat_n_gpu_layers'], | |
settings['chat_temperature'], | |
settings['chat_top_p'], | |
settings['chat_top_k'], | |
settings['chat_rep_pen'], | |
settings['chat_n_ctx'] | |
) | |
params.update_generate_parameters( | |
settings['gen_n_gpu_layers'], | |
settings['gen_temperature'], | |
settings['gen_top_p'], | |
settings['gen_top_k'], | |
settings['gen_rep_pen'], | |
settings['gen_n_ctx'] | |
) | |
demo = build_gradio_interface() | |
ip_address = NetworkUtils.get_ip_address() | |
starting_port = 7860 | |
port = NetworkUtils.find_available_port(starting_port) | |
print(f"サーバーのアドレス: http://{ip_address}:{port}") | |
demo.queue() | |
demo.launch( | |
server_name='0.0.0.0', | |
server_port=port, | |
share=True, | |
favicon_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), "custom.html") | |
) | |
if __name__ == "__main__": | |
asyncio.run(start_gradio()) |