File size: 4,770 Bytes
157810d
5f6e38c
 
 
157810d
 
fbae8c9
5f6e38c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157810d
 
 
c0828ba
5f6e38c
 
 
 
157810d
 
 
5f6e38c
 
 
 
 
 
 
 
 
b6b5359
157810d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import math
import random
import numpy as np
import gradio as gr
from transformers import AutoTokenizer
from rwkv.model import RWKV

# Define the Node class for MCTS
class Node:
    def __init__(self, state, parent=None):
        self.state = state
        self.parent = parent
        self.children = []
        self.visits = 0
        self.wins = 0

    def is_fully_expanded(self):
        return len(self.children) > 0

    def best_child(self, c_param=1.4):
        choices_weights = [
            (child.wins / child.visits) + c_param * (2 * math.log(self.visits) / child.visits) ** 0.5 for child in self.children
        ]
        return self.children[np.argmax(choices_weights)]

    def expand(self, state):
        new_node = Node(state, self)
        self.children.append(new_node)
        return new_node

# Define the MCTS class
class MCTS:
    def __init__(self, simulation_limit=1000):
        self.root = None
        self.simulation_limit = simulation_limit

    def search(self, initial_state):
        self.root = Node(initial_state)
        for _ in range(self.simulation_limit):
            node = self.tree_policy(self.root)
            reward = self.default_policy(node.state)
            self.backpropagate(node, reward)
        return self.root.best_child(c_param=0).state

    def tree_policy(self, node):
        while not node.state.is_terminal():
            if not node.is_fully_expanded():
                return self.expand(node)
            else:
                node = node.best_child()
        return node

    def expand(self, node):
        tried_states = [child.state for child in node.children]
        new_state = node.state.get_random_child_state()
        while new_state in tried_states:
            new_state = node.state.get_random_child_state()
        return node.expand(new_state)

    def default_policy(self, state):
        while not state.is_terminal():
            state = state.get_random_child_state()
        return state.get_reward()

    def backpropagate(self, node, reward):
        while node is not None:
            node.visits += 1
            node.wins += reward
            node = node.parent

# Define the Game State and Rules
class GameState:
    def __init__(self, board, player):
        self.board = board
        self.player = player

    def is_terminal(self):
        return self.check_win() or self.check_draw()

    def check_win(self):
        for row in self.board:
            if row.count(row[0]) == len(row) and row[0] != 0:
                return True
        for col in range(len(self.board)):
            if self.board[0][col] == self.board[1][col] == self.board[2][col] and self.board[0][col] != 0:
                return True
        if self.board[0][0] == self.board[1][1] == self.board[2][2] and self.board[0][0] != 0:
            return True
        if self.board[0][2] == self.board[1][1] == self.board[2][0] and self.board[0][2] != 0:
            return True
        return False

    def check_draw(self):
        return all(self.board[row][col] != 0 for row in range(len(self.board)) for col in range(len(self.board)))

    def get_random_child_state(self):
        available_moves = [(row, col) for row in range(len(self.board)) for col in range(len(self.board)) if self.board[row][col] == 0]
        if not available_moves:
            return self
        row, col = random.choice(available_moves)
        new_board = [row.copy() for row in self.board]
        new_board[row][col] = self.player
        return GameState(new_board, 3 - self.player)

    def get_reward(self):
        if self.check_win():
            return 1 if self.player == 1 else -1
        return 0

    def __str__(self):
        return "\n".join(" ".join(str(cell) for cell in row) for row in self.board)

# Initialize the RWKV model and tokenizer
model_name = "BlinkDL/rwkv-4-raven"
tokenizer = AutoTokenizer.from_pretrained("gpt2")  # Use a tokenizer from a supported model

# Load the RWKV model
model = RWKV(model=model_name, strategy="cuda fp16")

# Generate Chain-of-Thought
def generate_cot(state):
    input_text = f"Current state: {state}\nWhat is the best move?"
    inputs = tokenizer(input_text, return_tensors="pt")
    outputs = model.generate(inputs.input_ids, max_length=100, num_return_sequences=1)
    cot = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return cot

# Use CoT in MCTS
def mcts_with_cot(initial_state):
    mcts = MCTS(simulation_limit=1000)
    best_state = mcts.search(initial_state)
    cot = generate_cot(best_state)
    return best_state, cot

# Function to be called by Gradio
def run_mcts_cot(initial_board):
    initial_state = GameState(initial_board, 1)
    best_state, cot = mcts_with_cot(initial_state)
    return str(best_state), cot