stakelovelace commited on
Commit
2069fff
1 Parent(s): 339b8e7

commit from tesla

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -60,6 +60,12 @@ def train_model(model, tokenizer, data, device):
60
 
61
  trainer.train()
62
 
 
 
 
 
 
 
63
  # Perform any remaining steps such as logging, saving, etc.
64
  trainer.save_model()
65
 
@@ -70,8 +76,8 @@ def main(api_name, base_url):
70
  # Load the configuration for a specific model
71
  config = AutoConfig.from_pretrained('google/codegemma-2b')
72
  # Update the activation function
73
- config.hidden_act = '' # Set to use approximate GeLU gelu_pytorch_tanh
74
- config.hidden_activation = 'gelu' # Set to use GeLU
75
 
76
  model = AutoModelForCausalLM.from_pretrained('google/codegemma-2b', is_decoder=True)
77
  #model = BertLMHeadModel.from_pretrained('google/codegemma-2b', is_decoder=True)
@@ -80,7 +86,6 @@ def main(api_name, base_url):
80
  model.to(device) # Move model to the appropriate device
81
 
82
  train_model(model, tokenizer, data, device)
83
-
84
 
85
  model.save_pretrained("./fine_tuned_model")
86
  tokenizer.save_pretrained("./fine_tuned_model")
 
60
 
61
  trainer.train()
62
 
63
+ # Optionally clear cache if using GPU or MPS
64
+ if torch.cuda.is_available():
65
+ torch.cuda.empty_cache()
66
+ elif torch.has_mps:
67
+ torch.mps.empty_cache()
68
+
69
  # Perform any remaining steps such as logging, saving, etc.
70
  trainer.save_model()
71
 
 
76
  # Load the configuration for a specific model
77
  config = AutoConfig.from_pretrained('google/codegemma-2b')
78
  # Update the activation function
79
+ # config.hidden_act = '' # Set to use approximate GeLU gelu_pytorch_tanh
80
+ config.hidden_activation = 'gelu_pytorch_tanh' # Set to use GeLU
81
 
82
  model = AutoModelForCausalLM.from_pretrained('google/codegemma-2b', is_decoder=True)
83
  #model = BertLMHeadModel.from_pretrained('google/codegemma-2b', is_decoder=True)
 
86
  model.to(device) # Move model to the appropriate device
87
 
88
  train_model(model, tokenizer, data, device)
 
89
 
90
  model.save_pretrained("./fine_tuned_model")
91
  tokenizer.save_pretrained("./fine_tuned_model")