File size: 1,602 Bytes
cd5a2fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import os
import zipfile
import yaml
version = "20240803"
base_name_to_model_name = {
"sam2_hiera_tiny": "Segment Anything 2 (Hiera-Tiny)",
"sam2_hiera_small": "Segment Anything 2 (Hiera-Small)",
"sam2_hiera_base_plus": "Segment Anything 2 (Hiera-Base+)",
"sam2_hiera_large": "Segment Anything 2 (Hiera-Large)",
}
model_list = {}
onnx_files = [f for f in os.listdir(".") if f.endswith(".onnx")]
for onnx_file in onnx_files:
model_base_name, model_part, _ = onnx_file.split(".")
if model_base_name not in model_list:
model_list[model_base_name] = {}
model_list[model_base_name]["type"] = "segment_anything"
model_list[model_base_name]["input_size"] = 1024
model_list[model_base_name]["max_width"] = 1024
model_list[model_base_name]["max_height"] = 1024
if model_part == "encoder":
model_list[model_base_name]["encoder_model_path"] = onnx_file
elif model_part == "decoder":
model_list[model_base_name]["decoder_model_path"] = onnx_file
model_list[model_base_name]["basename"] = model_base_name
model_list[model_base_name]["name"] = model_base_name + "_" + version
model_list[model_base_name]["display_name"] = base_name_to_model_name[model_base_name]
for model in model_list.values():
output_zip = model.pop("basename") + ".zip"
with zipfile.ZipFile(output_zip, "w") as z:
z.write(model["encoder_model_path"])
z.write(model["decoder_model_path"])
# Save config in yaml
with z.open("config.yaml", "w") as f:
f.write(yaml.dump(model).encode("utf-8"))
|