bwang0911 commited on
Commit
b67df15
1 Parent(s): cd40fae

refactor: clean up num_examples code

Browse files
Files changed (1) hide show
  1. custom_st.py +1 -5
custom_st.py CHANGED
@@ -104,11 +104,7 @@ class Transformer(nn.Module):
104
  adapter_mask = None
105
  if task_type:
106
  task_id = self._adaptation_map[task_type]
107
- num_examples = 1
108
- if isinstance(features['input_ids'][0], list):
109
- # If input_ids[0] is a list, it means multiple inputs (list of texts)
110
- num_examples = len(features['input_ids'])
111
-
112
  adapter_mask = torch.full(
113
  (num_examples,), task_id, dtype=torch.int32, device=features['input_ids'].device
114
  )
 
104
  adapter_mask = None
105
  if task_type:
106
  task_id = self._adaptation_map[task_type]
107
+ num_examples = features['input_ids'].size(0)
 
 
 
 
108
  adapter_mask = torch.full(
109
  (num_examples,), task_id, dtype=torch.int32, device=features['input_ids'].device
110
  )