rwkv-mcts-cot / main.py
tenet's picture
Update main.py
fbae8c9 verified
import os
import math
import random
import numpy as np
import gradio as gr
from transformers import AutoTokenizer
from rwkv.model import RWKV
# Define the Node class for MCTS
class Node:
def __init__(self, state, parent=None):
self.state = state
self.parent = parent
self.children = []
self.visits = 0
self.wins = 0
def is_fully_expanded(self):
return len(self.children) > 0
def best_child(self, c_param=1.4):
choices_weights = [
(child.wins / child.visits) + c_param * (2 * math.log(self.visits) / child.visits) ** 0.5 for child in self.children
]
return self.children[np.argmax(choices_weights)]
def expand(self, state):
new_node = Node(state, self)
self.children.append(new_node)
return new_node
# Define the MCTS class
class MCTS:
def __init__(self, simulation_limit=1000):
self.root = None
self.simulation_limit = simulation_limit
def search(self, initial_state):
self.root = Node(initial_state)
for _ in range(self.simulation_limit):
node = self.tree_policy(self.root)
reward = self.default_policy(node.state)
self.backpropagate(node, reward)
return self.root.best_child(c_param=0).state
def tree_policy(self, node):
while not node.state.is_terminal():
if not node.is_fully_expanded():
return self.expand(node)
else:
node = node.best_child()
return node
def expand(self, node):
tried_states = [child.state for child in node.children]
new_state = node.state.get_random_child_state()
while new_state in tried_states:
new_state = node.state.get_random_child_state()
return node.expand(new_state)
def default_policy(self, state):
while not state.is_terminal():
state = state.get_random_child_state()
return state.get_reward()
def backpropagate(self, node, reward):
while node is not None:
node.visits += 1
node.wins += reward
node = node.parent
# Define the Game State and Rules
class GameState:
def __init__(self, board, player):
self.board = board
self.player = player
def is_terminal(self):
return self.check_win() or self.check_draw()
def check_win(self):
for row in self.board:
if row.count(row[0]) == len(row) and row[0] != 0:
return True
for col in range(len(self.board)):
if self.board[0][col] == self.board[1][col] == self.board[2][col] and self.board[0][col] != 0:
return True
if self.board[0][0] == self.board[1][1] == self.board[2][2] and self.board[0][0] != 0:
return True
if self.board[0][2] == self.board[1][1] == self.board[2][0] and self.board[0][2] != 0:
return True
return False
def check_draw(self):
return all(self.board[row][col] != 0 for row in range(len(self.board)) for col in range(len(self.board)))
def get_random_child_state(self):
available_moves = [(row, col) for row in range(len(self.board)) for col in range(len(self.board)) if self.board[row][col] == 0]
if not available_moves:
return self
row, col = random.choice(available_moves)
new_board = [row.copy() for row in self.board]
new_board[row][col] = self.player
return GameState(new_board, 3 - self.player)
def get_reward(self):
if self.check_win():
return 1 if self.player == 1 else -1
return 0
def __str__(self):
return "\n".join(" ".join(str(cell) for cell in row) for row in self.board)
# Initialize the RWKV model and tokenizer
model_name = "BlinkDL/rwkv-4-raven"
tokenizer = AutoTokenizer.from_pretrained("gpt2") # Use a tokenizer from a supported model
# Load the RWKV model
model = RWKV(model=model_name, strategy="cuda fp16")
# Generate Chain-of-Thought
def generate_cot(state):
input_text = f"Current state: {state}\nWhat is the best move?"
inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.generate(inputs.input_ids, max_length=100, num_return_sequences=1)
cot = tokenizer.decode(outputs[0], skip_special_tokens=True)
return cot
# Use CoT in MCTS
def mcts_with_cot(initial_state):
mcts = MCTS(simulation_limit=1000)
best_state = mcts.search(initial_state)
cot = generate_cot(best_state)
return best_state, cot
# Function to be called by Gradio
def run_mcts_cot(initial_board):
initial_state = GameState(initial_board, 1)
best_state, cot = mcts_with_cot(initial_state)
return str(best_state), cot