File size: 1,602 Bytes
cd5a2fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
import zipfile
import yaml

version = "20240803"
base_name_to_model_name = {
    "sam2_hiera_tiny": "Segment Anything 2 (Hiera-Tiny)",
    "sam2_hiera_small": "Segment Anything 2 (Hiera-Small)",
    "sam2_hiera_base_plus": "Segment Anything 2 (Hiera-Base+)",
    "sam2_hiera_large": "Segment Anything 2 (Hiera-Large)",
}

model_list = {}
onnx_files = [f for f in os.listdir(".") if f.endswith(".onnx")]
for onnx_file in onnx_files:
    model_base_name, model_part, _ = onnx_file.split(".")
    if model_base_name not in model_list:
        model_list[model_base_name] = {}
        model_list[model_base_name]["type"] = "segment_anything"
        model_list[model_base_name]["input_size"] = 1024
        model_list[model_base_name]["max_width"] = 1024
        model_list[model_base_name]["max_height"] = 1024
    if model_part == "encoder":
        model_list[model_base_name]["encoder_model_path"] = onnx_file
    elif model_part == "decoder":
        model_list[model_base_name]["decoder_model_path"] = onnx_file
    model_list[model_base_name]["basename"] = model_base_name
    model_list[model_base_name]["name"] = model_base_name + "_" + version
    model_list[model_base_name]["display_name"] = base_name_to_model_name[model_base_name]

for model in model_list.values():
    output_zip = model.pop("basename") + ".zip"
    with zipfile.ZipFile(output_zip, "w") as z:
        z.write(model["encoder_model_path"])
        z.write(model["decoder_model_path"])
        # Save config in yaml
        with z.open("config.yaml", "w") as f:
            f.write(yaml.dump(model).encode("utf-8"))