boris commited on
Commit
50498e6
1 Parent(s): 34cf91c

feat(train): save to bucket

Browse files
Files changed (1) hide show
  1. tools/train/train.py +70 -47
tools/train/train.py CHANGED
@@ -18,7 +18,7 @@ Training DALL·E Mini.
18
  Script adapted from run_summarization_flax.py
19
  """
20
 
21
- import json
22
  import logging
23
  import os
24
  import sys
@@ -41,6 +41,7 @@ from flax.core.frozen_dict import FrozenDict, freeze
41
  from flax.serialization import from_bytes, to_bytes
42
  from flax.training import train_state
43
  from flax.training.common_utils import onehot
 
44
  from jax.experimental import PartitionSpec, maps
45
  from jax.experimental.compilation_cache import compilation_cache as cc
46
  from jax.experimental.pjit import pjit, with_sharding_constraint
@@ -59,7 +60,6 @@ cc.initialize_cache(
59
  "/home/boris/dalle-mini/jax_cache", max_cache_size_bytes=5 * 2 ** 30
60
  )
61
 
62
-
63
  logger = logging.getLogger(__name__)
64
 
65
 
@@ -123,17 +123,20 @@ class ModelArguments:
123
  else:
124
  return dict()
125
 
126
- def get_opt_state(self, tmp_dir):
127
- if self.restore_state is True:
128
- # wandb artifact
129
- state_artifact = self.model_name_or_path.replace("/model-", "/state-", 1)
130
- if jax.process_index() == 0:
131
- artifact = wandb.run.use_artifact(state_artifact)
132
- else:
133
- artifact = wandb.Api().artifact(state_artifact)
134
- artifact_dir = artifact.download(tmp_dir)
135
- self.restore_state = Path(artifact_dir) / "opt_state.msgpack"
136
- return Path(self.restore_state).open("rb")
 
 
 
137
 
138
 
139
  @dataclass
@@ -785,10 +788,9 @@ def main():
785
 
786
  else:
787
  # load opt_state
788
- with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
789
- opt_state_file = model_args.get_opt_state(tmp_dir)
790
- opt_state = from_bytes(opt_state_shape, opt_state_file.read())
791
- opt_state_file.close()
792
 
793
  # restore other attributes
794
  attr_state = {
@@ -1034,42 +1036,60 @@ def main():
1034
 
1035
  def run_save_model(state, eval_metrics=None):
1036
  if jax.process_index() == 0:
 
 
 
 
 
 
 
 
 
 
1037
  params = jax.device_get(state.params)
1038
- # save model locally
1039
  model.save_pretrained(
1040
- training_args.output_dir,
1041
  params=params,
1042
  )
1043
 
1044
  # save tokenizer
1045
- tokenizer.save_pretrained(training_args.output_dir)
 
 
 
 
 
 
 
 
 
 
1046
 
1047
  # save state
1048
  opt_state = jax.device_get(state.opt_state)
1049
- with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
1050
- f.write(to_bytes(opt_state))
1051
- state_dict = {
1052
- k: jax.device_get(getattr(state, k)).item()
1053
- for k in ["step", "epoch", "train_time", "train_samples"]
1054
- }
1055
- with (Path(training_args.output_dir) / "training_state.json").open(
1056
- "w"
1057
- ) as f:
1058
- json.dump(
1059
- state_dict,
1060
- f,
1061
- )
1062
 
1063
  # save to W&B
1064
  if training_args.log_model:
1065
  # save some space
1066
  c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
1067
- c.cleanup(wandb.util.from_human_size("10GB"))
1068
 
1069
- metadata = dict(state_dict)
 
 
 
1070
  metadata["num_params"] = num_params
1071
  if eval_metrics is not None:
1072
  metadata["eval"] = eval_metrics
 
 
1073
 
1074
  # create model artifact
1075
  artifact = wandb.Artifact(
@@ -1077,16 +1097,19 @@ def main():
1077
  type="DalleBart_model",
1078
  metadata=metadata,
1079
  )
1080
- for filename in [
1081
- "config.json",
1082
- "flax_model.msgpack",
1083
- "merges.txt",
1084
- "special_tokens_map.json",
1085
- "tokenizer.json",
1086
- "tokenizer_config.json",
1087
- "vocab.json",
1088
- ]:
1089
- artifact.add_file(f"{Path(training_args.output_dir) / filename}")
 
 
 
1090
  wandb.run.log_artifact(artifact)
1091
 
1092
  # create state artifact
@@ -1095,9 +1118,9 @@ def main():
1095
  type="DalleBart_state",
1096
  metadata=metadata,
1097
  )
1098
- for filename in ["opt_state.msgpack", "training_state.json"]:
1099
  artifact_state.add_file(
1100
- f"{Path(training_args.output_dir) / filename}"
1101
  )
1102
  wandb.run.log_artifact(artifact_state)
1103
 
 
18
  Script adapted from run_summarization_flax.py
19
  """
20
 
21
+ import io
22
  import logging
23
  import os
24
  import sys
 
41
  from flax.serialization import from_bytes, to_bytes
42
  from flax.training import train_state
43
  from flax.training.common_utils import onehot
44
+ from google.cloud import storage
45
  from jax.experimental import PartitionSpec, maps
46
  from jax.experimental.compilation_cache import compilation_cache as cc
47
  from jax.experimental.pjit import pjit, with_sharding_constraint
 
60
  "/home/boris/dalle-mini/jax_cache", max_cache_size_bytes=5 * 2 ** 30
61
  )
62
 
 
63
  logger = logging.getLogger(__name__)
64
 
65
 
 
123
  else:
124
  return dict()
125
 
126
+ def get_opt_state(self):
127
+ with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
128
+ if self.restore_state is True:
129
+ # wandb artifact
130
+ state_artifact = self.model_name_or_path.replace(
131
+ "/model-", "/state-", 1
132
+ )
133
+ if jax.process_index() == 0:
134
+ artifact = wandb.run.use_artifact(state_artifact)
135
+ else:
136
+ artifact = wandb.Api().artifact(state_artifact)
137
+ artifact_dir = artifact.download(tmp_dir)
138
+ self.restore_state = Path(artifact_dir) / "opt_state.msgpack"
139
+ return Path(self.restore_state).open("rb")
140
 
141
 
142
  @dataclass
 
788
 
789
  else:
790
  # load opt_state
791
+ opt_state_file = model_args.get_opt_state()
792
+ opt_state = from_bytes(opt_state_shape, opt_state_file.read())
793
+ opt_state_file.close()
 
794
 
795
  # restore other attributes
796
  attr_state = {
 
1036
 
1037
  def run_save_model(state, eval_metrics=None):
1038
  if jax.process_index() == 0:
1039
+
1040
+ output_dir = training_args.output_dir
1041
+ use_bucket = output_dir.startswith("gs://")
1042
+ if use_bucket:
1043
+ bucket_path = Path(output_dir[5:]) / wandb.run.id / f"step_{state.step}"
1044
+ bucket, dir_path = str(bucket_path).split("/", 1)
1045
+ tmp_dir = tempfile.TemporaryDirectory()
1046
+ output_dir = tmp_dir.name
1047
+
1048
+ # save model
1049
  params = jax.device_get(state.params)
 
1050
  model.save_pretrained(
1051
+ output_dir,
1052
  params=params,
1053
  )
1054
 
1055
  # save tokenizer
1056
+ tokenizer.save_pretrained(output_dir)
1057
+
1058
+ # copy to bucket
1059
+ if use_bucket:
1060
+ client = storage.Client()
1061
+ bucket = client.bucket(bucket)
1062
+ for filename in Path(output_dir).glob("*"):
1063
+ blob_name = str(Path(dir_path) / filename.name)
1064
+ blob = bucket.blob(blob_name)
1065
+ blob.upload_from_filename(str(filename))
1066
+ tmp_dir.cleanup()
1067
 
1068
  # save state
1069
  opt_state = jax.device_get(state.opt_state)
1070
+ if use_bucket:
1071
+ blob_name = str(Path(dir_path) / "opt_state.msgpack")
1072
+ blob = bucket.blob(blob_name)
1073
+ blob.upload_from_file(io.BytesIO(to_bytes(opt_state)))
1074
+ else:
1075
+ with (Path(output_dir) / "opt_state.msgpack").open("wb") as f:
1076
+ f.write(to_bytes(opt_state))
 
 
 
 
 
 
1077
 
1078
  # save to W&B
1079
  if training_args.log_model:
1080
  # save some space
1081
  c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
1082
+ c.cleanup(wandb.util.from_human_size("20GB"))
1083
 
1084
+ metadata = {
1085
+ k: jax.device_get(getattr(state, k)).item()
1086
+ for k in ["step", "epoch", "train_time", "train_samples"]
1087
+ }
1088
  metadata["num_params"] = num_params
1089
  if eval_metrics is not None:
1090
  metadata["eval"] = eval_metrics
1091
+ if use_bucket:
1092
+ metadata["bucket_path"] = bucket_path
1093
 
1094
  # create model artifact
1095
  artifact = wandb.Artifact(
 
1097
  type="DalleBart_model",
1098
  metadata=metadata,
1099
  )
1100
+ if not use_bucket:
1101
+ for filename in [
1102
+ "config.json",
1103
+ "flax_model.msgpack",
1104
+ "merges.txt",
1105
+ "special_tokens_map.json",
1106
+ "tokenizer.json",
1107
+ "tokenizer_config.json",
1108
+ "vocab.json",
1109
+ ]:
1110
+ artifact.add_file(
1111
+ f"{Path(training_args.output_dir) / filename}"
1112
+ )
1113
  wandb.run.log_artifact(artifact)
1114
 
1115
  # create state artifact
 
1118
  type="DalleBart_state",
1119
  metadata=metadata,
1120
  )
1121
+ if not use_bucket:
1122
  artifact_state.add_file(
1123
+ f"{Path(training_args.output_dir) / 'opt_state.msgpack'}"
1124
  )
1125
  wandb.run.log_artifact(artifact_state)
1126