|
|
|
import numpy as np |
|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
|
|
st.set_page_config( |
|
page_title="๋ฒ์ญ๊ธฐ", layout="wide", initial_sidebar_state="expanded" |
|
) |
|
|
|
@st.cache |
|
def load_model(model_name): |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
return model |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("QuoQA-NLP/KE-T5-Ko2En-Base") |
|
ko2en_model = load_model("QuoQA-NLP/KE-T5-Ko2En-Base") |
|
en2ko_model = load_model("QuoQA-NLP/KE-T5-En2Ko-Base") |
|
|
|
|
|
st.title("๐ค ๋ฒ์ญ๊ธฐ") |
|
st.write("์ข์ธก์ ๋ฒ์ญ ๋ชจ๋๋ฅผ ์ ํํ๊ณ , CTRL+Enter(CMD+Enter)๋ฅผ ๋๋ฅด์ธ์ ๐ค") |
|
st.write("Select Translation Mode at the left and press CTRL+Enter(CMD+Enter)๐ค") |
|
|
|
translation_list = ["ํ๊ตญ์ด์์ ์์ด | Korean to English", "์์ด์์ ํ๊ตญ์ด | English to Korean"] |
|
translation_mode = st.sidebar.radio("๋ฒ์ญ ๋ชจ๋๋ฅผ ์ ํ(Translation Mode):", translation_list) |
|
|
|
|
|
default_value = '์ ํ์นด๋ ๊ด๊ณ์๋ "๊ณผ๊ฑฐ ๋ด๋์ ์ํ์ ๊ฒฝ์ฐ ์ถ์ 2๊ฐ์ ๋ง์ ์ ๊ธ ๊ฐ์
์ด 4๋ง์ฌ ์ข์ ๋ฌํ ์ ๋๋ก ์ธ๊ธฐ๋ฅผ ๋์๋ค"๋ฉด์ "๊ธ๋ฆฌ ์ธ์์ ๋ฐ๋ผ ์ ๊ธ ๊ธ๋ฆฌ๋ฅผ ๋ ์ฌ๋ ค ๋ง์ ๊ณ ๊ฐ์ด ๋ชฐ๋ฆด ๊ฒ์ผ๋ก ์์ํ๊ณ ์๋ค"๊ณ ๋งํ๋ค.' |
|
src_text = st.text_area( |
|
"๋ฒ์ญํ๊ณ ์ถ์ ๋ฌธ์ฅ์ ์
๋ ฅํ์ธ์:", |
|
default_value, |
|
height=300, |
|
max_chars=200, |
|
) |
|
print(src_text) |
|
|
|
|
|
|
|
if src_text == "": |
|
st.warning("Please **enter text** for translation") |
|
|
|
|
|
if translation_mode == translation_list[0]: |
|
model = ko2en_model |
|
else: |
|
model = en2ko_model |
|
|
|
translation_result = model.generate( |
|
**tokenizer( |
|
src_text, |
|
return_tensors="pt", |
|
padding="max_length", |
|
truncation=True, |
|
max_length=64, |
|
), |
|
max_length=64, |
|
num_beams=5, |
|
repetition_penalty=1.3, |
|
no_repeat_ngram_size=3, |
|
num_return_sequences=1, |
|
) |
|
translation_result = tokenizer.decode( |
|
translation_result[0], |
|
clean_up_tokenization_spaces=True, |
|
skip_special_tokens=True, |
|
) |
|
|
|
print(f"{src_text} -> {translation_result}") |
|
|
|
st.write(translation_result) |
|
print(translation_result) |
|
|