czl commited on
Commit
93ceca7
1 Parent(s): 3159b65

update with 219M models

Browse files
Files changed (3) hide show
  1. app.py +128 -16
  2. vocab219/idx2word.json +0 -0
  3. vocab219/word2idx.json +0 -0
app.py CHANGED
@@ -60,15 +60,6 @@ def lookup_words(idx2word, indices):
60
  return [idx2word[str(idx)] for idx in indices]
61
 
62
 
63
- params = {'input_dim': len(word2idx),
64
- 'emb_dim': 128,
65
- 'enc_hid_dim': 256,
66
- 'dec_hid_dim': 256,
67
- 'dropout': 0.5,
68
- 'attn_dim': 32,
69
- 'teacher_forcing_ratio': 0.5,
70
- 'epochs': 35}
71
-
72
  class Encoder(nn.Module):
73
  """
74
  GRU RNN Encoder
@@ -292,8 +283,15 @@ class Seq2Seq(nn.Module):
292
  output = trg[t] if teacher_force else top1
293
 
294
  return outputs
295
-
296
 
 
 
 
 
 
 
 
 
297
 
298
  enc = Encoder(input_dim=params['input_dim'], emb_dim=params['emb_dim'], enc_hid_dim=params['enc_hid_dim'], dec_hid_dim=params['dec_hid_dim'], dropout=params['dropout'])
299
  attn = Attention(enc_hid_dim=params['enc_hid_dim'], dec_hid_dim=params['dec_hid_dim'], attn_dim=params['attn_dim'])
@@ -308,9 +306,50 @@ norm_model = Seq2Seq(encoder=enc, decoder=dec, device=device)
308
  norm_model.load_state_dict(torch.load('NormSeq2Seq-188M_epoch35.pt', map_location=torch.device('cpu')))
309
  norm_model.to(device)
310
 
311
- models_dict = {'AttentionSeq2Seq-188M': attn_model, 'NormalSeq2Seq-188M': norm_model}
 
 
 
312
 
313
- def generateAttn(sentence, history, max_len=12,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  word2idx=word2idx, idx2word=idx2word,
315
  device=device, tokenize=tokenize, preprocess_text=preprocess_text,
316
  lookup_words=lookup_words, models_dict=models_dict):
@@ -343,7 +382,7 @@ def generateAttn(sentence, history, max_len=12,
343
  response = lookup_words(idx2word, outputs)
344
  return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
345
 
346
- def generateNorm(sentence, history, max_len=12,
347
  word2idx=word2idx, idx2word=idx2word,
348
  device=device, tokenize=tokenize, preprocess_text=preprocess_text,
349
  lookup_words=lookup_words, models_dict=models_dict):
@@ -376,13 +415,86 @@ def generateNorm(sentence, history, max_len=12,
376
  response = lookup_words(idx2word, outputs)
377
  return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
378
 
379
- # demo = gr.ChatInterface(generate, title="AttentionSeq2Seq-188M")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
 
381
  with gr.Blocks() as demo:
382
- gr.ChatInterface(generateNorm,
 
383
  title="NormalSeq2Seq-188M")
384
- gr.ChatInterface(generateAttn,
385
  title="AttentionSeq2Seq-188M")
 
 
 
 
 
 
 
 
386
 
387
  if __name__ == "__main__":
388
  demo.launch()
 
60
  return [idx2word[str(idx)] for idx in indices]
61
 
62
 
 
 
 
 
 
 
 
 
 
63
  class Encoder(nn.Module):
64
  """
65
  GRU RNN Encoder
 
283
  output = trg[t] if teacher_force else top1
284
 
285
  return outputs
 
286
 
287
+ params = {'input_dim': len(word2idx),
288
+ 'emb_dim': 128,
289
+ 'enc_hid_dim': 256,
290
+ 'dec_hid_dim': 256,
291
+ 'dropout': 0.5,
292
+ 'attn_dim': 32,
293
+ 'teacher_forcing_ratio': 0.5,
294
+ 'epochs': 35}
295
 
296
  enc = Encoder(input_dim=params['input_dim'], emb_dim=params['emb_dim'], enc_hid_dim=params['enc_hid_dim'], dec_hid_dim=params['dec_hid_dim'], dropout=params['dropout'])
297
  attn = Attention(enc_hid_dim=params['enc_hid_dim'], dec_hid_dim=params['dec_hid_dim'], attn_dim=params['attn_dim'])
 
306
  norm_model.load_state_dict(torch.load('NormSeq2Seq-188M_epoch35.pt', map_location=torch.device('cpu')))
307
  norm_model.to(device)
308
 
309
+ with open('vocab219/word2idx.json', 'r') as f:
310
+ word2idx2 = json.load(f)
311
+ with open('vocab219/idx2word.json', 'r') as f:
312
+ idx2word2 = json.load(f)
313
 
314
+ params219 = {'input_dim': len(word2idx2),
315
+ 'emb_dim': 192,
316
+ 'enc_hid_dim': 256,
317
+ 'dec_hid_dim': 256,
318
+ 'dropout': 0.5,
319
+ 'attn_dim': 64,
320
+ 'teacher_forcing_ratio': 0.5,
321
+ 'epochs': 35}
322
+
323
+ enc = Encoder(input_dim=params219['input_dim'], emb_dim=params219['emb_dim'],
324
+ enc_hid_dim=params219['enc_hid_dim'], dec_hid_dim=params219['dec_hid_dim'],
325
+ dropout=params219['dropout'])
326
+ attn = Attention(enc_hid_dim=params219['enc_hid_dim'], dec_hid_dim=params219['dec_hid_dim'],
327
+ attn_dim=params219['attn_dim'])
328
+ dec = AttnDecoder(output_dim=params219['input_dim'], emb_dim=params219['emb_dim'],
329
+ enc_hid_dim=params219['enc_hid_dim'], dec_hid_dim=params219['dec_hid_dim'],
330
+ attention=attn, dropout=params219['dropout'])
331
+ attn_model219 = Seq2Seq(encoder=enc, decoder=dec, device=device)
332
+ attn_model219.load_state_dict(torch.load('AttnSeq2Seq-219M_epoch35.pt',
333
+ map_location=torch.device('cpu')))
334
+ attn_model219.to(device)
335
+
336
+ enc = Encoder(input_dim=params219['input_dim'], emb_dim=params219['emb_dim'],
337
+ enc_hid_dim=params219['enc_hid_dim'],
338
+ dec_hid_dim=params219['dec_hid_dim'], dropout=params219['dropout'])
339
+ dec = Decoder(output_dim=params219['input_dim'], emb_dim=params219['emb_dim'],
340
+ enc_hid_dim=params219['enc_hid_dim'],
341
+ dec_hid_dim=params219['dec_hid_dim'],
342
+ dropout=params219['dropout'])
343
+ norm_model219 = Seq2Seq(encoder=enc, decoder=dec, device=device)
344
+ norm_model219.load_state_dict(torch.load('NormSeq2Seq-219M_epoch35.pt',
345
+ map_location=torch.device('cpu')))
346
+ norm_model219.to(device)
347
+
348
+ models_dict = {'AttentionSeq2Seq-188M': attn_model, 'NormalSeq2Seq-188M': norm_model,
349
+ 'AttentionSeq2Seq-219M': attn_model219,
350
+ 'NormalSeq2Seq-219M': norm_model219}
351
+
352
+ def generateAttn188(sentence, history, max_len=12,
353
  word2idx=word2idx, idx2word=idx2word,
354
  device=device, tokenize=tokenize, preprocess_text=preprocess_text,
355
  lookup_words=lookup_words, models_dict=models_dict):
 
382
  response = lookup_words(idx2word, outputs)
383
  return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
384
 
385
+ def generateNorm188(sentence, history, max_len=12,
386
  word2idx=word2idx, idx2word=idx2word,
387
  device=device, tokenize=tokenize, preprocess_text=preprocess_text,
388
  lookup_words=lookup_words, models_dict=models_dict):
 
415
  response = lookup_words(idx2word, outputs)
416
  return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
417
 
418
+ def generateAttn219(sentence, history, max_len=12,
419
+ word2idx=word2idx2, idx2word=idx2word2,
420
+ device=device, tokenize=tokenize, preprocess_text=preprocess_text,
421
+ lookup_words=lookup_words, models_dict=models_dict):
422
+ """
423
+ Generate response
424
+ :param model: model
425
+ :param sentence: sentence
426
+ :param max_len: maximum length of sequence
427
+ :param word2idx: word to index mapping
428
+ :param idx2word: index to word mapping
429
+ :return: response
430
+ """
431
+ history = history
432
+ model = models_dict['AttentionSeq2Seq-219M']
433
+ model.eval()
434
+ sentence = preprocess_text(sentence)
435
+ tokens = tokenize(sentence)
436
+ tokens = [word2idx[token] if token in word2idx else word2idx['<unk>'] for token in tokens]
437
+ tokens = [word2idx['<bos>']] + tokens + [word2idx['<eos>']]
438
+ tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(1).to(device)
439
+ outputs = [word2idx['<bos>']]
440
+ with torch.no_grad():
441
+ encoder_outputs, hidden = model.encoder(tokens)
442
+ for t in range(max_len):
443
+ output, hidden = model.decoder(torch.tensor([outputs[-1]], dtype=torch.long).to(device), hidden, encoder_outputs)
444
+ top1 = output.max(1)[1]
445
+ outputs.append(top1.item())
446
+ if top1.item() == word2idx['<eos>']:
447
+ break
448
+ response = lookup_words(idx2word, outputs)
449
+ return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
450
+
451
+ def generateNorm219(sentence, history, max_len=12,
452
+ word2idx=word2idx2, idx2word=idx2word2,
453
+ device=device, tokenize=tokenize, preprocess_text=preprocess_text,
454
+ lookup_words=lookup_words, models_dict=models_dict):
455
+ """
456
+ Generate response
457
+ :param model: model
458
+ :param sentence: sentence
459
+ :param max_len: maximum length of sequence
460
+ :param word2idx: word to index mapping
461
+ :param idx2word: index to word mapping
462
+ :return: response
463
+ """
464
+ history = history
465
+ model = models_dict['NormalSeq2Seq-219M']
466
+ model.eval()
467
+ sentence = preprocess_text(sentence)
468
+ tokens = tokenize(sentence)
469
+ tokens = [word2idx[token] if token in word2idx else word2idx['<unk>'] for token in tokens]
470
+ tokens = [word2idx['<bos>']] + tokens + [word2idx['<eos>']]
471
+ tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(1).to(device)
472
+ outputs = [word2idx['<bos>']]
473
+ with torch.no_grad():
474
+ encoder_outputs, hidden = model.encoder(tokens)
475
+ for t in range(max_len):
476
+ output, hidden = model.decoder(torch.tensor([outputs[-1]], dtype=torch.long).to(device), hidden, encoder_outputs)
477
+ top1 = output.max(1)[1]
478
+ outputs.append(top1.item())
479
+ if top1.item() == word2idx['<eos>']:
480
+ break
481
+ response = lookup_words(idx2word, outputs)
482
+ return ' '.join(response).replace('<bos>', '').replace('<eos>', '').strip()
483
 
484
  with gr.Blocks() as demo:
485
+ with gr.Row():
486
+ gr.ChatInterface(generateNorm188,
487
  title="NormalSeq2Seq-188M")
488
+ gr.ChatInterface(generateAttn188,
489
  title="AttentionSeq2Seq-188M")
490
+ gr.Markdown("""
491
+ # Seq2Seq Generative Chatbot with 219M parameters
492
+ """)
493
+ with gr.Row():
494
+ gr.ChatInterface(generateNorm219,
495
+ title="NormalSeq2Seq-219M")
496
+ gr.ChatInterface(generateAttn219,
497
+ title="AttentionSeq2Seq-219M")
498
 
499
  if __name__ == "__main__":
500
  demo.launch()
vocab219/idx2word.json ADDED
The diff for this file is too large to render. See raw diff
 
vocab219/word2idx.json ADDED
The diff for this file is too large to render. See raw diff