File size: 3,593 Bytes
1c300e3 e417b0c 1c300e3 e417b0c 1c300e3 1ced76b 1c300e3 e417b0c 1c300e3 1ced76b e417b0c 1c300e3 1ced76b 1c300e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
"""Script to run sagemaker training jobs for whisper finetuning jobs."""
import logging
import os
from pprint import pprint
import boto3
import sagemaker
from sagemaker.huggingface import HuggingFace
TEST = True
test_sm_instances = {
"ml.g4dn.xlarge":
{
"num_instances": 1,
"num_gpus": 1
}
}
full_sm_instances = {
"ml.g4dn.xlarge":
{
"num_instances": 1,
"num_gpus": 1
}
}
sm_instances = test_sm_instances if TEST else full_sm_instances
ENTRY_POINT = "run_sm.py"
RUN_SCRIPT = "test_run.sh" if TEST else "run.sh"
IMAGE_URI = "116817510867.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-training:whisper-finetuning-0223e276db78adf4ea4dc5f874793cb2"
if IMAGE_URI is None:
raise ValueError("IMAGE_URI variable not set, please update script.")
iam = boto3.client("iam")
os.environ["AWS_DEFAULT_REGION"] = "eu-west-1"
role = iam.get_role(RoleName="whisper-sagemaker-role")["Role"]["Arn"]
_ = sagemaker.Session() # not sure if this is necessary
sm_client = boto3.client("sagemaker")
def set_creds():
with open("creds.txt") as f:
creds = f.readlines()
for line in creds:
key, value = line.split("=")
os.environ[key] = value.replace("\n", "")
def parse_run_script():
"""Parse the run script to get the hyperparameters."""
hyperparameters = {}
with open(RUN_SCRIPT, "r") as f:
for line in f.readlines():
if line.startswith("python"):
continue
line = line \
.replace("\\", "") \
.replace("\t", "") \
.replace("--", "") \
.replace(" \n", "") \
.replace("\n", "") \
.replace('"', "")
line = line.split("=")
key = str(line[0])
try:
value = line[1]
except IndexError:
value = "True"
hyperparameters[key] = value
hyperparameters["model_index_name"] = f'"{hyperparameters["model_index_name"]}"'
return hyperparameters
set_creds()
# hyperparameters = parse_run_script()
# pprint(hyperparameters)
hf_token = os.environ.get("HF_TOKEN")
if hf_token is None:
raise ValueError("HF_TOKEN environment variable not set")
env_vars = {
"HF_TOKEN": hf_token,
"EMAIL_ADDRESS": os.environ.get("EMAIL_ADDRESS"),
"EMAIL_PASSWORD": os.environ.get("EMAIL_PASSWORD"),
"WANDB_TOKEN": os.environ.get("WANDB_TOKEN")
}
pprint(env_vars)
repo = f"https://huggingface.co/marinone94/{os.getcwd().split('/')[-1]}"
hyperparameters = {
"repo": repo,
"entrypoint": RUN_SCRIPT
}
for sm_instance_name, sm_instance_values in sm_instances.items():
num_instances: int = \
int(sm_instance_values["num_instances"])
num_gpus: int = \
int(sm_instance_values["num_gpus"])
try:
# instantiate and fit the sm Estimator
hf_estimator = HuggingFace(
entry_point=ENTRY_POINT,
instance_type=sm_instance_name,
instance_count=num_instances,
role=role,
py_version="py38",
image_uri=IMAGE_URI,
hyperparameters=hyperparameters,
environment=env_vars,
git_config={"repo": repo, "branch": "main"},
)
hf_estimator.fit()
break
except sm_client.exceptions.ResourceLimitExceeded as e_0:
logging.warning(f"Instance error {e_0}\nRetrying with new instance")
|