File size: 6,892 Bytes
73a0669
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import json
import shutil
import sys
from datetime import datetime
from pathlib import Path

import torch

sys.path.append(str(Path(__file__).parent.parent.parent))  # add utils/ to path
from utils.general import colorstr, xywh2xyxy

try:
    import wandb
except ImportError:
    wandb = None
    print(f"{colorstr('wandb: ')}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")

WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'


def remove_prefix(from_string, prefix):
    return from_string[len(prefix):]


class WandbLogger():
    def __init__(self, opt, name, run_id, data_dict, job_type='Training'):
        self.wandb = wandb
        self.wandb_run = wandb.init(config=opt, resume="allow",
                                    project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
                                    name=name,
                                    job_type=job_type,
                                    id=run_id) if self.wandb else None

        if job_type == 'Training':
            self.setup_training(opt, data_dict)
            if opt.bbox_interval == -1:
                opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else opt.epochs
            if opt.save_period == -1:
                opt.save_period = (opt.epochs // 10) if opt.epochs > 10 else opt.epochs

    def setup_training(self, opt, data_dict):
        self.log_dict = {}
        self.train_artifact_path, self.trainset_artifact = \
            self.download_dataset_artifact(data_dict['train'], opt.artifact_alias)
        self.test_artifact_path, self.testset_artifact = \
            self.download_dataset_artifact(data_dict['val'], opt.artifact_alias)
        self.result_artifact, self.result_table, self.weights = None, None, None
        if self.train_artifact_path is not None:
            train_path = Path(self.train_artifact_path) / 'data/images/'
            data_dict['train'] = str(train_path)
        if self.test_artifact_path is not None:
            test_path = Path(self.test_artifact_path) / 'data/images/'
            data_dict['val'] = str(test_path)
            self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
            self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"])
        if opt.resume_from_artifact:
            modeldir, _ = self.download_model_artifact(opt.resume_from_artifact)
            if modeldir:
                self.weights = Path(modeldir) / "best.pt"
                opt.weights = self.weights

    def download_dataset_artifact(self, path, alias):
        if path.startswith(WANDB_ARTIFACT_PREFIX):
            dataset_artifact = wandb.use_artifact(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias)
            assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'"
            datadir = dataset_artifact.download()
            labels_zip = Path(datadir) / "data/labels.zip"
            shutil.unpack_archive(labels_zip, Path(datadir) / 'data/labels', 'zip')
            print("Downloaded dataset to : ", datadir)
            return datadir, dataset_artifact
        return None, None

    def download_model_artifact(self, name):
        model_artifact = wandb.use_artifact(name + ":latest")
        assert model_artifact is not None, 'Error: W&B model artifact doesn\'t exist'
        modeldir = model_artifact.download()
        print("Downloaded model to : ", modeldir)
        return modeldir, model_artifact

    def log_model(self, path, opt, epoch):
        datetime_suffix = datetime.today().strftime('%Y-%m-%d-%H-%M-%S')
        model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model', type='model', metadata={
            'original_url': str(path),
            'epoch': epoch + 1,
            'save period': opt.save_period,
            'project': opt.project,
            'datetime': datetime_suffix
        })
        model_artifact.add_file(str(path / 'last.pt'), name='last.pt')
        model_artifact.add_file(str(path / 'best.pt'), name='best.pt')
        wandb.log_artifact(model_artifact)
        print("Saving model artifact on epoch ", epoch + 1)

    def log_dataset_artifact(self, dataset, class_to_id, name='dataset'):
        artifact = wandb.Artifact(name=name, type="dataset")
        image_path = dataset.path
        artifact.add_dir(image_path, name='data/images')
        table = wandb.Table(columns=["id", "train_image", "Classes"])
        class_set = wandb.Classes([{'id': id, 'name': name} for id, name in class_to_id.items()])
        for si, (img, labels, paths, shapes) in enumerate(dataset):
            height, width = shapes[0]
            labels[:, 2:] = (xywh2xyxy(labels[:, 2:].view(-1, 4)))
            labels[:, 2:] *= torch.Tensor([width, height, width, height])
            box_data = []
            img_classes = {}
            for cls, *xyxy in labels[:, 1:].tolist():
                cls = int(cls)
                box_data.append({"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
                                 "class_id": cls,
                                 "box_caption": "%s" % (class_to_id[cls]),
                                 "scores": {"acc": 1},
                                 "domain": "pixel"})
                img_classes[cls] = class_to_id[cls]
            boxes = {"ground_truth": {"box_data": box_data, "class_labels": class_to_id}}  # inference-space
            table.add_data(si, wandb.Image(paths, classes=class_set, boxes=boxes), json.dumps(img_classes))
        artifact.add(table, name)
        labels_path = 'labels'.join(image_path.rsplit('images', 1))
        zip_path = Path(labels_path).parent / (name + '_labels.zip')
        if not zip_path.is_file():  # make_archive won't check if file exists
            shutil.make_archive(zip_path.with_suffix(''), 'zip', labels_path)
        artifact.add_file(str(zip_path), name='data/labels.zip')
        wandb.log_artifact(artifact)
        print("Saving data to W&B...")

    def log(self, log_dict):
        if self.wandb_run:
            for key, value in log_dict.items():
                self.log_dict[key] = value

    def end_epoch(self):
        if self.wandb_run and self.log_dict:
            wandb.log(self.log_dict)
        self.log_dict = {}

    def finish_run(self):
        if self.wandb_run:
            if self.result_artifact:
                print("Add Training Progress Artifact")
                self.result_artifact.add(self.result_table, 'result')
                train_results = wandb.JoinedTable(self.testset_artifact.get("val"), self.result_table, "id")
                self.result_artifact.add(train_results, 'joined_result')
                wandb.log_artifact(self.result_artifact)
            if self.log_dict:
                wandb.log(self.log_dict)
            wandb.run.finish()