isekai-rolename-vae / inference.py
Sunbread's picture
update model & inference
1e53095
#!/usr/bin/env python
# coding: utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
class DecoderGRU(nn.Module):
def __init__(self, latent_size, hidden_size, output_size):
super(DecoderGRU, self).__init__()
self.proj1 = nn.Linear(latent_size, latent_size)
self.proj_activation = nn.ReLU()
self.proj2 = nn.Linear(latent_size, 2 * hidden_size)
self.embedding = nn.Embedding(output_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size, num_layers=2, batch_first=True)
self.out = nn.Linear(hidden_size, output_size)
def forward(self, encoder_sample, target_tensor=None, max_length=16):
batch_size = encoder_sample.size(0)
decoder_hidden = self.proj1(encoder_sample)
decoder_hidden = self.proj_activation(decoder_hidden)
decoder_hidden = self.proj2(decoder_hidden)
decoder_hidden = decoder_hidden.view(batch_size, 2, -1).permute(1, 0, 2).contiguous()
if target_tensor is not None:
decoder_input = target_tensor
decoder_outputs, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
else:
decoder_input = torch.empty(batch_size, 1, dtype=torch.long).fill_(SOS_token)
decoder_outputs = []
for i in range(max_length):
decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
decoder_outputs.append(decoder_output)
_, topi = decoder_output.topk(1)
decoder_input = topi.squeeze(-1).detach()
decoder_outputs = torch.cat(decoder_outputs, dim=1)
decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
return decoder_outputs, decoder_hidden
def forward_step(self, input, hidden):
output = self.embedding(input)
output = F.relu(output)
output, hidden = self.gru(output, hidden)
output = self.out(output)
return output, hidden
dec = torch.load('decoder.pt', map_location='cpu')
SOS_token = 1
EOS_token = 2
katakana = list('゠ァアィイゥウェエォオカガキギクグケゲコゴサザシジスズセゼソゾタダチヂッツヅテデトドナニヌネノハバパヒビピフブプヘベペホボポマミムメモャヤュユョヨラリルレロヮワヰヱヲンヴヵヶヷヸヹヺ・ーヽヾヿㇰㇱㇲㇳㇴㇵㇶㇷㇸㇹㇺㇻㇼㇽㇾㇿ')
vocab = ['<pad>', '<sos>', '<eos>'] + katakana
vocab_dict = {v: k for k, v in enumerate(vocab)}
h_latent=64
max_len=40
names=16
def detokenize(tokens):
if EOS_token in tokens:
return ''.join(vocab[token] for token in tokens[:tokens.index(EOS_token)])
else:
return None
while True:
print('generating names...')
for name in [detokenize(seq) for seq in dec(torch.randn(names,h_latent), max_length=max_len)[0].topk(1)[1].squeeze().tolist()]:
if name is not None:
print(name)
input("press enter to continue generation...")