File size: 3,504 Bytes
9b2e755
0c7ef71
 
 
 
05bda40
f04f90e
8d502c8
0c7ef71
 
 
 
 
 
 
 
 
 
5ad4694
 
ae618a2
ecefacb
0c7ef71
 
 
 
ae618a2
ecefacb
0c7ef71
 
 
193f184
5c07fb7
 
193f184
 
5c07fb7
0c7ef71
 
 
 
5c07fb7
f04f90e
0c7ef71
 
f04f90e
0c7ef71
 
 
c2cc6bf
0c7ef71
 
 
 
 
 
9b2e755
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d502c8
9b2e755
 
0c7ef71
9b2e755
0c7ef71
8d502c8
0c7ef71
9b2e755
0c7ef71
9b2e755
 
 
 
 
 
 
 
05bda40
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from huggingface_hub import ModelFilter, snapshot_download
from huggingface_hub import ModelCard

import json
import time

from src.submission.check_validity import is_model_on_hub, check_model_card, get_model_tags
from src.envs import DYNAMIC_INFO_REPO, DYNAMIC_INFO_PATH, DYNAMIC_INFO_FILE_PATH, API, H4_TOKEN

def update_models(file_path, models):
    """
    Search through all JSON files in the specified root folder and its subfolders,
    and update the likes key in JSON dict from value of input dict
    """
    with open(file_path, "r") as f:
        model_infos = json.load(f)
        for model_id, data in model_infos.items():
            if model_id not in models:
                data['still_on_hub'] = False
                data['likes'] = 0
                data['downloads'] = 0
                data['created_at'] = ""
                continue

            model_cfg = models[model_id]
            data['likes'] = model_cfg.likes
            data['downloads'] = model_cfg.downloads
            data['created_at'] = str(model_cfg.created_at)
            #data['params'] = get_model_size(model_cfg, data['precision'])
            data['license'] = model_cfg.card_data.license if model_cfg.card_data is not None else ""

            # Is the model still on the hub?
            model_name = model_id
            if model_cfg.card_data is not None and model_cfg.card_data.base_model is not None:
                model_name = model_cfg.card_data.base_model # for adapters, we look at the parent model
            still_on_hub, _, _ = is_model_on_hub(
                model_name=model_name, revision=data.get("revision"), trust_remote_code=True, test_tokenizer=False, token=H4_TOKEN
            )
            # If the model doesn't have a model card or a license, we consider it's deleted
            if still_on_hub:
                try:
                    status, _, model_card = check_model_card(model_id)
                    if status is False:
                        still_on_hub = False
                except Exception:
                    model_card = None
                    still_on_hub = False
            data['still_on_hub'] = still_on_hub

            tags = get_model_tags(model_card, model_id) if still_on_hub else []

            data["tags"] = tags

    with open(file_path, 'w') as f:
        json.dump(model_infos, f, indent=2)

def update_dynamic_files():
    """ This will only update metadata for models already linked in the repo, not add missing ones.
    """
    snapshot_download(
        repo_id=DYNAMIC_INFO_REPO, local_dir=DYNAMIC_INFO_PATH, repo_type="dataset", tqdm_class=None, etag_timeout=30
    )

    print("UPDATE_DYNAMIC: Loaded snapshot")
    # Get models
    start = time.time()

    models = list(API.list_models(
        filter=ModelFilter(task="text-generation"),
        full=False,
        cardData=True,
        fetch_config=True,
    ))
    id_to_model = {model.id : model for model in models}

    print(f"UPDATE_DYNAMIC: Downloaded list of models in {time.time() - start:.2f} seconds")

    start = time.time()

    update_models(DYNAMIC_INFO_FILE_PATH, id_to_model)

    print(f"UPDATE_DYNAMIC: updated in {time.time() - start:.2f} seconds")

    API.upload_file(
        path_or_fileobj=DYNAMIC_INFO_FILE_PATH,
        path_in_repo=DYNAMIC_INFO_FILE_PATH.split("/")[-1],
        repo_id=DYNAMIC_INFO_REPO,
        repo_type="dataset",
        commit_message=f"Daily request file update.",
    )
    print(f"UPDATE_DYNAMIC: pushed to hub")