Commit
•
3f9f573
1
Parent(s):
89047a4
Update device
Browse files- modeling_custom.py +3 -3
modeling_custom.py
CHANGED
@@ -140,11 +140,11 @@ class LlamaForRewardModelWithGating(LlamaPreTrainedModel):
|
|
140 |
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
141 |
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
142 |
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
143 |
-
sequence_lengths = sequence_lengths.to(
|
144 |
else:
|
145 |
sequence_lengths = -1
|
146 |
|
147 |
-
dummy_iterator = torch.arange(batch_size, device=
|
148 |
hidden_states = tokens_hidden_states[dummy_iterator, sequence_lengths]
|
149 |
assert hidden_states.shape == (batch_size, self.config.hidden_size)
|
150 |
rewards = self.regression_layer(hidden_states)
|
@@ -163,4 +163,4 @@ class LlamaForRewardModelWithGating(LlamaPreTrainedModel):
|
|
163 |
gating_output=gating_output,
|
164 |
score=score,
|
165 |
logits=score,
|
166 |
-
)
|
|
|
140 |
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
141 |
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
142 |
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
143 |
+
sequence_lengths = sequence_lengths.to(self.device)
|
144 |
else:
|
145 |
sequence_lengths = -1
|
146 |
|
147 |
+
dummy_iterator = torch.arange(batch_size, device=self.device)
|
148 |
hidden_states = tokens_hidden_states[dummy_iterator, sequence_lengths]
|
149 |
assert hidden_states.shape == (batch_size, self.config.hidden_size)
|
150 |
rewards = self.regression_layer(hidden_states)
|
|
|
163 |
gating_output=gating_output,
|
164 |
score=score,
|
165 |
logits=score,
|
166 |
+
)
|