refactor: clean up num_examples code (#14)
Browse files- refactor: clean up num_examples code (b67df152d4d5143d84ff7c023a245d301ecccb22)
- custom_st.py +1 -5
custom_st.py
CHANGED
@@ -104,11 +104,7 @@ class Transformer(nn.Module):
|
|
104 |
adapter_mask = None
|
105 |
if task_type:
|
106 |
task_id = self._adaptation_map[task_type]
|
107 |
-
num_examples =
|
108 |
-
if isinstance(features['input_ids'][0], list):
|
109 |
-
# If input_ids[0] is a list, it means multiple inputs (list of texts)
|
110 |
-
num_examples = len(features['input_ids'])
|
111 |
-
|
112 |
adapter_mask = torch.full(
|
113 |
(num_examples,), task_id, dtype=torch.int32, device=features['input_ids'].device
|
114 |
)
|
|
|
104 |
adapter_mask = None
|
105 |
if task_type:
|
106 |
task_id = self._adaptation_map[task_type]
|
107 |
+
num_examples = features['input_ids'].size(0)
|
|
|
|
|
|
|
|
|
108 |
adapter_mask = torch.full(
|
109 |
(num_examples,), task_id, dtype=torch.int32, device=features['input_ids'].device
|
110 |
)
|