Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
from typing import Dict | |
from typing import List | |
from typing import Tuple | |
from typing import Union | |
from pathlib import Path | |
from src.logger import LoggerFactory | |
from src.prompt_concat import GetManualTestSamples, CreateTestDataset | |
from src.utils import decode_csv_to_json, load_json, save_to_json | |
from threading import Thread | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
GenerationConfig, | |
TextIteratorStreamer, | |
) | |
from typing import List | |
import gradio as gr | |
import logging | |
import os | |
import shutil | |
import torch | |
import warnings | |
import random | |
import spaces | |
logger = LoggerFactory.create_logger(name="test", level=logging.INFO) | |
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') | |
MODEL_PATH = os.environ.get('MODEL_PATH', 'IndexTeam/Index-1.9B-Character') | |
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH) | |
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.float16, device_map="auto", | |
trust_remote_code=True) | |
character_path = "./character" | |
def _resolve_path(path: Union[str, Path]) -> Path: | |
return Path(path).expanduser().resolve() | |
# logger = LoggerFactory.create_logger(name="test", level=logging.INFO) | |
# warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') | |
# config_data = load_json("config/config.json") | |
# model_path = config_data["huggingface_local_path"] | |
# character_path = "./character" | |
# tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
# model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto", | |
# trust_remote_code=True) | |
def generate_with_question(question, role_name, role_file_path): | |
question_in = "\n".join(["\n".join(pair) for pair in question]) | |
g = GetManualTestSamples( | |
role_name=role_name, | |
role_data_path=f"./character/{role_file_path}.json", | |
save_samples_dir="./character", | |
save_samples_path= role_file_path + "_rag.json", | |
prompt_path="./prompt/dataset_character.txt", | |
max_seq_len=4000 | |
) | |
g.get_qa_samples_by_query( | |
questions_query=question_in, | |
keep_retrieve_results_flag=True | |
) | |
def create_datasets(role_name, role_file_path): | |
testset = [] | |
role_samples_path = os.path.join("./character", role_file_path + "_rag.json") | |
c = CreateTestDataset(role_name=role_name, | |
role_samples_path=role_samples_path, | |
role_data_path=role_samples_path, | |
prompt_path="./prompt/dataset_character.txt" | |
) | |
res = c.load_samples() | |
testset.extend(res) | |
save_to_json(testset, f"./character/{role_file_path}_测试问题.json") | |
def hf_gen(dialog: List, role_name, role_file_path, top_k, top_p, temperature, repetition_penalty, max_dec_len): | |
generate_with_question(dialog, role_name,role_file_path) | |
create_datasets(role_name,role_file_path) | |
json_data = load_json(f"{character_path}/{role_file_path}_测试问题.json")[0] | |
text = json_data["input_text"] | |
inputs = tokenizer(text, return_tensors="pt") | |
if torch.cuda.is_available(): | |
model.to("cuda") | |
inputs.to("cuda") | |
streamer = TextIteratorStreamer(tokenizer, **tokenizer.init_kwargs) | |
generation_kwargs = dict( | |
inputs, | |
do_sample=True, | |
top_k=int(top_k), | |
top_p=float(top_p), | |
temperature=float(temperature), | |
repetition_penalty=float(repetition_penalty), | |
max_new_tokens=int(max_dec_len), | |
pad_token_id=tokenizer.eos_token_id, | |
streamer=streamer, | |
) | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
answer = "" | |
for new_text in streamer: | |
answer += new_text | |
yield answer[len(text):] | |
def generate(chat_history: List, query, role_name, role_desc, role_file_path, top_k, top_p, temperature, repetition_penalty, max_dec_len): | |
"""generate after hitting "submit" button | |
Args: | |
chat_history (List): [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. list that stores all QA records | |
query (str): query of current round | |
top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. | |
temperature (float): strictly positive float value used to modulate the logits distribution. | |
max_dec_len (int): The maximum numbers of tokens to generate. | |
Yields: | |
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n], [q_n+1, a_n+1]]. chat_history + QA of current round. | |
""" | |
assert query != "", "Input must not be empty!!!" | |
# apply chat template | |
chat_history.append([f"user:{query}", ""]) | |
if role_name == "三三": | |
role_file_path = "三三" | |
for answer in hf_gen(chat_history, role_name,role_file_path, top_k, top_p, temperature, repetition_penalty, max_dec_len): | |
chat_history[-1][1] = role_name + ":" + answer | |
yield gr.update(value=""), chat_history | |
def regenerate(chat_history: List,role_name, role_description, role_file_path, top_k, top_p, temperature, repetition_penalty, max_dec_len): | |
"""re-generate the answer of last round's query | |
Args: | |
chat_history (List): [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. list that stores all QA records | |
top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. | |
temperature (float): strictly positive float value used to modulate the logits distribution. | |
max_dec_len (int): The maximum numbers of tokens to generate. | |
Yields: | |
List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. chat_history | |
""" | |
assert len(chat_history) >= 1, "History is empty. Nothing to regenerate!!" | |
if len(chat_history[-1]) > 1: | |
chat_history[-1][1] = "" | |
# apply chat template | |
if role_name == "三三": | |
role_file_path = "三三" | |
for answer in hf_gen(chat_history, role_name,role_file_path, top_k, top_p, temperature, repetition_penalty, max_dec_len): | |
chat_history[-1][1] = role_name + ":" + answer | |
yield gr.update(value=""), chat_history | |
def clear_history(): | |
"""clear all chat history | |
Returns: | |
List: empty chat history | |
""" | |
torch.cuda.empty_cache() | |
return [] | |
def delete_current_user(user_role_path): | |
try: | |
role_upload_path = os.path.join(character_path, user_role_path + ".csv") | |
role_path = os.path.join(character_path, user_role_path + ".json") | |
rag_path = os.path.join(character_path, user_role_path + "_rag.json") | |
question_path = os.path.join(character_path, user_role_path + "_测试问题.json") | |
files_to_delete = [role_upload_path, role_path, rag_path, question_path] | |
for file_path in files_to_delete: | |
os.remove(file_path) | |
except Exception as e: | |
print(e) | |
# launch gradio demo | |
with gr.Blocks(theme="soft") as demo: | |
gr.Markdown("""# Index-1.9B RolePlay Gradio Demo""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
top_k = gr.Slider(0, 10, value=5, step=1, label="top_k") | |
top_p = gr.Slider(0, 1, value=0.8, step=0.8, label="top_p") | |
temperature = gr.Slider(0.1, 2.0, value=0.85, step=0.1, label="temperature") | |
repetition_penalty = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="repetition_penalty") | |
max_dec_len = gr.Slider(1, 4096, value=512, step=1, label="max_dec_len") | |
file_input = gr.File(label="上传角色对话语料(.csv)") | |
role_description = gr.Textbox(label="Role Description", placeholder="输入角色描述", lines=2) | |
upload_button = gr.Button("生成角色!") | |
new_path = gr.State() | |
def generate_file(file_obj, role_info): | |
random.seed() | |
alphabet = 'abcdefghijklmnopqrstuvwxyz!@#$%^&*()' | |
random_char = "".join(random.choice(alphabet) for _ in range(10)) | |
role_name = os.path.basename(file_obj).split(".")[0] | |
new_path = role_name + random_char | |
new_save_path = os.path.join(character_path, new_path+".csv") | |
shutil.copy(file_obj, new_save_path) | |
new_file_path = os.path.join(character_path, new_path) | |
decode_csv_to_json(os.path.join(character_path, new_path + ".csv"), role_name, role_info, | |
new_file_path + ".json" ) | |
gr.Info(f"{role_name}生成成功") | |
return new_path | |
upload_button.click(generate_file, inputs=[file_input, role_description],outputs=new_path) | |
with gr.Column(scale=10): | |
chatbot = gr.Chatbot(bubble_full_width=False, height=400, label='Index-1.9B') | |
with gr.Row(): | |
role_name = gr.Textbox(label="Role name", placeholder="Input your rolename here!", lines=2) | |
user_input = gr.Textbox(label="User", placeholder="Input your query here!", lines=2) | |
with gr.Row(): | |
submit = gr.Button("🚀 Submit") | |
clear = gr.Button("🧹 Clear") | |
regen = gr.Button("🔄 Regenerate") | |
submit.click(generate, inputs=[chatbot, user_input, role_name, role_description, new_path, top_k, top_p, temperature, | |
repetition_penalty, max_dec_len], | |
outputs=[user_input, chatbot]) | |
regen.click(regenerate, | |
inputs=[chatbot, role_name, role_description, new_path, top_k, top_p, temperature, repetition_penalty, | |
max_dec_len], | |
outputs=[user_input, chatbot]) | |
clear.click(clear_history, inputs=[], outputs=[chatbot]) | |
demo.queue().launch() |