File size: 3,593 Bytes
1c300e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e417b0c
1c300e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e417b0c
1c300e3
1ced76b
1c300e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e417b0c
 
1c300e3
 
 
 
 
 
 
 
 
 
 
 
1ced76b
e417b0c
 
 
 
1c300e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ced76b
1c300e3
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""Script to run sagemaker training jobs for whisper finetuning jobs."""

import logging
import os
from pprint import pprint

import boto3
import sagemaker
from sagemaker.huggingface import HuggingFace


TEST = True


test_sm_instances = {
    "ml.g4dn.xlarge":
        {
            "num_instances": 1,
            "num_gpus": 1
        }
}

full_sm_instances = {
    "ml.g4dn.xlarge":
        {
            "num_instances": 1,
            "num_gpus": 1
        }
}

sm_instances = test_sm_instances if TEST else full_sm_instances

ENTRY_POINT = "run_sm.py"
RUN_SCRIPT = "test_run.sh" if TEST else "run.sh"
IMAGE_URI = "116817510867.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-training:whisper-finetuning-0223e276db78adf4ea4dc5f874793cb2"
if IMAGE_URI is None:
    raise ValueError("IMAGE_URI variable not set, please update script.")

iam = boto3.client("iam")
os.environ["AWS_DEFAULT_REGION"] = "eu-west-1"
role = iam.get_role(RoleName="whisper-sagemaker-role")["Role"]["Arn"]
_ = sagemaker.Session()  # not sure if this is necessary
sm_client = boto3.client("sagemaker")


def set_creds():
    with open("creds.txt") as f:
        creds = f.readlines()
        for line in creds:
            key, value = line.split("=")
            os.environ[key] = value.replace("\n", "")


def parse_run_script():
    """Parse the run script to get the hyperparameters."""
    hyperparameters = {}
    with open(RUN_SCRIPT, "r") as f:
        for line in f.readlines():
            if line.startswith("python"):
                continue
            line = line \
                .replace("\\", "") \
                .replace("\t", "") \
                .replace("--", "") \
                .replace(" \n", "") \
                .replace("\n", "") \
                .replace('"', "")
            line = line.split("=")
            key = str(line[0])
            try:
                value = line[1]
            except IndexError:
                value = "True"
            hyperparameters[key] = value
    hyperparameters["model_index_name"] = f'"{hyperparameters["model_index_name"]}"'
    return hyperparameters


set_creds()
# hyperparameters = parse_run_script()
# pprint(hyperparameters)

hf_token = os.environ.get("HF_TOKEN")
if hf_token is None:
    raise ValueError("HF_TOKEN environment variable not set")

env_vars = {
    "HF_TOKEN": hf_token,
    "EMAIL_ADDRESS": os.environ.get("EMAIL_ADDRESS"),
    "EMAIL_PASSWORD": os.environ.get("EMAIL_PASSWORD"),
    "WANDB_TOKEN": os.environ.get("WANDB_TOKEN")
}
pprint(env_vars)
repo = f"https://huggingface.co/marinone94/{os.getcwd().split('/')[-1]}"
hyperparameters = {
    "repo": repo,
    "entrypoint": RUN_SCRIPT
}
for sm_instance_name, sm_instance_values in sm_instances.items():
        num_instances: int = \
            int(sm_instance_values["num_instances"])
        num_gpus: int = \
            int(sm_instance_values["num_gpus"])
        try:
            # instantiate and fit the sm Estimator
            hf_estimator = HuggingFace(
                entry_point=ENTRY_POINT,
                instance_type=sm_instance_name,
                instance_count=num_instances,
                role=role,
                py_version="py38",
                image_uri=IMAGE_URI,
                hyperparameters=hyperparameters,
                environment=env_vars,
                git_config={"repo": repo, "branch": "main"},
            )
            hf_estimator.fit()
            break
        except sm_client.exceptions.ResourceLimitExceeded as e_0:
            logging.warning(f"Instance error {e_0}\nRetrying with new instance")