skytnt commited on
Commit
48914a6
1 Parent(s): 6acc8e8

use_cache=False to avoid warning

Browse files
Files changed (1) hide show
  1. midi_model.py +4 -2
midi_model.py CHANGED
@@ -28,11 +28,13 @@ class MIDIModelConfig:
28
  net_config = LlamaConfig(vocab_size=tokenizer.vocab_size,
29
  hidden_size=n_embd, num_attention_heads=n_head,
30
  num_hidden_layers=n_layer, intermediate_size=n_inner,
31
- pad_token_id=tokenizer.pad_id, max_position_embeddings=4096)
 
32
  net_token_config = LlamaConfig(vocab_size=tokenizer.vocab_size,
33
  hidden_size=n_embd, num_attention_heads=n_head // 4,
34
  num_hidden_layers=n_layer // 4, intermediate_size=n_inner // 4,
35
- pad_token_id=tokenizer.pad_id, max_position_embeddings=4096)
 
36
  return MIDIModelConfig(tokenizer, net_config, net_token_config)
37
 
38
  @staticmethod
 
28
  net_config = LlamaConfig(vocab_size=tokenizer.vocab_size,
29
  hidden_size=n_embd, num_attention_heads=n_head,
30
  num_hidden_layers=n_layer, intermediate_size=n_inner,
31
+ pad_token_id=tokenizer.pad_id, max_position_embeddings=4096,
32
+ use_cache=False)
33
  net_token_config = LlamaConfig(vocab_size=tokenizer.vocab_size,
34
  hidden_size=n_embd, num_attention_heads=n_head // 4,
35
  num_hidden_layers=n_layer // 4, intermediate_size=n_inner // 4,
36
+ pad_token_id=tokenizer.pad_id, max_position_embeddings=4096,
37
+ use_cache=False)
38
  return MIDIModelConfig(tokenizer, net_config, net_token_config)
39
 
40
  @staticmethod