File size: 1,082 Bytes
7d23b62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import numpy as np
import torch
from transformers import GPT2LMHeadModel


def load_model(model_name: str = "snoop2head/Gomoku-GPT2") -> GPT2LMHeadModel:
    gpt2 = GPT2LMHeadModel.from_pretrained(model_name)
    return gpt2


BOS_TOKEN_ID = 401
PAD_TOKEN_ID = 402
EOS_TOKEN_ID = 403


def generate_gpt2(model: GPT2LMHeadModel, input_ids: torch.LongTensor,temperature = 0.7) -> list:
    """
    input_ids: [batch_size, seq_len] torch.LongTensor
    output_ids: [seq_len] list
    """
    output_ids = model.generate(
        input_ids,
        max_length=128,
        num_beams=5,
        temperature= temperature,
        pad_token_id=PAD_TOKEN_ID,
        eos_token_id=EOS_TOKEN_ID,
    )
    return output_ids.squeeze().tolist()


def change_to_1d_coordinate(board: np.ndarray, x: int, y: int) -> int:
    """change 2d coordinate to 1d coordinate"""
    return x * board.shape[1] + y


def change_to_2d_coordinate(board: np.ndarray, coordinate: int) -> tuple:
    """change 1d coordinate to 2d coordinate"""
    return (coordinate // board.shape[1], coordinate % board.shape[1])