Update modeling_mpt.py
Browse files- modeling_mpt.py +0 -1
modeling_mpt.py
CHANGED
@@ -46,7 +46,6 @@ class MPTModel(MPTPreTrainedModel):
|
|
46 |
self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
|
47 |
self.norm_f = norm_class(config.d_model, device=config.init_device)
|
48 |
if config.init_device != 'meta':
|
49 |
-
print(f'You are using config.init_device={config.init_device!r}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.')
|
50 |
self.apply(self.param_init_fn)
|
51 |
self.is_causal = not self.prefix_lm
|
52 |
self._attn_bias_initialized = False
|
|
|
46 |
self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
|
47 |
self.norm_f = norm_class(config.d_model, device=config.init_device)
|
48 |
if config.init_device != 'meta':
|
|
|
49 |
self.apply(self.param_init_fn)
|
50 |
self.is_causal = not self.prefix_lm
|
51 |
self._attn_bias_initialized = False
|