import streamlit as st import cshogi from IPython.display import display from transformers import T5ForConditionalGeneration, T5Tokenizer # board = cshogi.Board() # st.write(display(board.to_svg()),unsafe_allow_html=True) # st.write(board,unsafe_allow_html=True) # st.write(print(board), unsafe_allow_html=True) tokenizer = T5Tokenizer.from_pretrained("pizzagatakasugi/shogi_t5", is_fast=True) model = T5ForConditionalGeneration.from_pretrained("pizzagatakasugi/shogi_t5") model.eval() cnt = 0 with open("./sample.tsv", "r", encoding="utf-8") as f: for line in f: line = line.strip().split("\t") input = line[0] output = line[1] pre_name = line[2] fow_name = line[3] sfen = line[4] intput = f"{input}" output = f"{output}" board = cshogi.Board(sfen=sfen) st.markdown(board.to_svg(), unsafe_allow_html=True) cnt += 1 s = "" for i in input.replace("\u3000","").split("。")[0][::-1]: if i in ["1","2","3","4","5","6","7","8","9","同"]: s+= i break else: s+= i tokenized_inputs = tokenizer.encode( input, max_length= 512, truncation=True, padding="max_length", return_tensors="pt" ) output_ids = model.generate(input_ids=tokenized_inputs, max_length=512, repetition_penalty=10.0, # 同じ文の繰り返しへのペナルティ ) output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) st.write("打ち手",s[::-1]) st.write("predict",input.replace("\u3000","").split("。")[1]) st.write("generate",output_text) st.write("actual",output) if cnt == 5: break del model