Spaces:
Sleeping
Sleeping
import streamlit as st | |
import cshogi | |
from IPython.display import display | |
from transformers import T5ForConditionalGeneration, T5Tokenizer | |
# board = cshogi.Board() | |
# st.write(display(board.to_svg()),unsafe_allow_html=True) | |
# st.write(board,unsafe_allow_html=True) | |
# st.write(print(board), unsafe_allow_html=True) | |
tokenizer = T5Tokenizer.from_pretrained("pizzagatakasugi/shogi_t5", is_fast=True) | |
model = T5ForConditionalGeneration.from_pretrained("pizzagatakasugi/shogi_t5") | |
model.eval() | |
cnt = 0 | |
with open("./sample.tsv", "r", encoding="utf-8") as f: | |
for line in f: | |
line = line.strip().split("\t") | |
input = line[0] | |
output = line[1] | |
pre_name = line[2] | |
fow_name = line[3] | |
sfen = line[4] | |
intput = f"{input}" | |
output = f"{output}" | |
board = cshogi.Board(sfen=sfen) | |
st.markdown(board.to_svg(), unsafe_allow_html=True) | |
cnt += 1 | |
s = "" | |
for i in input.replace("\u3000","").split("。")[0][::-1]: | |
if i in ["1","2","3","4","5","6","7","8","9","同"]: | |
s+= i | |
break | |
else: | |
s+= i | |
tokenized_inputs = tokenizer.encode( | |
input, max_length= 512, truncation=True, | |
padding="max_length", return_tensors="pt" | |
) | |
output_ids = model.generate(input_ids=tokenized_inputs, | |
max_length=512, | |
repetition_penalty=10.0, # 同じ文の繰り返しへのペナルティ | |
) | |
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True, | |
clean_up_tokenization_spaces=False) | |
st.write("打ち手",s[::-1]) | |
st.write("predict",input.replace("\u3000","").split("。")[1]) | |
st.write("generate",output_text) | |
st.write("actual",output) | |
if cnt == 5: | |
break | |
del model | |