MohamedRashad commited on
Commit
7264ba7
1 Parent(s): 6d41dc3

chore: Update TashkeelModelEO and TashkeelModelED loading in app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -31,8 +31,10 @@ print('Creating Model...')
31
  eo_model = TashkeelModelEO(tokenizer, max_seq_len=max_seq_len, n_layers=6, learnable_pos_emb=False)
32
  ed_model = TashkeelModelED(tokenizer, max_seq_len=max_seq_len, n_layers=3, learnable_pos_emb=False)
33
 
34
- eo_model.load_state_dict(torch.load(eo_ckpt_path, map_location=device)).eval().to(device)
35
- ed_model.load_state_dict(torch.load(ed_ckpt_path, map_location=device)).eval().to(device)
 
 
36
 
37
  @spaces.GPU()
38
  def infer_catt(input_text, choose_model):
 
31
  eo_model = TashkeelModelEO(tokenizer, max_seq_len=max_seq_len, n_layers=6, learnable_pos_emb=False)
32
  ed_model = TashkeelModelED(tokenizer, max_seq_len=max_seq_len, n_layers=3, learnable_pos_emb=False)
33
 
34
+ eo_model.load_state_dict(torch.load(eo_ckpt_path, map_location=device))
35
+ eo_model.eval().to(device)
36
+ ed_model.load_state_dict(torch.load(ed_ckpt_path, map_location=device))
37
+ ed_model.eval().to(device)
38
 
39
  @spaces.GPU()
40
  def infer_catt(input_text, choose_model):