|
import os |
|
import zipfile |
|
import yaml |
|
|
|
version = "20240803" |
|
base_name_to_model_name = { |
|
"sam2_hiera_tiny": "Segment Anything 2 (Hiera-Tiny)", |
|
"sam2_hiera_small": "Segment Anything 2 (Hiera-Small)", |
|
"sam2_hiera_base_plus": "Segment Anything 2 (Hiera-Base+)", |
|
"sam2_hiera_large": "Segment Anything 2 (Hiera-Large)", |
|
} |
|
|
|
model_list = {} |
|
onnx_files = [f for f in os.listdir(".") if f.endswith(".onnx")] |
|
for onnx_file in onnx_files: |
|
model_base_name, model_part, _ = onnx_file.split(".") |
|
if model_base_name not in model_list: |
|
model_list[model_base_name] = {} |
|
model_list[model_base_name]["type"] = "segment_anything" |
|
model_list[model_base_name]["input_size"] = 1024 |
|
model_list[model_base_name]["max_width"] = 1024 |
|
model_list[model_base_name]["max_height"] = 1024 |
|
if model_part == "encoder": |
|
model_list[model_base_name]["encoder_model_path"] = onnx_file |
|
elif model_part == "decoder": |
|
model_list[model_base_name]["decoder_model_path"] = onnx_file |
|
model_list[model_base_name]["basename"] = model_base_name |
|
model_list[model_base_name]["name"] = model_base_name + "_" + version |
|
model_list[model_base_name]["display_name"] = base_name_to_model_name[model_base_name] |
|
|
|
for model in model_list.values(): |
|
output_zip = model.pop("basename") + ".zip" |
|
with zipfile.ZipFile(output_zip, "w") as z: |
|
z.write(model["encoder_model_path"]) |
|
z.write(model["decoder_model_path"]) |
|
|
|
with z.open("config.yaml", "w") as f: |
|
f.write(yaml.dump(model).encode("utf-8")) |
|
|