update with 219M models
Browse files- app.py +128 -16
- vocab219/idx2word.json +0 -0
- vocab219/word2idx.json +0 -0
app.py
CHANGED
@@ -60,15 +60,6 @@ def lookup_words(idx2word, indices):
|
|
60 |
return [idx2word[str(idx)] for idx in indices]
|
61 |
|
62 |
|
63 |
-
params = {'input_dim': len(word2idx),
|
64 |
-
'emb_dim': 128,
|
65 |
-
'enc_hid_dim': 256,
|
66 |
-
'dec_hid_dim': 256,
|
67 |
-
'dropout': 0.5,
|
68 |
-
'attn_dim': 32,
|
69 |
-
'teacher_forcing_ratio': 0.5,
|
70 |
-
'epochs': 35}
|
71 |
-
|
72 |
class Encoder(nn.Module):
|
73 |
"""
|
74 |
GRU RNN Encoder
|
@@ -292,8 +283,15 @@ class Seq2Seq(nn.Module):
|
|
292 |
output = trg[t] if teacher_force else top1
|
293 |
|
294 |
return outputs
|
295 |
-
|
296 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
|
298 |
enc = Encoder(input_dim=params['input_dim'], emb_dim=params['emb_dim'], enc_hid_dim=params['enc_hid_dim'], dec_hid_dim=params['dec_hid_dim'], dropout=params['dropout'])
|
299 |
attn = Attention(enc_hid_dim=params['enc_hid_dim'], dec_hid_dim=params['dec_hid_dim'], attn_dim=params['attn_dim'])
|
@@ -308,9 +306,50 @@ norm_model = Seq2Seq(encoder=enc, decoder=dec, device=device)
|
|
308 |
norm_model.load_state_dict(torch.load('NormSeq2Seq-188M_epoch35.pt', map_location=torch.device('cpu')))
|
309 |
norm_model.to(device)
|
310 |
|
311 |
-
|
|
|
|
|
|
|
312 |
|
313 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
word2idx=word2idx, idx2word=idx2word,
|
315 |
device=device, tokenize=tokenize, preprocess_text=preprocess_text,
|
316 |
lookup_words=lookup_words, models_dict=models_dict):
|
@@ -343,7 +382,7 @@ def generateAttn(sentence, history, max_len=12,
|
|
343 |
response = lookup_words(idx2word, outputs)
|
344 |
return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
|
345 |
|
346 |
-
def
|
347 |
word2idx=word2idx, idx2word=idx2word,
|
348 |
device=device, tokenize=tokenize, preprocess_text=preprocess_text,
|
349 |
lookup_words=lookup_words, models_dict=models_dict):
|
@@ -376,13 +415,86 @@ def generateNorm(sentence, history, max_len=12,
|
|
376 |
response = lookup_words(idx2word, outputs)
|
377 |
return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
|
378 |
|
379 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
380 |
|
381 |
with gr.Blocks() as demo:
|
382 |
-
gr.
|
|
|
383 |
title="NormalSeq2Seq-188M")
|
384 |
-
|
385 |
title="AttentionSeq2Seq-188M")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
386 |
|
387 |
if __name__ == "__main__":
|
388 |
demo.launch()
|
|
|
60 |
return [idx2word[str(idx)] for idx in indices]
|
61 |
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
class Encoder(nn.Module):
|
64 |
"""
|
65 |
GRU RNN Encoder
|
|
|
283 |
output = trg[t] if teacher_force else top1
|
284 |
|
285 |
return outputs
|
|
|
286 |
|
287 |
+
params = {'input_dim': len(word2idx),
|
288 |
+
'emb_dim': 128,
|
289 |
+
'enc_hid_dim': 256,
|
290 |
+
'dec_hid_dim': 256,
|
291 |
+
'dropout': 0.5,
|
292 |
+
'attn_dim': 32,
|
293 |
+
'teacher_forcing_ratio': 0.5,
|
294 |
+
'epochs': 35}
|
295 |
|
296 |
enc = Encoder(input_dim=params['input_dim'], emb_dim=params['emb_dim'], enc_hid_dim=params['enc_hid_dim'], dec_hid_dim=params['dec_hid_dim'], dropout=params['dropout'])
|
297 |
attn = Attention(enc_hid_dim=params['enc_hid_dim'], dec_hid_dim=params['dec_hid_dim'], attn_dim=params['attn_dim'])
|
|
|
306 |
norm_model.load_state_dict(torch.load('NormSeq2Seq-188M_epoch35.pt', map_location=torch.device('cpu')))
|
307 |
norm_model.to(device)
|
308 |
|
309 |
+
with open('vocab219/word2idx.json', 'r') as f:
|
310 |
+
word2idx2 = json.load(f)
|
311 |
+
with open('vocab219/idx2word.json', 'r') as f:
|
312 |
+
idx2word2 = json.load(f)
|
313 |
|
314 |
+
params219 = {'input_dim': len(word2idx2),
|
315 |
+
'emb_dim': 192,
|
316 |
+
'enc_hid_dim': 256,
|
317 |
+
'dec_hid_dim': 256,
|
318 |
+
'dropout': 0.5,
|
319 |
+
'attn_dim': 64,
|
320 |
+
'teacher_forcing_ratio': 0.5,
|
321 |
+
'epochs': 35}
|
322 |
+
|
323 |
+
enc = Encoder(input_dim=params219['input_dim'], emb_dim=params219['emb_dim'],
|
324 |
+
enc_hid_dim=params219['enc_hid_dim'], dec_hid_dim=params219['dec_hid_dim'],
|
325 |
+
dropout=params219['dropout'])
|
326 |
+
attn = Attention(enc_hid_dim=params219['enc_hid_dim'], dec_hid_dim=params219['dec_hid_dim'],
|
327 |
+
attn_dim=params219['attn_dim'])
|
328 |
+
dec = AttnDecoder(output_dim=params219['input_dim'], emb_dim=params219['emb_dim'],
|
329 |
+
enc_hid_dim=params219['enc_hid_dim'], dec_hid_dim=params219['dec_hid_dim'],
|
330 |
+
attention=attn, dropout=params219['dropout'])
|
331 |
+
attn_model219 = Seq2Seq(encoder=enc, decoder=dec, device=device)
|
332 |
+
attn_model219.load_state_dict(torch.load('AttnSeq2Seq-219M_epoch35.pt',
|
333 |
+
map_location=torch.device('cpu')))
|
334 |
+
attn_model219.to(device)
|
335 |
+
|
336 |
+
enc = Encoder(input_dim=params219['input_dim'], emb_dim=params219['emb_dim'],
|
337 |
+
enc_hid_dim=params219['enc_hid_dim'],
|
338 |
+
dec_hid_dim=params219['dec_hid_dim'], dropout=params219['dropout'])
|
339 |
+
dec = Decoder(output_dim=params219['input_dim'], emb_dim=params219['emb_dim'],
|
340 |
+
enc_hid_dim=params219['enc_hid_dim'],
|
341 |
+
dec_hid_dim=params219['dec_hid_dim'],
|
342 |
+
dropout=params219['dropout'])
|
343 |
+
norm_model219 = Seq2Seq(encoder=enc, decoder=dec, device=device)
|
344 |
+
norm_model219.load_state_dict(torch.load('NormSeq2Seq-219M_epoch35.pt',
|
345 |
+
map_location=torch.device('cpu')))
|
346 |
+
norm_model219.to(device)
|
347 |
+
|
348 |
+
models_dict = {'AttentionSeq2Seq-188M': attn_model, 'NormalSeq2Seq-188M': norm_model,
|
349 |
+
'AttentionSeq2Seq-219M': attn_model219,
|
350 |
+
'NormalSeq2Seq-219M': norm_model219}
|
351 |
+
|
352 |
+
def generateAttn188(sentence, history, max_len=12,
|
353 |
word2idx=word2idx, idx2word=idx2word,
|
354 |
device=device, tokenize=tokenize, preprocess_text=preprocess_text,
|
355 |
lookup_words=lookup_words, models_dict=models_dict):
|
|
|
382 |
response = lookup_words(idx2word, outputs)
|
383 |
return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
|
384 |
|
385 |
+
def generateNorm188(sentence, history, max_len=12,
|
386 |
word2idx=word2idx, idx2word=idx2word,
|
387 |
device=device, tokenize=tokenize, preprocess_text=preprocess_text,
|
388 |
lookup_words=lookup_words, models_dict=models_dict):
|
|
|
415 |
response = lookup_words(idx2word, outputs)
|
416 |
return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
|
417 |
|
418 |
+
def generateAttn219(sentence, history, max_len=12,
|
419 |
+
word2idx=word2idx2, idx2word=idx2word2,
|
420 |
+
device=device, tokenize=tokenize, preprocess_text=preprocess_text,
|
421 |
+
lookup_words=lookup_words, models_dict=models_dict):
|
422 |
+
"""
|
423 |
+
Generate response
|
424 |
+
:param model: model
|
425 |
+
:param sentence: sentence
|
426 |
+
:param max_len: maximum length of sequence
|
427 |
+
:param word2idx: word to index mapping
|
428 |
+
:param idx2word: index to word mapping
|
429 |
+
:return: response
|
430 |
+
"""
|
431 |
+
history = history
|
432 |
+
model = models_dict['AttentionSeq2Seq-219M']
|
433 |
+
model.eval()
|
434 |
+
sentence = preprocess_text(sentence)
|
435 |
+
tokens = tokenize(sentence)
|
436 |
+
tokens = [word2idx[token] if token in word2idx else word2idx['<unk>'] for token in tokens]
|
437 |
+
tokens = [word2idx['<bos>']] + tokens + [word2idx['<eos>']]
|
438 |
+
tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(1).to(device)
|
439 |
+
outputs = [word2idx['<bos>']]
|
440 |
+
with torch.no_grad():
|
441 |
+
encoder_outputs, hidden = model.encoder(tokens)
|
442 |
+
for t in range(max_len):
|
443 |
+
output, hidden = model.decoder(torch.tensor([outputs[-1]], dtype=torch.long).to(device), hidden, encoder_outputs)
|
444 |
+
top1 = output.max(1)[1]
|
445 |
+
outputs.append(top1.item())
|
446 |
+
if top1.item() == word2idx['<eos>']:
|
447 |
+
break
|
448 |
+
response = lookup_words(idx2word, outputs)
|
449 |
+
return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
|
450 |
+
|
451 |
+
def generateNorm219(sentence, history, max_len=12,
|
452 |
+
word2idx=word2idx2, idx2word=idx2word2,
|
453 |
+
device=device, tokenize=tokenize, preprocess_text=preprocess_text,
|
454 |
+
lookup_words=lookup_words, models_dict=models_dict):
|
455 |
+
"""
|
456 |
+
Generate response
|
457 |
+
:param model: model
|
458 |
+
:param sentence: sentence
|
459 |
+
:param max_len: maximum length of sequence
|
460 |
+
:param word2idx: word to index mapping
|
461 |
+
:param idx2word: index to word mapping
|
462 |
+
:return: response
|
463 |
+
"""
|
464 |
+
history = history
|
465 |
+
model = models_dict['NormalSeq2Seq-219M']
|
466 |
+
model.eval()
|
467 |
+
sentence = preprocess_text(sentence)
|
468 |
+
tokens = tokenize(sentence)
|
469 |
+
tokens = [word2idx[token] if token in word2idx else word2idx['<unk>'] for token in tokens]
|
470 |
+
tokens = [word2idx['<bos>']] + tokens + [word2idx['<eos>']]
|
471 |
+
tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(1).to(device)
|
472 |
+
outputs = [word2idx['<bos>']]
|
473 |
+
with torch.no_grad():
|
474 |
+
encoder_outputs, hidden = model.encoder(tokens)
|
475 |
+
for t in range(max_len):
|
476 |
+
output, hidden = model.decoder(torch.tensor([outputs[-1]], dtype=torch.long).to(device), hidden, encoder_outputs)
|
477 |
+
top1 = output.max(1)[1]
|
478 |
+
outputs.append(top1.item())
|
479 |
+
if top1.item() == word2idx['<eos>']:
|
480 |
+
break
|
481 |
+
response = lookup_words(idx2word, outputs)
|
482 |
+
return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
|
483 |
|
484 |
with gr.Blocks() as demo:
|
485 |
+
with gr.Row():
|
486 |
+
gr.ChatInterface(generateNorm188,
|
487 |
title="NormalSeq2Seq-188M")
|
488 |
+
gr.ChatInterface(generateAttn188,
|
489 |
title="AttentionSeq2Seq-188M")
|
490 |
+
gr.Markdown("""
|
491 |
+
# Seq2Seq Generative Chatbot with 219M parameters
|
492 |
+
""")
|
493 |
+
with gr.Row():
|
494 |
+
gr.ChatInterface(generateNorm219,
|
495 |
+
title="NormalSeq2Seq-219M")
|
496 |
+
gr.ChatInterface(generateAttn219,
|
497 |
+
title="AttentionSeq2Seq-219M")
|
498 |
|
499 |
if __name__ == "__main__":
|
500 |
demo.launch()
|
vocab219/idx2word.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
vocab219/word2idx.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|