gpt-omni commited on
Commit
5f2e2de
1 Parent(s): 8b673c6

Update litgpt/model.py

Browse files
Files changed (1) hide show
  1. litgpt/model.py +3 -1
litgpt/model.py CHANGED
@@ -16,10 +16,12 @@ from litgpt.config import Config
16
 
17
 
18
  class GPT(nn.Module):
19
- def __init__(self, config: Config) -> None:
20
  super().__init__()
21
  assert config.padded_vocab_size is not None
22
  self.config = config
 
 
23
  if self.config.asr_adapter == "mlp":
24
  print("Using MLP adapter for ASR feature")
25
  self.whisper_adapter = nn.Linear(config.whisper_adapter_dim, config.n_embd)
 
16
 
17
 
18
  class GPT(nn.Module):
19
+ def __init__(self, config: Config, device=None) -> None:
20
  super().__init__()
21
  assert config.padded_vocab_size is not None
22
  self.config = config
23
+ if device is not None:
24
+ self.device = device
25
  if self.config.asr_adapter == "mlp":
26
  print("Using MLP adapter for ASR feature")
27
  self.whisper_adapter = nn.Linear(config.whisper_adapter_dim, config.n_embd)