Edit model card

t5-base-korean-chit-chat

This model is a fine-tuning of paust/pko-t5-base model using AIHUB "ํ•œ๊ตญ์–ด SNS". This model infers the next conversation by using the conversation used on social media..

์ด ๋ชจ๋ธ์€ paust/pko-t5-large model์„ AIHUB "ํ•œ๊ตญ์–ด SNS"๋ฅผ ์ด์šฉํ•˜์—ฌ fine tunning ํ•œ ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ด ๋ชจ๋ธ์€ SNS์ƒ์—์„œ ์‚ฌ์šฉ๋˜๋Š” ๋Œ€ํ™”๋ฅผ ์ด์šฉํ•˜์—ฌ ๋‹ค์Œ ๋Œ€ํ™”๋ฅผ ์ถ”๋ก  ํ•ฉ๋‹ˆ๋‹ค.

Usage


from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, MT5ForConditionalGeneration
from transformers import AutoTokenizer, T5TokenizerFast
import nltk
nltk.download('punkt')


model_dir = f"lcw99/t5-base-korean-chit-chat"

max_input_length = 1024

text = """
A: ์‡ผํ•‘ํ•˜๋Ÿฌ ๊ฐˆ๊นŒ? B: ์‘ ์ข‹์•„. A: ์–ธ์ œ ๊ฐˆ๊นŒ? B:
"""

inputs = [text]

tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)

inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, return_tensors="pt")
output = model.generate(**inputs, num_beams=3, do_sample=True, min_length=20, max_length=500, num_return_sequences=3)
for i in range(3):
    #print(output[i])
    print("---", i)
    decoded_output = tokenizer.decode(output[i], skip_special_tokens=True)
    predicted_title = nltk.sent_tokenize(decoded_output)
    #print(decoded_output)
    print(predicted_title)

import torch

chat_history = []
# Let's chat for 5 lines
for step in range(100):
    print("")
    user_input = input(">> User: ")
    chat_history.append("A: " + user_input)
    while len(chat_history) > 5:
        chat_history.pop(0)
    hist = ""
    for chat in chat_history:
        hist += "\n" + chat
    hist += "\nB: "
    new_user_input_ids = tokenizer.encode(hist, return_tensors='pt')

    bot_input_ids = new_user_input_ids

    # generated a response while limiting the total chat history to 1000 tokens, 
    chat_history_ids = model.generate(
        bot_input_ids, max_length=200,
        pad_token_id=tokenizer.eos_token_id,  
        do_sample=True, 
        #top_k=100, 
        #top_p=0.7,
        #temperature = 0.1
    )

    bot_text = tokenizer.decode(chat_history_ids[0], skip_special_tokens=True).replace("#@์ด๋ฆ„#", "OOO")
    bot_text = bot_text.replace("\n", " / ")
    chat_history.append("B: " + bot_text)
    
    # pretty print last ouput tokens from bot
    print("Bot: {}".format(bot_text))    

Framework versions

  • Transformers 4.22.1
  • TensorFlow 2.10.0
  • Datasets 2.5.1
  • Tokenizers 0.12.1
Downloads last month
16
Safetensors
Model size
276M params
Tensor type
F32
ยท
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Model tree for lcw99/t5-base-korean-chit-chat

Finetunes
1 model

Spaces using lcw99/t5-base-korean-chit-chat 4