duzx16 commited on
Commit
44011aa
1 Parent(s): 911e84c

Fix default dtype of classifier head

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +1 -1
modeling_chatglm.py CHANGED
@@ -1131,7 +1131,7 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1131
  self.num_labels = config.num_labels
1132
  self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
1133
 
1134
- self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
1135
  if config.classifier_dropout is not None:
1136
  self.dropout = nn.Dropout(config.classifier_dropout)
1137
  else:
 
1131
  self.num_labels = config.num_labels
1132
  self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
1133
 
1134
+ self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=config.torch_dtype)
1135
  if config.classifier_dropout is not None:
1136
  self.dropout = nn.Dropout(config.classifier_dropout)
1137
  else: