#!/usr/bin/env python # coding: utf-8 import torch import torch.nn as nn import torch.nn.functional as F class DecoderGRU(nn.Module): def __init__(self, latent_size, hidden_size, output_size): super(DecoderGRU, self).__init__() self.proj1 = nn.Linear(latent_size, latent_size) self.proj_activation = nn.ReLU() self.proj2 = nn.Linear(latent_size, 2 * hidden_size) self.embedding = nn.Embedding(output_size, hidden_size) self.gru = nn.GRU(hidden_size, hidden_size, num_layers=2, batch_first=True) self.out = nn.Linear(hidden_size, output_size) def forward(self, encoder_sample, target_tensor=None, max_length=16): batch_size = encoder_sample.size(0) decoder_hidden = self.proj1(encoder_sample) decoder_hidden = self.proj_activation(decoder_hidden) decoder_hidden = self.proj2(decoder_hidden) decoder_hidden = decoder_hidden.view(batch_size, 2, -1).permute(1, 0, 2).contiguous() if target_tensor is not None: decoder_input = target_tensor decoder_outputs, decoder_hidden = self.forward_step(decoder_input, decoder_hidden) else: decoder_input = torch.empty(batch_size, 1, dtype=torch.long).fill_(SOS_token) decoder_outputs = [] for i in range(max_length): decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden) decoder_outputs.append(decoder_output) _, topi = decoder_output.topk(1) decoder_input = topi.squeeze(-1).detach() decoder_outputs = torch.cat(decoder_outputs, dim=1) decoder_outputs = F.log_softmax(decoder_outputs, dim=-1) return decoder_outputs, decoder_hidden def forward_step(self, input, hidden): output = self.embedding(input) output = F.relu(output) output, hidden = self.gru(output, hidden) output = self.out(output) return output, hidden dec = torch.load('decoder.pt', map_location='cpu') SOS_token = 1 EOS_token = 2 katakana = list('゠ァアィイゥウェエォオカガキギクグケゲコゴサザシジスズセゼソゾタダチヂッツヅテデトドナニヌネノハバパヒビピフブプヘベペホボポマミムメモャヤュユョヨラリルレロヮワヰヱヲンヴヵヶヷヸヹヺ・ーヽヾヿㇰㇱㇲㇳㇴㇵㇶㇷㇸㇹㇺㇻㇼㇽㇾㇿ') vocab = ['', '', ''] + katakana vocab_dict = {v: k for k, v in enumerate(vocab)} h_latent=64 max_len=40 names=16 def detokenize(tokens): if EOS_token in tokens: return ''.join(vocab[token] for token in tokens[:tokens.index(EOS_token)]) else: return None while True: print('generating names...') for name in [detokenize(seq) for seq in dec(torch.randn(names,h_latent), max_length=max_len)[0].topk(1)[1].squeeze().tolist()]: if name is not None: print(name) input("press enter to continue generation...")