shogiapp / app.py
pizzagatakasugi's picture
Update app.py
eb576ee
raw
history blame
2.12 kB
import streamlit as st
import cshogi
from IPython.display import display
from transformers import T5ForConditionalGeneration, T5Tokenizer
# board = cshogi.Board()
# st.write(display(board.to_svg()),unsafe_allow_html=True)
# st.write(board,unsafe_allow_html=True)
# st.write(print(board), unsafe_allow_html=True)
tokenizer = T5Tokenizer.from_pretrained("pizzagatakasugi/shogi_t5", is_fast=True)
model = T5ForConditionalGeneration.from_pretrained("pizzagatakasugi/shogi_t5")
model.eval()
cnt = 0
with open("./sample.tsv", "r", encoding="utf-8") as f:
for line in f:
line = line.strip().split("\t")
input = line[0]
output = line[1]
pre_name = line[2]
fow_name = line[3]
sfen = line[4]
intput = f"{input}"
output = f"{output}"
board = cshogi.Board(sfen=sfen)
st.markdown(board.to_svg(), unsafe_allow_html=True)
cnt += 1
s = ""
for i in input.replace("\u3000","").split("。")[0][::-1]:
if i in ["1","2","3","4","5","6","7","8","9","同"]:
s+= i
break
else:
s+= i
tokenized_inputs = tokenizer.encode(
input, max_length= 512, truncation=True,
padding="max_length", return_tensors="pt"
)
output_ids = model.generate(input_ids=tokenized_inputs,
max_length=512,
repetition_penalty=10.0, # 同じ文の繰り返しへのペナルティ
)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True,
clean_up_tokenization_spaces=False)
st.write("打ち手",s[::-1])
st.write("predict",input.replace("\u3000","").split("。")[1])
st.write("generate",output_text)
st.write("actual",output)
if cnt == 5:
break
del model