import torch from transformers import GPT2Tokenizer import gradio as gr import tiktoken import model_file from dataclasses import dataclass import time import os import torch.nn.functional as F num_return_sequences = 1 max_length = 100 @dataclass class GPTConfig: block_size: int = 1024 vocab_size: int = 50304 n_layer: int = 12 n_head: int = 12 n_embd: int = 768 # tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer = tiktoken.get_encoding("gpt2") device = "cpu" if torch.cuda.is_available(): device = "cuda" elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): device = "mps" device = torch.device(device) try: model = model_file.get_model().to(device) checkpoint = torch.load(os.path.join(os.path.dirname(__file__), "model_00350.pt"), map_location=device) state_dict = {key.replace("_orig_mod.", ""): value for key, value in checkpoint['model'].items()} model.load_state_dict(state_dict=state_dict) model.eval() print("Model loaded successfully.") except Exception as e: print(f"Error loading model: {e}") raise e examples = [ "Who are you?", "Write a Shakespeare short poem.", "Tell me a joke.", "What is the meaning of life?", ] def chat_fn(message, history): # Tokenize print(f"message: {message}") tokens = tokenizer.encode(message) tokens = torch.tensor(tokens, dtype=torch.int32) tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1) x = tokens.to(device) while x.size(1) < max_length: # forward pass through model to get logits with torch.no_grad(): logits = model(x)[0] # batch_size, T, vocab_size logits = logits[:, -1, :] # get last position logits B, vocab_size # calculate probabilities probs = F.softmax(logits, dim=-1) # doing topk here, HF defafult is 50 # topk is (5, 50), top_indices is (5, 50) too topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) # sampling a token from topk ix = torch.multinomial(input=topk_probs, num_samples=1) # (B, 1) (5, 1) # gather corresponding indices xcol = torch.gather(input=topk_indices, dim=-1, index=ix) # append to the seq x = torch.cat([x, xcol], dim=1) for i in range(num_return_sequences): tokens = x[i, :max_length].tolist() decoded = tokenizer.decode(tokens) yield decoded gr.ChatInterface(chat_fn, examples=examples).launch() # interface.launch()