File size: 6,224 Bytes
293421b
 
 
 
 
 
 
 
 
 
 
 
ee8c78e
293421b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4bee8c6
604dd69
 
 
293421b
 
 
 
 
 
 
 
 
 
 
8717155
293421b
 
 
 
 
 
8717155
 
 
 
 
 
293421b
 
 
8717155
293421b
8717155
 
 
 
293421b
 
 
 
 
8717155
293421b
 
 
 
d1a91c6
293421b
 
 
 
 
 
 
 
 
 
 
 
 
4333018
293421b
8717155
d1a91c6
4333018
293421b
 
8717155
 
293421b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import streamlit as st
import cshogi
from IPython.display import display
from transformers import T5ForConditionalGeneration, T5Tokenizer
import pandas as pd

#モデルの読み込み
tokenizer = T5Tokenizer.from_pretrained("pizzagatakasugi/shogi_t5", is_fast=True)
model = T5ForConditionalGeneration.from_pretrained("pizzagatakasugi/shogi_t5_v2")
model.eval()

st.title("将棋解説文の自動生成")
df = pd.read_csv("./demo.csv")
num = st.text_input("0から9の数字を入力")

KIFU_TO_SQUARE_NAMES = [
                    '1一', '1二', '1三', '1四', '1五', '1六', '1七', '1八', '1九',
                    '2一', '2二', '2三', '2四', '2五', '2六', '2七', '2八', '2九',
                    '3一', '3二', '3三', '3四', '3五', '3六', '3七', '3八', '3九',
                    '4一', '4二', '4三', '4四', '4五', '4六', '4七', '4八', '4九',
                    '5一', '5二', '5三', '5四', '5五', '5六', '5七', '5八', '5九',
                    '6一', '6二', '6三', '6四', '6五', '6六', '6七', '6八', '6九',
                    '7一', '7二', '7三', '7四', '7五', '7六', '7七', '7八', '7九',
                    '8一', '8二', '8三', '8四', '8五', '8六', '8七', '8八', '8九',
                    '9一', '9二', '9三', '9四', '9五', '9六', '9七', '9八', '9九',
                ]
KIFU_FROM_SQUARE_NAMES = [
                    '11', '12', '13', '14', '15', '16', '17', '18', '19',
                    '21', '22', '23', '24', '25', '26', '27', '28', '29',
                    '31', '32', '33', '34', '35', '36', '37', '38', '39',
                    '41', '42', '43', '44', '45', '46', '47', '48', '49',
                    '51', '52', '53', '54', '55', '56', '57', '58', '59',
                    '61', '62', '63', '64', '65', '66', '67', '68', '69',
                    '71', '72', '73', '74', '75', '76', '77', '78', '79',
                    '81', '82', '83', '84', '85', '86', '87', '88', '89',
                    '91', '92', '93', '94', '95', '96', '97', '98', '99',
                ]

if  num in [str(x) for x in list(range(10))]:
    df = df.iloc[int(num)]
    st.write(df["game_type"],df["precedence_name"],df["follower_name"])
    sfen = df["sfen"].split("\n")
    bestlist = eval(df["bestlist"])
    best2list = eval(df["best2list"])
    te = []
    te_sf = []
    movelist = []

    #文字の正規化
    for x in range(len(sfen)):
        if x < 2:
            continue
        if len(sfen[x]) > 30:
            te_sf.append(sfen[x])
        else:
            #te.append(sfen[x])
            temp = sfen[x].split()
            num = temp[1][0] + temp[1][1]
            for y in range(len(KIFU_FROM_SQUARE_NAMES)):
                if num == KIFU_FROM_SQUARE_NAMES[y]:
                    sq = KIFU_TO_SQUARE_NAMES[y]
            word = sq+temp[1][2:]
            word = word.replace("竜","龍").replace("成銀","全").replace("成桂","圭").replace("成香","杏")
            if sfen[x].split()[1] not in ["投了" , "千日手" , "持将棋" , "反則勝ち"]:
                te.append(temp[0]+" "+word)
                movelist.append(word)
            else:
                movelist.append(sfen[x].split()[1])

    #盤面表示
    s = st.selectbox(label="手数を選択",options=te)

    with st.expander("parameter"):
        
        beams = st.slider("num_beams",min_value=1,max_value=10,step=1,value=5,key=2)
        tokens = st.slider("min_new_tokens",min_value=0,max_value=50,step=1,value=20,key=3)
        top_p = st.slider("top_p",min_value=0.50,max_value=1.00,value=0.90,step=0.01)
        top_k = st.slider("top_k",min_value=5,max_value=50,value=30,step=1)
        
    reload = st.button('盤面生成',key=0)
    if s in te and reload == True:
        reload = False
        idx = te.index(s)
        board = cshogi.Board(sfen=te_sf[idx+1])
        st.markdown(board.to_svg(),unsafe_allow_html=True)

        #入力文作成
        kifs="解説文生成:"
        cnt = 0
        teban = "▲"
        for kif in movelist:
            if cnt > idx:
                break
            kif = kif.split("(")[0]
            kifs += kif.replace("▲","").replace("△","")
            cnt += 1
            if teban == "▲":
                teban = "△"
            else:
                teban = "▲"

        teban2 = teban
        best = ""
        cnt = 0
        for x in bestlist[idx]:
            best += teban+x.split("(")[0]
            cnt += 1
            if teban == "▲":
                teban = "△"
            else:
                teban = "▲"
            if cnt == 3:
                break
            
        best2 = ""
        for y in best2list[idx]:
            best2 += teban2+y.split("(")[0]
            break
            
        #st.write(idx,"入力",input)
        with st.spinner("推論中です..."):
            input = kifs+"。最善手は"+best+"。次善手は"+best2
            tokenized_inputs = tokenizer.encode(
                                    input, max_length= 512, truncation=True, 
                                    padding="max_length", return_tensors="pt"
                                )
        
            output_ids = model.generate(input_ids=tokenized_inputs,
                            max_length=512,
                            repetition_penalty=10.0, # 同じ文の繰り返しへのペナルティ
                            do_sample = True,
                            num_beams = beams,
                            min_new_tokens = tokens,
                            top_p = top_p,
                            top_k = top_k,
                            num_return_sequences = beams,
                            )
            output_list = []
            # st.write(input)
            for x in range(beams):
                output_text = tokenizer.decode(output_ids[x], skip_special_tokens=True,
                                                    clean_up_tokenization_spaces=False)
                output_list.append(output_text)
            st.write(output_list)