|
import json |
|
import torch |
|
from torch.utils.data import DataLoader, Dataset |
|
import torch.nn as nn |
|
|
|
|
|
class CustomTokenizer: |
|
def __init__(self, vocab): |
|
self.vocab = vocab |
|
|
|
def encode(self, text): |
|
tokens = text.split() |
|
ids = [self.vocab.get(token, self.vocab['[UNK]']) for token in tokens] |
|
return ids |
|
|
|
def decode(self, ids): |
|
tokens = [list(self.vocab.keys())[id] for id in ids if id != self.vocab['[PAD]'] and id < len(self.vocab)] |
|
return ' '.join(tokens) |
|
|
|
def pad_sequence(self, sequence, max_length): |
|
if len(sequence) < max_length: |
|
sequence = sequence + [self.vocab['[PAD]']] * (max_length - len(sequence)) |
|
else: |
|
sequence = sequence[:max_length] |
|
return sequence |
|
|
|
|
|
class LanguageModel(nn.Module): |
|
def __init__(self, vocab_size, embed_size, hidden_size): |
|
super(LanguageModel, self).__init__() |
|
self.embedding = nn.Embedding(vocab_size, embed_size) |
|
self.rnn = nn.GRU(embed_size, hidden_size, batch_first=True) |
|
self.fc = nn.Linear(hidden_size, vocab_size) |
|
|
|
def forward(self, x, hidden=None): |
|
embedded = self.embedding(x) |
|
output, hidden = self.rnn(embedded, hidden) |
|
output = self.fc(output) |
|
return output, hidden |
|
|
|
|
|
with open('vocab2.json', 'r') as f: |
|
vocab = json.load(f) |
|
|
|
special_tokens = ['[PAD]', '[UNK]'] |
|
for token in special_tokens: |
|
if token not in vocab: |
|
vocab[token] = len(vocab) |
|
|
|
tokenizer = CustomTokenizer(vocab) |
|
|
|
|
|
embed_size = 900 |
|
hidden_size = 900 |
|
vocab_size = max(vocab.values()) + 1 |
|
|
|
|
|
model = LanguageModel(vocab_size, embed_size, hidden_size) |
|
model.load_state_dict(torch.load('language_model.nit')) |
|
model.eval() |
|
|
|
def generate_response(input_text, model, tokenizer, max_length=1000): |
|
encoded_input = tokenizer.encode(input_text) |
|
padded_input = tokenizer.pad_sequence(encoded_input, max_length) |
|
input_tensor = torch.tensor(padded_input).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
outputs, _ = model(input_tensor) |
|
|
|
predicted_ids = torch.argmax(outputs, dim=2).squeeze().tolist() |
|
predicted_text = tokenizer.decode(predicted_ids) |
|
|
|
return predicted_text |
|
|
|
|
|
while True: |
|
test_text = input(">>>") |
|
response = generate_response(test_text, model, tokenizer) |
|
print("Input:", test_text) |
|
print("Response:", response) |
|
|