Spaces:
Sleeping
Sleeping
pizzagatakasugi
commited on
Commit
•
b03e5a6
1
Parent(s):
0dc8f8c
Delete pre_app.py
Browse files- pre_app.py +0 -130
pre_app.py
DELETED
@@ -1,130 +0,0 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
import cshogi
|
3 |
-
from IPython.display import display
|
4 |
-
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
5 |
-
import pandas as pd
|
6 |
-
|
7 |
-
#モデルの読み込み
|
8 |
-
tokenizer = T5Tokenizer.from_pretrained("pizzagatakasugi/shogi_t5", is_fast=True)
|
9 |
-
model = T5ForConditionalGeneration.from_pretrained("pizzagatakasugi/shogi_t5")
|
10 |
-
model.eval()
|
11 |
-
|
12 |
-
st.title("将棋解説文の自動生成")
|
13 |
-
df = pd.read_csv("./dataset10.csv")
|
14 |
-
num = st.text_input("0から9の数字を入力")
|
15 |
-
|
16 |
-
KIFU_TO_SQUARE_NAMES = [
|
17 |
-
'1一', '1二', '1三', '1四', '1五', '1六', '1七', '1八', '1九',
|
18 |
-
'2一', '2二', '2三', '2四', '2五', '2六', '2七', '2八', '2九',
|
19 |
-
'3一', '3二', '3三', '3四', '3五', '3六', '3七', '3八', '3九',
|
20 |
-
'4一', '4二', '4三', '4四', '4五', '4六', '4七', '4八', '4九',
|
21 |
-
'5一', '5二', '5三', '5四', '5五', '5六', '5七', '5八', '5九',
|
22 |
-
'6一', '6二', '6三', '6四', '6五', '6六', '6七', '6八', '6九',
|
23 |
-
'7一', '7二', '7三', '7四', '7五', '7六', '7七', '7八', '7九',
|
24 |
-
'8一', '8二', '8三', '8四', '8五', '8六', '8七', '8八', '8九',
|
25 |
-
'9一', '9二', '9三', '9四', '9五', '9六', '9七', '9八', '9九',
|
26 |
-
]
|
27 |
-
KIFU_FROM_SQUARE_NAMES = [
|
28 |
-
'11', '12', '13', '14', '15', '16', '17', '18', '19',
|
29 |
-
'21', '22', '23', '24', '25', '26', '27', '28', '29',
|
30 |
-
'31', '32', '33', '34', '35', '36', '37', '38', '39',
|
31 |
-
'41', '42', '43', '44', '45', '46', '47', '48', '49',
|
32 |
-
'51', '52', '53', '54', '55', '56', '57', '58', '59',
|
33 |
-
'61', '62', '63', '64', '65', '66', '67', '68', '69',
|
34 |
-
'71', '72', '73', '74', '75', '76', '77', '78', '79',
|
35 |
-
'81', '82', '83', '84', '85', '86', '87', '88', '89',
|
36 |
-
'91', '92', '93', '94', '95', '96', '97', '98', '99',
|
37 |
-
]
|
38 |
-
|
39 |
-
if num in [str(x) for x in list(range(10))]:
|
40 |
-
df = df.iloc[int(num)]
|
41 |
-
st.write(df["game_type"],df["precedence_name"],df["follower_name"])
|
42 |
-
sfen = df["sfen"].split("\n")
|
43 |
-
bestlist = eval(df["bestlist"])
|
44 |
-
best2list = eval(df["best2list"])
|
45 |
-
te = []
|
46 |
-
te_sf = []
|
47 |
-
movelist = []
|
48 |
-
|
49 |
-
#文字の正規化
|
50 |
-
for x in range(len(sfen)):
|
51 |
-
if x < 2:
|
52 |
-
continue
|
53 |
-
if len(sfen[x]) > 30:
|
54 |
-
te_sf.append(sfen[x])
|
55 |
-
else:
|
56 |
-
#te.append(sfen[x])
|
57 |
-
temp = sfen[x].split()
|
58 |
-
num = temp[1][0] + temp[1][1]
|
59 |
-
for y in range(len(KIFU_FROM_SQUARE_NAMES)):
|
60 |
-
if num == KIFU_FROM_SQUARE_NAMES[y]:
|
61 |
-
sq = KIFU_TO_SQUARE_NAMES[y]
|
62 |
-
word = sq+temp[1][2:]
|
63 |
-
word = word.replace("竜","龍").replace("成銀","全").replace("成桂","圭").replace("成香","杏")
|
64 |
-
if sfen[x].split()[1] not in ["投了" , "千日手" , "持将棋" , "反則勝ち"]:
|
65 |
-
te.append(temp[0]+" "+word)
|
66 |
-
movelist.append(word)
|
67 |
-
else:
|
68 |
-
movelist.append(sfen[x].split()[1])
|
69 |
-
|
70 |
-
#盤面表示
|
71 |
-
s = st.selectbox(label="手数を選択",options=te)
|
72 |
-
|
73 |
-
with st.expander("parameter"):
|
74 |
-
temp = st.slider("temperature",min_value=0.0,max_value=1.0,step=0.01,value=0.3,key=1)
|
75 |
-
beams = st.slider("num_beams",min_value=1,max_value=5,step=1,value=1,key=2)
|
76 |
-
tokens = st.slider("min_new_tokens",min_value=0,max_value=50,value=20,key=3)
|
77 |
-
|
78 |
-
reload = st.button('盤面生成',key=0)
|
79 |
-
if s in te and reload == True:
|
80 |
-
reload = False
|
81 |
-
idx = te.index(s)
|
82 |
-
board = cshogi.Board(sfen=te_sf[idx+1])
|
83 |
-
st.markdown(board.to_svg(),unsafe_allow_html=True)
|
84 |
-
|
85 |
-
#入力文作成
|
86 |
-
kifs=""
|
87 |
-
cnt = 0
|
88 |
-
for kif in movelist:
|
89 |
-
if cnt > idx:
|
90 |
-
break
|
91 |
-
kif = kif.split("(")[0]
|
92 |
-
kifs += kif
|
93 |
-
cnt += 1
|
94 |
-
|
95 |
-
best = ""
|
96 |
-
for x in bestlist[idx]:
|
97 |
-
best += x.split("(")[0]
|
98 |
-
|
99 |
-
best2 = ""
|
100 |
-
for y in best2list[idx]:
|
101 |
-
best2 += y.split("(")[0]
|
102 |
-
|
103 |
-
#st.write(idx,"入力",input)
|
104 |
-
with st.spinner("推論中です..."):
|
105 |
-
input = sfen[0]+sfen[1]+kifs+"最善手の予測手順は"+best+"次善手の予測手順は"+best2
|
106 |
-
tokenized_inputs = tokenizer.encode(
|
107 |
-
input, max_length= 512, truncation=True,
|
108 |
-
padding="max_length", return_tensors="pt"
|
109 |
-
)
|
110 |
-
|
111 |
-
output_ids = model.generate(input_ids=tokenized_inputs,
|
112 |
-
max_length=512,
|
113 |
-
repetition_penalty=10.0, # 同じ文の繰り返しへのペナルティ
|
114 |
-
temperature = temp,
|
115 |
-
num_beams = beams,
|
116 |
-
min_new_tokens = tokens,
|
117 |
-
)
|
118 |
-
|
119 |
-
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True,
|
120 |
-
clean_up_tokenization_spaces=False)
|
121 |
-
st.write(output_text)
|
122 |
-
|
123 |
-
|
124 |
-
# temperature = st.slider("temperature",min_value=0.0,max_value=1.0,step=0.01,value=0.3,key=1)
|
125 |
-
# num_beams = st.slider("num_beams",min_value=1,max_value=5,step=1,value=1,key=2)
|
126 |
-
# min_new_tokens = st.slider("min_new_tokens",min_value=0,max_value=100,value=30,key=3)
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|