Spaces:
Running
Running
feat: load from bucket
Browse files- src/dalle_mini/model/utils.py +26 -3
- tools/train/train.py +22 -9
src/dalle_mini/model/utils.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import os
|
2 |
import tempfile
|
|
|
3 |
|
4 |
import wandb
|
5 |
|
@@ -8,11 +9,13 @@ class PretrainedFromWandbMixin:
|
|
8 |
@classmethod
|
9 |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
10 |
"""
|
11 |
-
Initializes from a wandb artifact, or delegates loading to the superclass.
|
12 |
"""
|
13 |
with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
|
14 |
-
if
|
15 |
-
pretrained_model_name_or_path
|
|
|
|
|
16 |
):
|
17 |
# wandb artifact
|
18 |
if wandb.run is not None:
|
@@ -20,7 +23,27 @@ class PretrainedFromWandbMixin:
|
|
20 |
else:
|
21 |
artifact = wandb.Api().artifact(pretrained_model_name_or_path)
|
22 |
pretrained_model_name_or_path = artifact.download(tmp_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
return super(PretrainedFromWandbMixin, cls).from_pretrained(
|
25 |
pretrained_model_name_or_path, *model_args, **kwargs
|
26 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import tempfile
|
3 |
+
from pathlib import Path
|
4 |
|
5 |
import wandb
|
6 |
|
|
|
9 |
@classmethod
|
10 |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
11 |
"""
|
12 |
+
Initializes from a wandb artifact, google bucket path or delegates loading to the superclass.
|
13 |
"""
|
14 |
with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
|
15 |
+
if (
|
16 |
+
":" in pretrained_model_name_or_path
|
17 |
+
and not os.path.isdir(pretrained_model_name_or_path)
|
18 |
+
and not pretrained_model_name_or_path.startswith("gs")
|
19 |
):
|
20 |
# wandb artifact
|
21 |
if wandb.run is not None:
|
|
|
23 |
else:
|
24 |
artifact = wandb.Api().artifact(pretrained_model_name_or_path)
|
25 |
pretrained_model_name_or_path = artifact.download(tmp_dir)
|
26 |
+
if artifact.metadata.get("bucket_path"):
|
27 |
+
pretrained_model_name_or_path = artifact.metadata["bucket_path"]
|
28 |
+
|
29 |
+
if pretrained_model_name_or_path.startswith("gs://"):
|
30 |
+
copy_blobs(pretrained_model_name_or_path, tmp_dir)
|
31 |
+
pretrained_model_name_or_path = tmp_dir
|
32 |
|
33 |
return super(PretrainedFromWandbMixin, cls).from_pretrained(
|
34 |
pretrained_model_name_or_path, *model_args, **kwargs
|
35 |
)
|
36 |
+
|
37 |
+
|
38 |
+
def copy_blobs(source_path, dest_path):
|
39 |
+
assert source_path.startswith("gs://")
|
40 |
+
from google.cloud import storage
|
41 |
+
|
42 |
+
bucket_path = Path(source_path[5:])
|
43 |
+
bucket, dir_path = str(bucket_path).split("/", 1)
|
44 |
+
client = storage.Client()
|
45 |
+
bucket = client.bucket(bucket)
|
46 |
+
blobs = client.list_blobs(bucket, prefix=f"{dir_path}/")
|
47 |
+
for blob in blobs:
|
48 |
+
dest_name = str(Path(dest_path) / Path(blob.name).name)
|
49 |
+
blob.download_to_filename(dest_name)
|
tools/train/train.py
CHANGED
@@ -135,8 +135,21 @@ class ModelArguments:
|
|
135 |
else:
|
136 |
artifact = wandb.Api().artifact(state_artifact)
|
137 |
artifact_dir = artifact.download(tmp_dir)
|
138 |
-
|
139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
|
142 |
@dataclass
|
@@ -788,9 +801,7 @@ def main():
|
|
788 |
|
789 |
else:
|
790 |
# load opt_state
|
791 |
-
|
792 |
-
opt_state = from_bytes(opt_state_shape, opt_state_file.read())
|
793 |
-
opt_state_file.close()
|
794 |
|
795 |
# restore other attributes
|
796 |
attr_state = {
|
@@ -1060,7 +1071,7 @@ def main():
|
|
1060 |
client = storage.Client()
|
1061 |
bucket = client.bucket(bucket)
|
1062 |
for filename in Path(output_dir).glob("*"):
|
1063 |
-
blob_name = str(Path(dir_path) / filename.name)
|
1064 |
blob = bucket.blob(blob_name)
|
1065 |
blob.upload_from_filename(str(filename))
|
1066 |
tmp_dir.cleanup()
|
@@ -1068,7 +1079,7 @@ def main():
|
|
1068 |
# save state
|
1069 |
opt_state = jax.device_get(state.opt_state)
|
1070 |
if use_bucket:
|
1071 |
-
blob_name = str(Path(dir_path) / "opt_state.msgpack")
|
1072 |
blob = bucket.blob(blob_name)
|
1073 |
blob.upload_from_file(io.BytesIO(to_bytes(opt_state)))
|
1074 |
else:
|
@@ -1088,10 +1099,10 @@ def main():
|
|
1088 |
metadata["num_params"] = num_params
|
1089 |
if eval_metrics is not None:
|
1090 |
metadata["eval"] = eval_metrics
|
1091 |
-
if use_bucket:
|
1092 |
-
metadata["bucket_path"] = bucket_path
|
1093 |
|
1094 |
# create model artifact
|
|
|
|
|
1095 |
artifact = wandb.Artifact(
|
1096 |
name=f"model-{wandb.run.id}",
|
1097 |
type="DalleBart_model",
|
@@ -1113,6 +1124,8 @@ def main():
|
|
1113 |
wandb.run.log_artifact(artifact)
|
1114 |
|
1115 |
# create state artifact
|
|
|
|
|
1116 |
artifact_state = wandb.Artifact(
|
1117 |
name=f"state-{wandb.run.id}",
|
1118 |
type="DalleBart_state",
|
|
|
135 |
else:
|
136 |
artifact = wandb.Api().artifact(state_artifact)
|
137 |
artifact_dir = artifact.download(tmp_dir)
|
138 |
+
if artifact.metadata.get("bucket_path"):
|
139 |
+
self.restore_state = artifact.metadata["bucket_path"]
|
140 |
+
else:
|
141 |
+
self.restore_state = Path(artifact_dir) / "opt_state.msgpack"
|
142 |
+
|
143 |
+
if self.restore_state.startswith("gs://"):
|
144 |
+
bucket_path = Path(self.restore_state[5:]) / "opt_state.msgpack"
|
145 |
+
bucket, blob_name = str(bucket_path).split("/", 1)
|
146 |
+
client = storage.Client()
|
147 |
+
bucket = client.bucket(bucket)
|
148 |
+
blob = bucket.blob(blob_name)
|
149 |
+
return blob.download_as_bytes()
|
150 |
+
|
151 |
+
with Path(self.restore_state).open("rb") as f:
|
152 |
+
return f.read()
|
153 |
|
154 |
|
155 |
@dataclass
|
|
|
801 |
|
802 |
else:
|
803 |
# load opt_state
|
804 |
+
opt_state = from_bytes(opt_state_shape, model_args.get_opt_state())
|
|
|
|
|
805 |
|
806 |
# restore other attributes
|
807 |
attr_state = {
|
|
|
1071 |
client = storage.Client()
|
1072 |
bucket = client.bucket(bucket)
|
1073 |
for filename in Path(output_dir).glob("*"):
|
1074 |
+
blob_name = str(Path(dir_path) / "model" / filename.name)
|
1075 |
blob = bucket.blob(blob_name)
|
1076 |
blob.upload_from_filename(str(filename))
|
1077 |
tmp_dir.cleanup()
|
|
|
1079 |
# save state
|
1080 |
opt_state = jax.device_get(state.opt_state)
|
1081 |
if use_bucket:
|
1082 |
+
blob_name = str(Path(dir_path) / "state" / "opt_state.msgpack")
|
1083 |
blob = bucket.blob(blob_name)
|
1084 |
blob.upload_from_file(io.BytesIO(to_bytes(opt_state)))
|
1085 |
else:
|
|
|
1099 |
metadata["num_params"] = num_params
|
1100 |
if eval_metrics is not None:
|
1101 |
metadata["eval"] = eval_metrics
|
|
|
|
|
1102 |
|
1103 |
# create model artifact
|
1104 |
+
if use_bucket:
|
1105 |
+
metadata["bucket_path"] = f"gs://{bucket_path}/model"
|
1106 |
artifact = wandb.Artifact(
|
1107 |
name=f"model-{wandb.run.id}",
|
1108 |
type="DalleBart_model",
|
|
|
1124 |
wandb.run.log_artifact(artifact)
|
1125 |
|
1126 |
# create state artifact
|
1127 |
+
if use_bucket:
|
1128 |
+
metadata["bucket_path"] = f"gs://{bucket_path}/state"
|
1129 |
artifact_state = wandb.Artifact(
|
1130 |
name=f"state-{wandb.run.id}",
|
1131 |
type="DalleBart_state",
|