mohdelgaar commited on
Commit
e048c03
1 Parent(s): 20b7679

updating model loading

Browse files
Files changed (2) hide show
  1. app.py +30 -10
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,16 +1,16 @@
1
  import nltk
2
- nltk.download('wordnet')
3
  import spacy
4
- spacy.cli.download('en_core_web_sm')
 
 
5
  from const import name_map
6
  from demo import run_gradio
7
- from model import EncoderDecoderVAE
8
  from options import parse_args
9
  import numpy as np
10
  from transformers import T5Tokenizer
11
  import torch
12
  import joblib
13
- import pandas as pd
14
 
15
 
16
  def process_examples(samples, full_names):
@@ -24,19 +24,39 @@ def process_examples(samples, full_names):
24
  return list(samples)
25
 
26
  args, args_list, lng_names = parse_args(ckpt='./ckpt/model.pt')
 
 
27
 
28
  tokenizer = T5Tokenizer.from_pretrained(args.model_name)
29
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
30
 
31
- scaler = joblib.load('assets/scaler.bin')
32
  full_names = [name_map[x] for x in lng_names]
33
- samples = joblib.load('assets/samples.bin')
34
- examples = process_examples(samples, full_names)
35
- ling_collection = np.load('assets/ling_collection.npy')
 
 
 
36
 
37
- model = EncoderDecoderVAE(args, tokenizer.pad_token_id, tokenizer.get_vocab()['</s>']).to(device)
38
  state = torch.load(args.ckpt, map_location=torch.device('cpu'))
39
- model.load_state_dict(state['model'], strict=False)
40
  model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  run_gradio(model, tokenizer, scaler, ling_collection, examples, full_names)
 
1
  import nltk
 
2
  import spacy
3
+ # nltk.download('wordnet')
4
+ # spacy.cli.download('en_core_web_sm')
5
+
6
  from const import name_map
7
  from demo import run_gradio
8
+ from model import get_model
9
  from options import parse_args
10
  import numpy as np
11
  from transformers import T5Tokenizer
12
  import torch
13
  import joblib
 
14
 
15
 
16
  def process_examples(samples, full_names):
 
24
  return list(samples)
25
 
26
  args, args_list, lng_names = parse_args(ckpt='./ckpt/model.pt')
27
+ print(args)
28
+ exit()
29
 
30
  tokenizer = T5Tokenizer.from_pretrained(args.model_name)
31
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
32
 
 
33
  full_names = [name_map[x] for x in lng_names]
34
+ # samples = joblib.load('assets/samples.bin')
35
+ # examples = process_examples(samples, full_names)
36
+ # ling_collection = np.load('assets/ling_collection.npy')
37
+
38
+ scaler = joblib.load('assets/scaler.bin')
39
+ model, ling_disc, sem_emb = get_model(args, tokenizer, device)
40
 
 
41
  state = torch.load(args.ckpt, map_location=torch.device('cpu'))
42
+ model.load_state_dict(state['model'], strict=True)
43
  model.eval()
44
+ print(model is not None, ling_disc is not None, sem_emb is not None)
45
+ exit()
46
+
47
+ if args.disc_type == 't5':
48
+ state = torch.load(args.disc_ckpt)
49
+ if 'model' in state:
50
+ ling_disc.load_state_dict(state['model'], strict=False)
51
+ else:
52
+ ling_disc.load_state_dict(state, strict=False)
53
+ ling_disc.eval()
54
+
55
+ state = torch.load(args.sem_ckpt)
56
+ if 'model' in state:
57
+ sem_emb.load_state_dict(state['model'], strict=False)
58
+ else:
59
+ sem_emb.load_state_dict(state, strict=False)
60
+ sem_emb.eval()
61
 
62
  run_gradio(model, tokenizer, scaler, ling_collection, examples, full_names)
requirements.txt CHANGED
@@ -8,3 +8,4 @@ scikit-learn
8
  tqdm
9
  spacy
10
  sentencepiece
 
 
8
  tqdm
9
  spacy
10
  sentencepiece
11
+ lftk