fix: make sure data and adapter on same device (#11)
Browse files- fix: make sure data and adapter on same device (08577bc2e88cb6d2e7ffa9fb2c45ba7c16c02836)
- custom_st.py +1 -2
custom_st.py
CHANGED
@@ -55,7 +55,6 @@ class Transformer(nn.Module):
|
|
55 |
|
56 |
config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
|
57 |
self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=config, cache_dir=cache_dir, **model_args)
|
58 |
-
self.device = next(self.auto_model.parameters()).device
|
59 |
|
60 |
self._lora_adaptations = config.lora_adaptations
|
61 |
if (
|
@@ -111,7 +110,7 @@ class Transformer(nn.Module):
|
|
111 |
num_examples = len(features['input_ids'])
|
112 |
|
113 |
adapter_mask = torch.full(
|
114 |
-
(num_examples,), task_id, dtype=torch.int32, device=
|
115 |
)
|
116 |
|
117 |
lora_arguments = (
|
|
|
55 |
|
56 |
config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
|
57 |
self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=config, cache_dir=cache_dir, **model_args)
|
|
|
58 |
|
59 |
self._lora_adaptations = config.lora_adaptations
|
60 |
if (
|
|
|
110 |
num_examples = len(features['input_ids'])
|
111 |
|
112 |
adapter_mask = torch.full(
|
113 |
+
(num_examples,), task_id, dtype=torch.int32, device=features['input_ids'].device
|
114 |
)
|
115 |
|
116 |
lora_arguments = (
|