pizzagatakasugi commited on
Commit
b03e5a6
1 Parent(s): 0dc8f8c

Delete pre_app.py

Browse files
Files changed (1) hide show
  1. pre_app.py +0 -130
pre_app.py DELETED
@@ -1,130 +0,0 @@
1
- import streamlit as st
2
- import cshogi
3
- from IPython.display import display
4
- from transformers import T5ForConditionalGeneration, T5Tokenizer
5
- import pandas as pd
6
-
7
- #モデルの読み込み
8
- tokenizer = T5Tokenizer.from_pretrained("pizzagatakasugi/shogi_t5", is_fast=True)
9
- model = T5ForConditionalGeneration.from_pretrained("pizzagatakasugi/shogi_t5")
10
- model.eval()
11
-
12
- st.title("将棋解説文の自動生成")
13
- df = pd.read_csv("./dataset10.csv")
14
- num = st.text_input("0から9の数字を入力")
15
-
16
- KIFU_TO_SQUARE_NAMES = [
17
- '1一', '1二', '1三', '1四', '1五', '1六', '1七', '1八', '1九',
18
- '2一', '2二', '2三', '2四', '2五', '2六', '2七', '2八', '2九',
19
- '3一', '3二', '3三', '3四', '3五', '3六', '3七', '3八', '3九',
20
- '4一', '4二', '4三', '4四', '4五', '4六', '4七', '4八', '4九',
21
- '5一', '5二', '5三', '5四', '5五', '5六', '5七', '5八', '5九',
22
- '6一', '6二', '6三', '6四', '6五', '6六', '6七', '6八', '6九',
23
- '7一', '7二', '7三', '7四', '7五', '7六', '7七', '7八', '7九',
24
- '8一', '8二', '8三', '8四', '8五', '8六', '8七', '8八', '8九',
25
- '9一', '9二', '9三', '9四', '9五', '9六', '9七', '9八', '9九',
26
- ]
27
- KIFU_FROM_SQUARE_NAMES = [
28
- '11', '12', '13', '14', '15', '16', '17', '18', '19',
29
- '21', '22', '23', '24', '25', '26', '27', '28', '29',
30
- '31', '32', '33', '34', '35', '36', '37', '38', '39',
31
- '41', '42', '43', '44', '45', '46', '47', '48', '49',
32
- '51', '52', '53', '54', '55', '56', '57', '58', '59',
33
- '61', '62', '63', '64', '65', '66', '67', '68', '69',
34
- '71', '72', '73', '74', '75', '76', '77', '78', '79',
35
- '81', '82', '83', '84', '85', '86', '87', '88', '89',
36
- '91', '92', '93', '94', '95', '96', '97', '98', '99',
37
- ]
38
-
39
- if num in [str(x) for x in list(range(10))]:
40
- df = df.iloc[int(num)]
41
- st.write(df["game_type"],df["precedence_name"],df["follower_name"])
42
- sfen = df["sfen"].split("\n")
43
- bestlist = eval(df["bestlist"])
44
- best2list = eval(df["best2list"])
45
- te = []
46
- te_sf = []
47
- movelist = []
48
-
49
- #文字の正規化
50
- for x in range(len(sfen)):
51
- if x < 2:
52
- continue
53
- if len(sfen[x]) > 30:
54
- te_sf.append(sfen[x])
55
- else:
56
- #te.append(sfen[x])
57
- temp = sfen[x].split()
58
- num = temp[1][0] + temp[1][1]
59
- for y in range(len(KIFU_FROM_SQUARE_NAMES)):
60
- if num == KIFU_FROM_SQUARE_NAMES[y]:
61
- sq = KIFU_TO_SQUARE_NAMES[y]
62
- word = sq+temp[1][2:]
63
- word = word.replace("竜","龍").replace("成銀","全").replace("成桂","圭").replace("成香","杏")
64
- if sfen[x].split()[1] not in ["投了" , "千日手" , "持将棋" , "反則勝ち"]:
65
- te.append(temp[0]+" "+word)
66
- movelist.append(word)
67
- else:
68
- movelist.append(sfen[x].split()[1])
69
-
70
- #盤面表示
71
- s = st.selectbox(label="手数を選択",options=te)
72
-
73
- with st.expander("parameter"):
74
- temp = st.slider("temperature",min_value=0.0,max_value=1.0,step=0.01,value=0.3,key=1)
75
- beams = st.slider("num_beams",min_value=1,max_value=5,step=1,value=1,key=2)
76
- tokens = st.slider("min_new_tokens",min_value=0,max_value=50,value=20,key=3)
77
-
78
- reload = st.button('盤面生成',key=0)
79
- if s in te and reload == True:
80
- reload = False
81
- idx = te.index(s)
82
- board = cshogi.Board(sfen=te_sf[idx+1])
83
- st.markdown(board.to_svg(),unsafe_allow_html=True)
84
-
85
- #入力文作成
86
- kifs=""
87
- cnt = 0
88
- for kif in movelist:
89
- if cnt > idx:
90
- break
91
- kif = kif.split("(")[0]
92
- kifs += kif
93
- cnt += 1
94
-
95
- best = ""
96
- for x in bestlist[idx]:
97
- best += x.split("(")[0]
98
-
99
- best2 = ""
100
- for y in best2list[idx]:
101
- best2 += y.split("(")[0]
102
-
103
- #st.write(idx,"入力",input)
104
- with st.spinner("推論中です..."):
105
- input = sfen[0]+sfen[1]+kifs+"最善手の予測手順は"+best+"次善手の予測手順は"+best2
106
- tokenized_inputs = tokenizer.encode(
107
- input, max_length= 512, truncation=True,
108
- padding="max_length", return_tensors="pt"
109
- )
110
-
111
- output_ids = model.generate(input_ids=tokenized_inputs,
112
- max_length=512,
113
- repetition_penalty=10.0, # 同じ文の繰り返しへのペナルティ
114
- temperature = temp,
115
- num_beams = beams,
116
- min_new_tokens = tokens,
117
- )
118
-
119
- output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True,
120
- clean_up_tokenization_spaces=False)
121
- st.write(output_text)
122
-
123
-
124
- # temperature = st.slider("temperature",min_value=0.0,max_value=1.0,step=0.01,value=0.3,key=1)
125
- # num_beams = st.slider("num_beams",min_value=1,max_value=5,step=1,value=1,key=2)
126
- # min_new_tokens = st.slider("min_new_tokens",min_value=0,max_value=100,value=30,key=3)
127
-
128
-
129
-
130
-