|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
class DecoderGRU(nn.Module): |
|
def __init__(self, latent_size, hidden_size, output_size): |
|
super(DecoderGRU, self).__init__() |
|
self.proj1 = nn.Linear(latent_size, latent_size) |
|
self.proj_activation = nn.ReLU() |
|
self.proj2 = nn.Linear(latent_size, 2 * hidden_size) |
|
self.embedding = nn.Embedding(output_size, hidden_size) |
|
self.gru = nn.GRU(hidden_size, hidden_size, num_layers=2, batch_first=True) |
|
self.out = nn.Linear(hidden_size, output_size) |
|
|
|
def forward(self, encoder_sample, target_tensor=None, max_length=16): |
|
batch_size = encoder_sample.size(0) |
|
decoder_hidden = self.proj1(encoder_sample) |
|
decoder_hidden = self.proj_activation(decoder_hidden) |
|
decoder_hidden = self.proj2(decoder_hidden) |
|
decoder_hidden = decoder_hidden.view(batch_size, 2, -1).permute(1, 0, 2).contiguous() |
|
if target_tensor is not None: |
|
decoder_input = target_tensor |
|
decoder_outputs, decoder_hidden = self.forward_step(decoder_input, decoder_hidden) |
|
else: |
|
decoder_input = torch.empty(batch_size, 1, dtype=torch.long).fill_(SOS_token) |
|
decoder_outputs = [] |
|
for i in range(max_length): |
|
decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden) |
|
decoder_outputs.append(decoder_output) |
|
_, topi = decoder_output.topk(1) |
|
decoder_input = topi.squeeze(-1).detach() |
|
decoder_outputs = torch.cat(decoder_outputs, dim=1) |
|
decoder_outputs = F.log_softmax(decoder_outputs, dim=-1) |
|
return decoder_outputs, decoder_hidden |
|
|
|
def forward_step(self, input, hidden): |
|
output = self.embedding(input) |
|
output = F.relu(output) |
|
output, hidden = self.gru(output, hidden) |
|
output = self.out(output) |
|
return output, hidden |
|
|
|
dec = torch.load('decoder.pt', map_location='cpu') |
|
|
|
SOS_token = 1 |
|
EOS_token = 2 |
|
katakana = list('゠ァアィイゥウェエォオカガキギクグケゲコゴサザシジスズセゼソゾタダチヂッツヅテデトドナニヌネノハバパヒビピフブプヘベペホボポマミムメモャヤュユョヨラリルレロヮワヰヱヲンヴヵヶヷヸヹヺ・ーヽヾヿㇰㇱㇲㇳㇴㇵㇶㇷㇸㇹㇺㇻㇼㇽㇾㇿ') |
|
vocab = ['<pad>', '<sos>', '<eos>'] + katakana |
|
vocab_dict = {v: k for k, v in enumerate(vocab)} |
|
|
|
h_latent=64 |
|
max_len=40 |
|
names=16 |
|
|
|
def detokenize(tokens): |
|
if EOS_token in tokens: |
|
return ''.join(vocab[token] for token in tokens[:tokens.index(EOS_token)]) |
|
else: |
|
return None |
|
|
|
while True: |
|
print('generating names...') |
|
for name in [detokenize(seq) for seq in dec(torch.randn(names,h_latent), max_length=max_len)[0].topk(1)[1].squeeze().tolist()]: |
|
if name is not None: |
|
print(name) |
|
input("press enter to continue generation...") |
|
|