Spaces:
Runtime error
Runtime error
File size: 4,770 Bytes
157810d 5f6e38c 157810d fbae8c9 5f6e38c 157810d c0828ba 5f6e38c 157810d 5f6e38c b6b5359 157810d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import os
import math
import random
import numpy as np
import gradio as gr
from transformers import AutoTokenizer
from rwkv.model import RWKV
# Define the Node class for MCTS
class Node:
def __init__(self, state, parent=None):
self.state = state
self.parent = parent
self.children = []
self.visits = 0
self.wins = 0
def is_fully_expanded(self):
return len(self.children) > 0
def best_child(self, c_param=1.4):
choices_weights = [
(child.wins / child.visits) + c_param * (2 * math.log(self.visits) / child.visits) ** 0.5 for child in self.children
]
return self.children[np.argmax(choices_weights)]
def expand(self, state):
new_node = Node(state, self)
self.children.append(new_node)
return new_node
# Define the MCTS class
class MCTS:
def __init__(self, simulation_limit=1000):
self.root = None
self.simulation_limit = simulation_limit
def search(self, initial_state):
self.root = Node(initial_state)
for _ in range(self.simulation_limit):
node = self.tree_policy(self.root)
reward = self.default_policy(node.state)
self.backpropagate(node, reward)
return self.root.best_child(c_param=0).state
def tree_policy(self, node):
while not node.state.is_terminal():
if not node.is_fully_expanded():
return self.expand(node)
else:
node = node.best_child()
return node
def expand(self, node):
tried_states = [child.state for child in node.children]
new_state = node.state.get_random_child_state()
while new_state in tried_states:
new_state = node.state.get_random_child_state()
return node.expand(new_state)
def default_policy(self, state):
while not state.is_terminal():
state = state.get_random_child_state()
return state.get_reward()
def backpropagate(self, node, reward):
while node is not None:
node.visits += 1
node.wins += reward
node = node.parent
# Define the Game State and Rules
class GameState:
def __init__(self, board, player):
self.board = board
self.player = player
def is_terminal(self):
return self.check_win() or self.check_draw()
def check_win(self):
for row in self.board:
if row.count(row[0]) == len(row) and row[0] != 0:
return True
for col in range(len(self.board)):
if self.board[0][col] == self.board[1][col] == self.board[2][col] and self.board[0][col] != 0:
return True
if self.board[0][0] == self.board[1][1] == self.board[2][2] and self.board[0][0] != 0:
return True
if self.board[0][2] == self.board[1][1] == self.board[2][0] and self.board[0][2] != 0:
return True
return False
def check_draw(self):
return all(self.board[row][col] != 0 for row in range(len(self.board)) for col in range(len(self.board)))
def get_random_child_state(self):
available_moves = [(row, col) for row in range(len(self.board)) for col in range(len(self.board)) if self.board[row][col] == 0]
if not available_moves:
return self
row, col = random.choice(available_moves)
new_board = [row.copy() for row in self.board]
new_board[row][col] = self.player
return GameState(new_board, 3 - self.player)
def get_reward(self):
if self.check_win():
return 1 if self.player == 1 else -1
return 0
def __str__(self):
return "\n".join(" ".join(str(cell) for cell in row) for row in self.board)
# Initialize the RWKV model and tokenizer
model_name = "BlinkDL/rwkv-4-raven"
tokenizer = AutoTokenizer.from_pretrained("gpt2") # Use a tokenizer from a supported model
# Load the RWKV model
model = RWKV(model=model_name, strategy="cuda fp16")
# Generate Chain-of-Thought
def generate_cot(state):
input_text = f"Current state: {state}\nWhat is the best move?"
inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.generate(inputs.input_ids, max_length=100, num_return_sequences=1)
cot = tokenizer.decode(outputs[0], skip_special_tokens=True)
return cot
# Use CoT in MCTS
def mcts_with_cot(initial_state):
mcts = MCTS(simulation_limit=1000)
best_state = mcts.search(initial_state)
cot = generate_cot(best_state)
return best_state, cot
# Function to be called by Gradio
def run_mcts_cot(initial_board):
initial_state = GameState(initial_board, 1)
best_state, cot = mcts_with_cot(initial_state)
return str(best_state), cot |