Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,712 Bytes
4893ce0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
"""
Test script for S3DIS 6-fold cross validation
Gathering Area_X.pth from result folder of experiment record of each area as follows:
|- RECORDS_PATH
|- Area_1.pth
|- Area_2.pth
|- Area_3.pth
|- Area_4.pth
|- Area_5.pth
|- Area_6.pth
Author: Xiaoyang Wu ([email protected])
Please cite our work if the code is helpful to you.
"""
import argparse
import os
import torch
import numpy as np
import glob
from pointcept.utils.logger import get_root_logger
CLASS_NAMES = [
"ceiling",
"floor",
"wall",
"beam",
"column",
"window",
"door",
"table",
"chair",
"sofa",
"bookcase",
"board",
"clutter",
]
def evaluation(intersection, union, target, logger=None):
iou_class = intersection / (union + 1e-10)
accuracy_class = intersection / (target + 1e-10)
mIoU = np.mean(iou_class)
mAcc = np.mean(accuracy_class)
allAcc = sum(intersection) / (sum(target) + 1e-10)
if logger is not None:
logger.info(
"Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}".format(
mIoU, mAcc, allAcc
)
)
for i in range(len(CLASS_NAMES)):
logger.info(
"Class_{idx} - {name} Result: iou/accuracy {iou:.4f}/{accuracy:.4f}".format(
idx=i,
name=CLASS_NAMES[i],
iou=iou_class[i],
accuracy=accuracy_class[i],
)
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--record_root",
required=True,
help="Path to the S3DIS record of each split",
)
config = parser.parse_args()
logger = get_root_logger(
log_file=os.path.join(config.record_root, "6-fold.log"),
file_mode="w",
)
records = sorted(glob.glob(os.path.join(config.record_root, "Area_*.pth")))
assert len(records) == 6
intersection_ = np.zeros(len(CLASS_NAMES), dtype=int)
union_ = np.zeros(len(CLASS_NAMES), dtype=int)
target_ = np.zeros(len(CLASS_NAMES), dtype=int)
for record in records:
area = os.path.basename(record).split(".")[0]
info = torch.load(record)
logger.info(f"<<<<<<<<<<<<<<<<< Parsing {area} <<<<<<<<<<<<<<<<<")
intersection = info["intersection"]
union = info["union"]
target = info["target"]
evaluation(intersection, union, target, logger=logger)
intersection_ += intersection
union_ += union
target_ += target
logger.info(f"<<<<<<<<<<<<<<<<< Parsing 6-fold <<<<<<<<<<<<<<<<<")
evaluation(intersection_, union_, target_, logger=logger)
if __name__ == "__main__":
main()
|