ant_test4 / git.diff
andrewzhang505's picture
Upload . with huggingface_hub
4cffbfd
raw
history blame
8.12 kB
diff --git a/docs/10-huggingface/huggingface.md b/docs/10-huggingface/huggingface.md
index 8846da73..1f1fae6f 100644
--- a/docs/10-huggingface/huggingface.md
+++ b/docs/10-huggingface/huggingface.md
@@ -77,10 +77,16 @@ You can also save a video of the model during evaluation to upload to the hub wi
- `--video_name`: The name of the video to save as. If `None`, will save to `replay.mp4` in your experiment directory
+Also, you can include information in the Hugging Face Hub model card for how to train and enjoy using this model. These parameters are optional:
+
+- `--train_script`: The module path for training this model
+
+- `--enjoy_script`: The module path for enjoying this model
+
For example:
```
-python -m sf_examples.mujoco.enjoy_mujoco --algo=APPO --env=mujoco_ant --experiment=<repo_name> --train_dir=./train_dir --max_num_episodes=10 --push_to_hub --hf_repository=<username>/<hf_repo_name> --save_video --no_render
+python -m sf_examples.mujoco.enjoy_mujoco --algo=APPO --env=mujoco_ant --experiment=<repo_name> --train_dir=./train_dir --max_num_episodes=10 --push_to_hub --hf_repository=<username>/<hf_repo_name> --save_video --no_render --enjoy_script=sf_examples.mujoco.enjoy_mujoco --train_script=sf_examples.mujoco.train_mujoco
```
#### Using the push_to_hub Script
@@ -95,4 +101,6 @@ The command line arguments are:
- `-r`: The repo_id to save on HF Hub. This is the same as `hf_repository` in the enjoy script and must be in the form `<hf_username>/<hf_repo_name>`
-- `-d`: The full path to your experiment directory to upload
\ No newline at end of file
+- `-d`: The full path to your experiment directory to upload
+
+The optional arguments of `--train_script` and `--enjoy_script` can also be used. See the above section for more details
\ No newline at end of file
diff --git a/sample_factory/cfg/arguments.py b/sample_factory/cfg/arguments.py
index 820efce6..f736342d 100644
--- a/sample_factory/cfg/arguments.py
+++ b/sample_factory/cfg/arguments.py
@@ -18,7 +18,7 @@ from sample_factory.cfg.cfg import (
)
from sample_factory.utils.attr_dict import AttrDict
from sample_factory.utils.typing import Config
-from sample_factory.utils.utils import cfg_file, cfg_file_old, get_git_commit_hash, get_top_level_script, log
+from sample_factory.utils.utils import cfg_file, cfg_file_old, get_git_commit_hash, log
def parse_sf_args(
@@ -91,7 +91,6 @@ def postprocess_args(args, argv, parser) -> argparse.Namespace:
args.cli_args = vars(cli_args)
args.git_hash, args.git_repo_name = get_git_commit_hash()
- args.train_script = get_top_level_script()
return args
diff --git a/sample_factory/cfg/cfg.py b/sample_factory/cfg/cfg.py
index 43393da1..360e6895 100644
--- a/sample_factory/cfg/cfg.py
+++ b/sample_factory/cfg/cfg.py
@@ -675,6 +675,19 @@ def add_eval_args(parser):
help="False to sample from action distributions at test time. True to just use the argmax.",
)
+ parser.add_argument(
+ "--train_script",
+ default=None,
+ type=str,
+ help="Module name used to run training script. Used to generate HF model card",
+ )
+ parser.add_argument(
+ "--enjoy_script",
+ default=None,
+ type=str,
+ help="Module name used to run training script. Used to generate HF model card",
+ )
+
def add_wandb_args(p: ArgumentParser):
"""Weights and Biases experiment monitoring."""
diff --git a/sample_factory/enjoy.py b/sample_factory/enjoy.py
index 341b537b..b620c532 100644
--- a/sample_factory/enjoy.py
+++ b/sample_factory/enjoy.py
@@ -21,7 +21,7 @@ from sample_factory.model.actor_critic import create_actor_critic
from sample_factory.model.model_utils import get_rnn_size
from sample_factory.utils.attr_dict import AttrDict
from sample_factory.utils.typing import Config, StatusCode
-from sample_factory.utils.utils import debug_log_every_n, experiment_dir, get_top_level_script, log
+from sample_factory.utils.utils import debug_log_every_n, experiment_dir, log
def visualize_policy_inputs(normalized_obs: Dict[str, Tensor]) -> None:
@@ -260,9 +260,8 @@ def enjoy(cfg: Config) -> Tuple[StatusCode, float]:
generate_replay_video(experiment_dir(cfg=cfg), video_frames, fps)
if cfg.push_to_hub:
- enjoy_name = get_top_level_script()
generate_model_card(
- experiment_dir(cfg=cfg), cfg.algo, cfg.env, cfg.hf_repository, reward_list, enjoy_name, cfg.train_script
+ experiment_dir(cfg=cfg), cfg.algo, cfg.env, cfg.hf_repository, reward_list, cfg.enjoy_script, cfg.train_script
)
push_to_hf(experiment_dir(cfg=cfg), cfg.hf_repository)
diff --git a/sample_factory/huggingface/huggingface_utils.py b/sample_factory/huggingface/huggingface_utils.py
index 90184da7..5b4a6b14 100644
--- a/sample_factory/huggingface/huggingface_utils.py
+++ b/sample_factory/huggingface/huggingface_utils.py
@@ -57,8 +57,10 @@ python -m sample_factory.huggingface.load_from_hub -r {repo_id}
```\n
"""
- if enjoy_name is not None:
- readme += f"""
+ if enjoy_name is None:
+ enjoy_name = "<path.to.enjoy.module>"
+
+ readme += f"""
## Using the model\n
To run the model after download, use the `enjoy` script corresponding to this environment:
```
@@ -67,17 +69,19 @@ python -m {enjoy_name} --algo={algo} --env={env} --train_dir=./train_dir --exper
\n
You can also upload models to the Hugging Face Hub using the same script with the `--push_to_hub` flag.
See https://www.samplefactory.dev/10-huggingface/huggingface/ for more details
- """
+ """
- if train_name is not None:
- readme += f"""
+ if train_name is None:
+ train_name = "<path.to.train.module>"
+
+ readme += f"""
## Training with this model\n
To continue training with this model, use the `train` script corresponding to this environment:
```
python -m {train_name} --algo={algo} --env={env} --train_dir=./train_dir --experiment={repo_name} --restart_behavior=resume --train_for_env_steps=10000000000
```\n
Note, you may have to adjust `--train_for_env_steps` to a suitably high number as the experiment will resume at the number of steps it concluded at.
- """
+ """
with open(readme_path, "w", encoding="utf-8") as f:
f.write(readme)
diff --git a/sample_factory/huggingface/push_to_hub.py b/sample_factory/huggingface/push_to_hub.py
index dbd5c382..d67806ad 100644
--- a/sample_factory/huggingface/push_to_hub.py
+++ b/sample_factory/huggingface/push_to_hub.py
@@ -16,6 +16,18 @@ def main():
type=str,
)
parser.add_argument("-d", "--experiment_dir", help="Path to your experiment directory", type=str)
+ parser.add_argument(
+ "--train_script",
+ default=None,
+ type=str,
+ help="Module name used to run training script. Used to generate HF model card",
+ )
+ parser.add_argument(
+ "--enjoy_script",
+ default=None,
+ type=str,
+ help="Module name used to run training script. Used to generate HF model card",
+ )
args = parser.parse_args()
cfg_file = os.path.join(args.experiment_dir, "config.json")
@@ -34,7 +46,7 @@ def main():
json_params = json.load(json_file)
cfg = AttrDict(json_params)
- generate_model_card(args.experiment_dir, cfg.algo, cfg.env, args.hf_repository)
+ generate_model_card(args.experiment_dir, cfg.algo, cfg.env, args.hf_repository, enjoy_name=args.enjoy_script, train_name=args.train_script)
push_to_hf(args.experiment_dir, args.hf_repository)
diff --git a/sample_factory/utils/utils.py b/sample_factory/utils/utils.py
index 99db3c10..fcd335c5 100644
--- a/sample_factory/utils/utils.py
+++ b/sample_factory/utils/utils.py
@@ -493,5 +493,5 @@ def debug_log_every_n(n, msg, *args, **kwargs):
log_every_n(n, logging.DEBUG, msg, *args, **kwargs)
-def get_top_level_script():
- return argv[0].split("sample-factory/")[-1][:-3].replace("/", ".")
+# def get_top_level_script():
+# return argv[0].split("sample-factory/")[-1][:-3].replace("/", ".")