Spaces:
Running
Running
feat(train): save to bucket
Browse files- tools/train/train.py +70 -47
tools/train/train.py
CHANGED
@@ -18,7 +18,7 @@ Training DALL·E Mini.
|
|
18 |
Script adapted from run_summarization_flax.py
|
19 |
"""
|
20 |
|
21 |
-
import
|
22 |
import logging
|
23 |
import os
|
24 |
import sys
|
@@ -41,6 +41,7 @@ from flax.core.frozen_dict import FrozenDict, freeze
|
|
41 |
from flax.serialization import from_bytes, to_bytes
|
42 |
from flax.training import train_state
|
43 |
from flax.training.common_utils import onehot
|
|
|
44 |
from jax.experimental import PartitionSpec, maps
|
45 |
from jax.experimental.compilation_cache import compilation_cache as cc
|
46 |
from jax.experimental.pjit import pjit, with_sharding_constraint
|
@@ -59,7 +60,6 @@ cc.initialize_cache(
|
|
59 |
"/home/boris/dalle-mini/jax_cache", max_cache_size_bytes=5 * 2 ** 30
|
60 |
)
|
61 |
|
62 |
-
|
63 |
logger = logging.getLogger(__name__)
|
64 |
|
65 |
|
@@ -123,17 +123,20 @@ class ModelArguments:
|
|
123 |
else:
|
124 |
return dict()
|
125 |
|
126 |
-
def get_opt_state(self
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
|
|
|
|
|
|
137 |
|
138 |
|
139 |
@dataclass
|
@@ -785,10 +788,9 @@ def main():
|
|
785 |
|
786 |
else:
|
787 |
# load opt_state
|
788 |
-
|
789 |
-
|
790 |
-
|
791 |
-
opt_state_file.close()
|
792 |
|
793 |
# restore other attributes
|
794 |
attr_state = {
|
@@ -1034,42 +1036,60 @@ def main():
|
|
1034 |
|
1035 |
def run_save_model(state, eval_metrics=None):
|
1036 |
if jax.process_index() == 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1037 |
params = jax.device_get(state.params)
|
1038 |
-
# save model locally
|
1039 |
model.save_pretrained(
|
1040 |
-
|
1041 |
params=params,
|
1042 |
)
|
1043 |
|
1044 |
# save tokenizer
|
1045 |
-
tokenizer.save_pretrained(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1046 |
|
1047 |
# save state
|
1048 |
opt_state = jax.device_get(state.opt_state)
|
1049 |
-
|
1050 |
-
|
1051 |
-
|
1052 |
-
|
1053 |
-
|
1054 |
-
|
1055 |
-
|
1056 |
-
"w"
|
1057 |
-
) as f:
|
1058 |
-
json.dump(
|
1059 |
-
state_dict,
|
1060 |
-
f,
|
1061 |
-
)
|
1062 |
|
1063 |
# save to W&B
|
1064 |
if training_args.log_model:
|
1065 |
# save some space
|
1066 |
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
|
1067 |
-
c.cleanup(wandb.util.from_human_size("
|
1068 |
|
1069 |
-
metadata =
|
|
|
|
|
|
|
1070 |
metadata["num_params"] = num_params
|
1071 |
if eval_metrics is not None:
|
1072 |
metadata["eval"] = eval_metrics
|
|
|
|
|
1073 |
|
1074 |
# create model artifact
|
1075 |
artifact = wandb.Artifact(
|
@@ -1077,16 +1097,19 @@ def main():
|
|
1077 |
type="DalleBart_model",
|
1078 |
metadata=metadata,
|
1079 |
)
|
1080 |
-
|
1081 |
-
|
1082 |
-
|
1083 |
-
|
1084 |
-
|
1085 |
-
|
1086 |
-
|
1087 |
-
|
1088 |
-
|
1089 |
-
|
|
|
|
|
|
|
1090 |
wandb.run.log_artifact(artifact)
|
1091 |
|
1092 |
# create state artifact
|
@@ -1095,9 +1118,9 @@ def main():
|
|
1095 |
type="DalleBart_state",
|
1096 |
metadata=metadata,
|
1097 |
)
|
1098 |
-
|
1099 |
artifact_state.add_file(
|
1100 |
-
f"{Path(training_args.output_dir) /
|
1101 |
)
|
1102 |
wandb.run.log_artifact(artifact_state)
|
1103 |
|
|
|
18 |
Script adapted from run_summarization_flax.py
|
19 |
"""
|
20 |
|
21 |
+
import io
|
22 |
import logging
|
23 |
import os
|
24 |
import sys
|
|
|
41 |
from flax.serialization import from_bytes, to_bytes
|
42 |
from flax.training import train_state
|
43 |
from flax.training.common_utils import onehot
|
44 |
+
from google.cloud import storage
|
45 |
from jax.experimental import PartitionSpec, maps
|
46 |
from jax.experimental.compilation_cache import compilation_cache as cc
|
47 |
from jax.experimental.pjit import pjit, with_sharding_constraint
|
|
|
60 |
"/home/boris/dalle-mini/jax_cache", max_cache_size_bytes=5 * 2 ** 30
|
61 |
)
|
62 |
|
|
|
63 |
logger = logging.getLogger(__name__)
|
64 |
|
65 |
|
|
|
123 |
else:
|
124 |
return dict()
|
125 |
|
126 |
+
def get_opt_state(self):
|
127 |
+
with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
|
128 |
+
if self.restore_state is True:
|
129 |
+
# wandb artifact
|
130 |
+
state_artifact = self.model_name_or_path.replace(
|
131 |
+
"/model-", "/state-", 1
|
132 |
+
)
|
133 |
+
if jax.process_index() == 0:
|
134 |
+
artifact = wandb.run.use_artifact(state_artifact)
|
135 |
+
else:
|
136 |
+
artifact = wandb.Api().artifact(state_artifact)
|
137 |
+
artifact_dir = artifact.download(tmp_dir)
|
138 |
+
self.restore_state = Path(artifact_dir) / "opt_state.msgpack"
|
139 |
+
return Path(self.restore_state).open("rb")
|
140 |
|
141 |
|
142 |
@dataclass
|
|
|
788 |
|
789 |
else:
|
790 |
# load opt_state
|
791 |
+
opt_state_file = model_args.get_opt_state()
|
792 |
+
opt_state = from_bytes(opt_state_shape, opt_state_file.read())
|
793 |
+
opt_state_file.close()
|
|
|
794 |
|
795 |
# restore other attributes
|
796 |
attr_state = {
|
|
|
1036 |
|
1037 |
def run_save_model(state, eval_metrics=None):
|
1038 |
if jax.process_index() == 0:
|
1039 |
+
|
1040 |
+
output_dir = training_args.output_dir
|
1041 |
+
use_bucket = output_dir.startswith("gs://")
|
1042 |
+
if use_bucket:
|
1043 |
+
bucket_path = Path(output_dir[5:]) / wandb.run.id / f"step_{state.step}"
|
1044 |
+
bucket, dir_path = str(bucket_path).split("/", 1)
|
1045 |
+
tmp_dir = tempfile.TemporaryDirectory()
|
1046 |
+
output_dir = tmp_dir.name
|
1047 |
+
|
1048 |
+
# save model
|
1049 |
params = jax.device_get(state.params)
|
|
|
1050 |
model.save_pretrained(
|
1051 |
+
output_dir,
|
1052 |
params=params,
|
1053 |
)
|
1054 |
|
1055 |
# save tokenizer
|
1056 |
+
tokenizer.save_pretrained(output_dir)
|
1057 |
+
|
1058 |
+
# copy to bucket
|
1059 |
+
if use_bucket:
|
1060 |
+
client = storage.Client()
|
1061 |
+
bucket = client.bucket(bucket)
|
1062 |
+
for filename in Path(output_dir).glob("*"):
|
1063 |
+
blob_name = str(Path(dir_path) / filename.name)
|
1064 |
+
blob = bucket.blob(blob_name)
|
1065 |
+
blob.upload_from_filename(str(filename))
|
1066 |
+
tmp_dir.cleanup()
|
1067 |
|
1068 |
# save state
|
1069 |
opt_state = jax.device_get(state.opt_state)
|
1070 |
+
if use_bucket:
|
1071 |
+
blob_name = str(Path(dir_path) / "opt_state.msgpack")
|
1072 |
+
blob = bucket.blob(blob_name)
|
1073 |
+
blob.upload_from_file(io.BytesIO(to_bytes(opt_state)))
|
1074 |
+
else:
|
1075 |
+
with (Path(output_dir) / "opt_state.msgpack").open("wb") as f:
|
1076 |
+
f.write(to_bytes(opt_state))
|
|
|
|
|
|
|
|
|
|
|
|
|
1077 |
|
1078 |
# save to W&B
|
1079 |
if training_args.log_model:
|
1080 |
# save some space
|
1081 |
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
|
1082 |
+
c.cleanup(wandb.util.from_human_size("20GB"))
|
1083 |
|
1084 |
+
metadata = {
|
1085 |
+
k: jax.device_get(getattr(state, k)).item()
|
1086 |
+
for k in ["step", "epoch", "train_time", "train_samples"]
|
1087 |
+
}
|
1088 |
metadata["num_params"] = num_params
|
1089 |
if eval_metrics is not None:
|
1090 |
metadata["eval"] = eval_metrics
|
1091 |
+
if use_bucket:
|
1092 |
+
metadata["bucket_path"] = bucket_path
|
1093 |
|
1094 |
# create model artifact
|
1095 |
artifact = wandb.Artifact(
|
|
|
1097 |
type="DalleBart_model",
|
1098 |
metadata=metadata,
|
1099 |
)
|
1100 |
+
if not use_bucket:
|
1101 |
+
for filename in [
|
1102 |
+
"config.json",
|
1103 |
+
"flax_model.msgpack",
|
1104 |
+
"merges.txt",
|
1105 |
+
"special_tokens_map.json",
|
1106 |
+
"tokenizer.json",
|
1107 |
+
"tokenizer_config.json",
|
1108 |
+
"vocab.json",
|
1109 |
+
]:
|
1110 |
+
artifact.add_file(
|
1111 |
+
f"{Path(training_args.output_dir) / filename}"
|
1112 |
+
)
|
1113 |
wandb.run.log_artifact(artifact)
|
1114 |
|
1115 |
# create state artifact
|
|
|
1118 |
type="DalleBart_state",
|
1119 |
metadata=metadata,
|
1120 |
)
|
1121 |
+
if not use_bucket:
|
1122 |
artifact_state.add_file(
|
1123 |
+
f"{Path(training_args.output_dir) / 'opt_state.msgpack'}"
|
1124 |
)
|
1125 |
wandb.run.log_artifact(artifact_state)
|
1126 |
|