import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM import warnings class Chatbot(): def __init__(self): self.tokenizer = AutoTokenizer.from_pretrained('kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b') special_tokens_dict = {'additional_special_tokens': ['', '', '', '#@이름#', '#@계정#', '#@신원#', '#@전번#', '#@금융#', '#@번호#', '#@주소#', '#@소속#', '#@기타#']} num_added_toks = self.tokenizer.add_special_tokens(special_tokens_dict) self.model = AutoModelForCausalLM.from_pretrained("/workspace/test_trainer/checkpoint-10000") self.model.resize_token_embeddings(len(self.tokenizer)) self.model = self.model.cuda() self.info = None self.talk = [] def initialize(self, topic, bot_addr, bot_age, bot_sex, my_addr, my_age, my_sex): def encode(age): if age < 20: age = "20대 미만" elif age >= 70: age = "70대 이상" else: age = str(age // 10 * 10) + "대" return age bot_age = encode(bot_age) my_age = encode(my_age) self.info = f"일상 대화 {topic}P01:{my_addr} {my_age} {my_sex}P02:{bot_addr} {bot_age} {bot_sex}" return self.info_check() def info_check(self): return self.info.replace('', '\n').replace('P01', '당신').replace('P02', '챗봇') def reset_talk(self): self.talk = [] def test(self, myinp): state = None inp = "P01" + myinp + "" self.talk.append(inp) self.talk.append("P02") while True: now_inp = self.info + "".join(self.talk) inputs = self.tokenizer(now_inp, max_length=1024, truncation='longest_first', return_tensors='pt') seq_len = inputs.input_ids.size(1) if seq_len > 512 * 0.8: state = f"<주의> 현재 대화 길이가 곧 최대 길이에 도달합니다. ({seq_len} / 512)" if seq_len >= 512: state = "<주의> 대화 길이가 너무 길어졌기 때문에, 이후 대화는 맨 앞의 발화를 조금씩 지우면서 진행됩니다." talk = talk[1:] else: break out = self.model.generate( inputs=inputs.input_ids.cuda(), attention_mask=inputs.attention_mask.cuda(), max_length=512, do_sample=True, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.encode('')[0] ) out = self.tokenizer.batch_decode(out) real_out = out[0][len(now_inp):-5] self.talk[-1] += out[0][len(now_inp):] return [(self.talk[i][8:-5], self.talk[i+1][8:-5]) for i in range(0, len(self.talk)-1, 2)] if __name__ == "__main__": warnings.filterwarnings("ignore") chatbot = Chatbot() demo = gr.Blocks() with demo: gr.Markdown("#
MINDs Lab Brain's Fast Neural Chit-Chatbot
") with gr.Row(): with gr.Column(): topic = gr.Radio(label="Topic", choices=['여가 생활', '시사/교육', '미용과 건강', '식음료', '상거래(쇼핑)', '일과 직업', '주거와 생활', '개인 및 관계', '행사']) with gr.Column(): gr.Markdown(f"Bot's persona") bot_addr = gr.Dropdown(label="지역", choices=['서울특별시', '경기도', '부산광역시', '대전광역시', '광주광역시', '울산광역시', '경상남도', '인천광역시', '충청북도', '제주도', '강원도', '충청남도', '전라북도', '대구광역시', '전라남도', '경상북도', '세종특별자치시', '기타']) bot_age = gr.Slider(label="나이", minimum=10, maximum=80, value=45, step=1) bot_sex = gr.Radio(label="성별", choices=["남성", "여성"]) with gr.Column(): gr.Markdown(f"Your persona") my_addr = gr.Dropdown(label="지역", choices=['서울특별시', '경기도', '부산광역시', '대전광역시', '광주광역시', '울산광역시', '경상남도', '인천광역시', '충청북도', '제주도', '강원도', '충청남도', '전라북도', '대구광역시', '전라남도', '경상북도', '세종특별자치시', '기타']) my_age = gr.Slider(label="나이", minimum=10, maximum=80, value=45, step=1) my_sex = gr.Radio(label="성별", choices=["남성", "여성"]) with gr.Row(): btn = gr.Button(label="적용") state = gr.Textbox(label="상태") btn.click( fn=chatbot.initialize, inputs=[topic, bot_addr, bot_age, bot_sex, my_addr, my_age, my_sex], outputs=state ) with gr.Column(): screen = gr.Chatbot(label="익명의 상대") with gr.Row(): speak = gr.Textbox(label="입력창") btn = gr.Button(label="Talk") btn.click( fn=chatbot.test, inputs=speak, outputs=screen ) demo.launch(share=True)