boris commited on
Commit
1c4e839
1 Parent(s): 50498e6

feat: load from bucket

Browse files
src/dalle_mini/model/utils.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import tempfile
 
3
 
4
  import wandb
5
 
@@ -8,11 +9,13 @@ class PretrainedFromWandbMixin:
8
  @classmethod
9
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
10
  """
11
- Initializes from a wandb artifact, or delegates loading to the superclass.
12
  """
13
  with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
14
- if ":" in pretrained_model_name_or_path and not os.path.isdir(
15
- pretrained_model_name_or_path
 
 
16
  ):
17
  # wandb artifact
18
  if wandb.run is not None:
@@ -20,7 +23,27 @@ class PretrainedFromWandbMixin:
20
  else:
21
  artifact = wandb.Api().artifact(pretrained_model_name_or_path)
22
  pretrained_model_name_or_path = artifact.download(tmp_dir)
 
 
 
 
 
 
23
 
24
  return super(PretrainedFromWandbMixin, cls).from_pretrained(
25
  pretrained_model_name_or_path, *model_args, **kwargs
26
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import tempfile
3
+ from pathlib import Path
4
 
5
  import wandb
6
 
 
9
  @classmethod
10
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
11
  """
12
+ Initializes from a wandb artifact, google bucket path or delegates loading to the superclass.
13
  """
14
  with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
15
+ if (
16
+ ":" in pretrained_model_name_or_path
17
+ and not os.path.isdir(pretrained_model_name_or_path)
18
+ and not pretrained_model_name_or_path.startswith("gs")
19
  ):
20
  # wandb artifact
21
  if wandb.run is not None:
 
23
  else:
24
  artifact = wandb.Api().artifact(pretrained_model_name_or_path)
25
  pretrained_model_name_or_path = artifact.download(tmp_dir)
26
+ if artifact.metadata.get("bucket_path"):
27
+ pretrained_model_name_or_path = artifact.metadata["bucket_path"]
28
+
29
+ if pretrained_model_name_or_path.startswith("gs://"):
30
+ copy_blobs(pretrained_model_name_or_path, tmp_dir)
31
+ pretrained_model_name_or_path = tmp_dir
32
 
33
  return super(PretrainedFromWandbMixin, cls).from_pretrained(
34
  pretrained_model_name_or_path, *model_args, **kwargs
35
  )
36
+
37
+
38
+ def copy_blobs(source_path, dest_path):
39
+ assert source_path.startswith("gs://")
40
+ from google.cloud import storage
41
+
42
+ bucket_path = Path(source_path[5:])
43
+ bucket, dir_path = str(bucket_path).split("/", 1)
44
+ client = storage.Client()
45
+ bucket = client.bucket(bucket)
46
+ blobs = client.list_blobs(bucket, prefix=f"{dir_path}/")
47
+ for blob in blobs:
48
+ dest_name = str(Path(dest_path) / Path(blob.name).name)
49
+ blob.download_to_filename(dest_name)
tools/train/train.py CHANGED
@@ -135,8 +135,21 @@ class ModelArguments:
135
  else:
136
  artifact = wandb.Api().artifact(state_artifact)
137
  artifact_dir = artifact.download(tmp_dir)
138
- self.restore_state = Path(artifact_dir) / "opt_state.msgpack"
139
- return Path(self.restore_state).open("rb")
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
 
142
  @dataclass
@@ -788,9 +801,7 @@ def main():
788
 
789
  else:
790
  # load opt_state
791
- opt_state_file = model_args.get_opt_state()
792
- opt_state = from_bytes(opt_state_shape, opt_state_file.read())
793
- opt_state_file.close()
794
 
795
  # restore other attributes
796
  attr_state = {
@@ -1060,7 +1071,7 @@ def main():
1060
  client = storage.Client()
1061
  bucket = client.bucket(bucket)
1062
  for filename in Path(output_dir).glob("*"):
1063
- blob_name = str(Path(dir_path) / filename.name)
1064
  blob = bucket.blob(blob_name)
1065
  blob.upload_from_filename(str(filename))
1066
  tmp_dir.cleanup()
@@ -1068,7 +1079,7 @@ def main():
1068
  # save state
1069
  opt_state = jax.device_get(state.opt_state)
1070
  if use_bucket:
1071
- blob_name = str(Path(dir_path) / "opt_state.msgpack")
1072
  blob = bucket.blob(blob_name)
1073
  blob.upload_from_file(io.BytesIO(to_bytes(opt_state)))
1074
  else:
@@ -1088,10 +1099,10 @@ def main():
1088
  metadata["num_params"] = num_params
1089
  if eval_metrics is not None:
1090
  metadata["eval"] = eval_metrics
1091
- if use_bucket:
1092
- metadata["bucket_path"] = bucket_path
1093
 
1094
  # create model artifact
 
 
1095
  artifact = wandb.Artifact(
1096
  name=f"model-{wandb.run.id}",
1097
  type="DalleBart_model",
@@ -1113,6 +1124,8 @@ def main():
1113
  wandb.run.log_artifact(artifact)
1114
 
1115
  # create state artifact
 
 
1116
  artifact_state = wandb.Artifact(
1117
  name=f"state-{wandb.run.id}",
1118
  type="DalleBart_state",
 
135
  else:
136
  artifact = wandb.Api().artifact(state_artifact)
137
  artifact_dir = artifact.download(tmp_dir)
138
+ if artifact.metadata.get("bucket_path"):
139
+ self.restore_state = artifact.metadata["bucket_path"]
140
+ else:
141
+ self.restore_state = Path(artifact_dir) / "opt_state.msgpack"
142
+
143
+ if self.restore_state.startswith("gs://"):
144
+ bucket_path = Path(self.restore_state[5:]) / "opt_state.msgpack"
145
+ bucket, blob_name = str(bucket_path).split("/", 1)
146
+ client = storage.Client()
147
+ bucket = client.bucket(bucket)
148
+ blob = bucket.blob(blob_name)
149
+ return blob.download_as_bytes()
150
+
151
+ with Path(self.restore_state).open("rb") as f:
152
+ return f.read()
153
 
154
 
155
  @dataclass
 
801
 
802
  else:
803
  # load opt_state
804
+ opt_state = from_bytes(opt_state_shape, model_args.get_opt_state())
 
 
805
 
806
  # restore other attributes
807
  attr_state = {
 
1071
  client = storage.Client()
1072
  bucket = client.bucket(bucket)
1073
  for filename in Path(output_dir).glob("*"):
1074
+ blob_name = str(Path(dir_path) / "model" / filename.name)
1075
  blob = bucket.blob(blob_name)
1076
  blob.upload_from_filename(str(filename))
1077
  tmp_dir.cleanup()
 
1079
  # save state
1080
  opt_state = jax.device_get(state.opt_state)
1081
  if use_bucket:
1082
+ blob_name = str(Path(dir_path) / "state" / "opt_state.msgpack")
1083
  blob = bucket.blob(blob_name)
1084
  blob.upload_from_file(io.BytesIO(to_bytes(opt_state)))
1085
  else:
 
1099
  metadata["num_params"] = num_params
1100
  if eval_metrics is not None:
1101
  metadata["eval"] = eval_metrics
 
 
1102
 
1103
  # create model artifact
1104
+ if use_bucket:
1105
+ metadata["bucket_path"] = f"gs://{bucket_path}/model"
1106
  artifact = wandb.Artifact(
1107
  name=f"model-{wandb.run.id}",
1108
  type="DalleBart_model",
 
1124
  wandb.run.log_artifact(artifact)
1125
 
1126
  # create state artifact
1127
+ if use_bucket:
1128
+ metadata["bucket_path"] = f"gs://{bucket_path}/state"
1129
  artifact_state = wandb.Artifact(
1130
  name=f"state-{wandb.run.id}",
1131
  type="DalleBart_state",