|
"""Script to create the model artifact |
|
|
|
Trains a simple logistic regression with grid search on a synthetic dataset and |
|
stores the model in a pickle file. |
|
|
|
""" |
|
|
|
import joblib |
|
from sklearn.datasets import make_classification |
|
from sklearn.linear_model import SGDClassifier |
|
from sklearn.model_selection import GridSearchCV |
|
|
|
|
|
SEED = 0 |
|
FILENAME = 'sklearn_model.joblib' |
|
|
|
|
|
def get_data(): |
|
X, y = make_classification(n_samples=2000, random_state=SEED) |
|
return X, y |
|
|
|
|
|
def get_model(**kwargs): |
|
model = SGDClassifier(random_state=SEED) |
|
model.set_params(**kwargs) |
|
return model |
|
|
|
|
|
def get_hparams(): |
|
hparams = { |
|
'penalty': ['l1', 'l2'], |
|
'alpha': [0.00001, 0.0001, 0.001], |
|
} |
|
return hparams |
|
|
|
|
|
def grid_search(model, X, y, hparams): |
|
search = GridSearchCV(model, hparams, cv=5, scoring='accuracy') |
|
search.fit(X, y) |
|
return search |
|
|
|
|
|
def train(model, X, y, hparams): |
|
search = grid_search(model, X, y, hparams=hparams) |
|
print(f"Best accuracy: {100 * search.best_score_:.1f}%") |
|
print(f"Best parameters: {search.best_params_}") |
|
return search.best_estimator_ |
|
|
|
|
|
def save_model(model, filename): |
|
joblib.dump(model, filename) |
|
print(f"Stored model in '{filename}'") |
|
|
|
|
|
def main(): |
|
X, y = get_data() |
|
model = get_model() |
|
hparams = get_hparams() |
|
model_trained = train(model, X, y, hparams=hparams) |
|
save_model(model_trained, FILENAME) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|