nitro / chat_nitro.py
ABDULHODIY's picture
Upload 4 files
9e151d0 verified
raw
history blame
2.56 kB
import json
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
# Initialize tokenizer
class CustomTokenizer:
def __init__(self, vocab):
self.vocab = vocab
def encode(self, text):
tokens = text.split()
ids = [self.vocab.get(token, self.vocab['[UNK]']) for token in tokens]
return ids
def decode(self, ids):
tokens = [list(self.vocab.keys())[id] for id in ids if id != self.vocab['[PAD]'] and id < len(self.vocab)]
return ' '.join(tokens)
def pad_sequence(self, sequence, max_length):
if len(sequence) < max_length:
sequence = sequence + [self.vocab['[PAD]']] * (max_length - len(sequence))
else:
sequence = sequence[:max_length]
return sequence
# Sample language model class
class LanguageModel(nn.Module):
def __init__(self, vocab_size, embed_size, hidden_size):
super(LanguageModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_size)
self.rnn = nn.GRU(embed_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, vocab_size)
def forward(self, x, hidden=None):
embedded = self.embedding(x)
output, hidden = self.rnn(embedded, hidden)
output = self.fc(output)
return output, hidden
# Load the vocab from the JSON file
with open('vocab2.json', 'r') as f:
vocab = json.load(f)
special_tokens = ['[PAD]', '[UNK]']
for token in special_tokens:
if token not in vocab:
vocab[token] = len(vocab)
tokenizer = CustomTokenizer(vocab)
# Model parameters
embed_size = 900
hidden_size = 900
vocab_size = max(vocab.values()) + 1
# Load the model
model = LanguageModel(vocab_size, embed_size, hidden_size)
model.load_state_dict(torch.load('language_model.nit'))
model.eval()
def generate_response(input_text, model, tokenizer, max_length=1000):
encoded_input = tokenizer.encode(input_text)
padded_input = tokenizer.pad_sequence(encoded_input, max_length)
input_tensor = torch.tensor(padded_input).unsqueeze(0) # Add batch dimension
with torch.no_grad():
outputs, _ = model(input_tensor)
predicted_ids = torch.argmax(outputs, dim=2).squeeze().tolist()
predicted_text = tokenizer.decode(predicted_ids)
return predicted_text
# Test the model with a new text
while True:
test_text = input(">>>")
response = generate_response(test_text, model, tokenizer)
print("Input:", test_text)
print("Response:", response)