yuhuili commited on
Commit
6e01670
1 Parent(s): 9647dd6

Update model/ea_model.py

Browse files
Files changed (1) hide show
  1. model/ea_model.py +1 -1
model/ea_model.py CHANGED
@@ -69,7 +69,7 @@ class EaModel(nn.Module):
69
 
70
 
71
  device = base_model.model.layers[-1].self_attn.q_proj.weight.device
72
- self.ea_layer.to(self.base_model.dtype).to(device)
73
  self.ea_layer.init_tree()
74
 
75
 
 
69
 
70
 
71
  device = base_model.model.layers[-1].self_attn.q_proj.weight.device
72
+ self.ea_layer.to(torch.float16).to(device)
73
  self.ea_layer.init_tree()
74
 
75