Make sure download metrics work

#2
by nielsr HF staff - opened

This PR simplifies the loading of the model from the hub, along with working download metrics.

Basically, when the model is defined like so:

from huggingface_hub import hf_hub_download

# define model
model = seemore.SeemoRe(scale=cfg.model.scale, in_chans=cfg.model.in_chans,
                        num_experts=cfg.model.num_experts, num_layers=cfg.model.num_layers, embedding_dim=cfg.model.embedding_dim, 
                        img_range=cfg.model.img_range, use_shuffle=cfg.model.use_shuffle, global_kernel_size=cfg.model.global_kernel_size, 
                        recursive=cfg.model.recursive, lr_space=cfg.model.lr_space, topk=cfg.model.topk)
# load weights
filepath = hf_hub_download(repo_id="eduardzamfir/SeemoRe-T", filename="SeemoRe_T_X4.pth", local_dir="./")
model.load_state_dict(filepath, map_location="cpu")

You can do things like:

# push to the hub
model.push_to_hub("eduardzamfir/SeemoRe-T")
# save to a local folder
model.save_pretrained(".")
# load from the hub
model.from_pretrained("eduardzamfir/SeemoRe-T")

This would ensure that download metrics work, along with better tags etc.

See the PyTorchModelHubMixin docs for more info: https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin

nielsr changed pull request title from Leverage PyTorchModelHubMixin to Make sure download metrics work
nielsr changed pull request status to open

Let me know if you can push the model as shown above to "eduardzamfir/SeemoRe-T" :) requires pip install huggingface_hub

Hi @eduardzamfir could you take a look at this please?

Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment