File size: 2,580 Bytes
dc0d378
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
from transformers import GPT2Tokenizer
import gradio as gr
import tiktoken
import model_file
from dataclasses import dataclass
import time
import os
import torch.nn.functional as F

num_return_sequences = 1
max_length = 100


@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50304
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768


# tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer = tiktoken.get_encoding("gpt2")

device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = "mps"

device = torch.device(device)

try:
    model = model_file.get_model().to(device)
    checkpoint = torch.load(os.path.join(os.path.dirname(__file__), "model_00350.pt"), map_location=device)
    state_dict = {key.replace("_orig_mod.", ""): value for key, value in checkpoint['model'].items()}
    model.load_state_dict(state_dict=state_dict)
    model.eval()
    print("Model loaded successfully.")
except Exception as e:
    print(f"Error loading model: {e}")
    raise e

examples = [
    "Who are you?",
    "Write a Shakespeare short poem.",
    "Tell me a joke.",
    "What is the meaning of life?",
]


def chat_fn(message, history):
    # Tokenize
    print(f"message: {message}")
    tokens = tokenizer.encode(message)
    tokens = torch.tensor(tokens, dtype=torch.int32)
    tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)
    x = tokens.to(device)
    while x.size(1) < max_length:
        # forward pass through model to get logits
        with torch.no_grad():
            logits = model(x)[0]  # batch_size, T, vocab_size
            logits = logits[:, -1, :]  # get last position logits B, vocab_size

            # calculate probabilities
            probs = F.softmax(logits, dim=-1)

            # doing topk here, HF defafult is 50
            # topk is (5, 50), top_indices is (5, 50) too
            topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)

            # sampling a token from topk
            ix = torch.multinomial(input=topk_probs, num_samples=1)  # (B, 1) (5, 1)

            # gather corresponding indices
            xcol = torch.gather(input=topk_indices, dim=-1, index=ix)
            # append to the seq
            x = torch.cat([x, xcol], dim=1)

    for i in range(num_return_sequences):
        tokens = x[i, :max_length].tolist()
        decoded = tokenizer.decode(tokens)

        yield decoded


gr.ChatInterface(chat_fn, examples=examples).launch()
# interface.launch()