File size: 8,149 Bytes
ac6acf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import datetime
import numpy as np
import os
from PIL import Image
import pytest
from pytest import fixture
from typing import Tuple, List

from cv2 import imread, cvtColor, COLOR_BGR2RGB
from skimage.metrics import structural_similarity as ssim


"""

This test suite compares images in 2 directories by file name

The directories are specified by the command line arguments --baseline_dir and --test_dir



"""
# ssim: Structural Similarity Index
# Returns a tuple of (ssim, diff_image)
def ssim_score(img0: np.ndarray, img1: np.ndarray) -> Tuple[float, np.ndarray]:
    score, diff = ssim(img0, img1, channel_axis=-1, full=True)
    # rescale the difference image to 0-255 range
    diff = (diff * 255).astype("uint8")
    return score, diff
    
# Metrics must return a tuple of (score, diff_image)
METRICS = {"ssim": ssim_score}
METRICS_PASS_THRESHOLD = {"ssim": 0.95}


class TestCompareImageMetrics:
    @fixture(scope="class")
    def test_file_names(self, args_pytest):
        test_dir = args_pytest['test_dir']
        fnames = self.gather_file_basenames(test_dir)  
        yield fnames
        del fnames

    @fixture(scope="class", autouse=True)
    def teardown(self, args_pytest):
        yield
        # Runs after all tests are complete
        # Aggregate output files into a grid of images
        baseline_dir = args_pytest['baseline_dir']
        test_dir = args_pytest['test_dir']
        img_output_dir = args_pytest['img_output_dir']
        metrics_file = args_pytest['metrics_file']

        grid_dir = os.path.join(img_output_dir, "grid")
        os.makedirs(grid_dir, exist_ok=True)

        for metric_dir in METRICS.keys():
            metric_path = os.path.join(img_output_dir, metric_dir)
            for file in os.listdir(metric_path):
                if file.endswith(".png"):
                    score = self.lookup_score_from_fname(file, metrics_file)
                    image_file_list = []
                    image_file_list.append([
                                            os.path.join(baseline_dir, file),
                                            os.path.join(test_dir, file),
                                            os.path.join(metric_path, file)
                                            ])
                    # Create grid
                    image_list = [[Image.open(file) for file in files] for files in image_file_list]
                    grid = self.image_grid(image_list)
                    grid.save(os.path.join(grid_dir, f"{metric_dir}_{score:.3f}_{file}"))
    
    # Tests run for each baseline file name
    @fixture()
    def fname(self, baseline_fname):
        yield baseline_fname
        del baseline_fname
    
    def test_directories_not_empty(self, args_pytest):
        baseline_dir = args_pytest['baseline_dir']
        test_dir = args_pytest['test_dir']
        assert len(os.listdir(baseline_dir)) != 0, f"Baseline directory {baseline_dir} is empty"
        assert len(os.listdir(test_dir)) != 0, f"Test directory {test_dir} is empty"

    def test_dir_has_all_matching_metadata(self, fname, test_file_names, args_pytest):
        # Check that all files in baseline_dir have a file in test_dir with matching metadata
        baseline_file_path = os.path.join(args_pytest['baseline_dir'], fname)
        file_paths = [os.path.join(args_pytest['test_dir'], f) for f in test_file_names]
        file_match = self.find_file_match(baseline_file_path, file_paths)
        assert file_match is not None, f"Could not find a file in {args_pytest['test_dir']} with matching metadata to {baseline_file_path}"

    # For a baseline image file, finds the corresponding file name in test_dir and 
    # compares the images using the metrics in METRICS
    @pytest.mark.parametrize("metric", METRICS.keys())
    def test_pipeline_compare(

        self,

        args_pytest,

        fname,

        test_file_names,

        metric,

    ):
        baseline_dir = args_pytest['baseline_dir']
        test_dir = args_pytest['test_dir']
        metrics_output_file = args_pytest['metrics_file']
        img_output_dir = args_pytest['img_output_dir']
        
        baseline_file_path = os.path.join(baseline_dir, fname)

        # Find file match
        file_paths = [os.path.join(test_dir, f) for f in test_file_names]
        test_file = self.find_file_match(baseline_file_path, file_paths)

        # Run metrics
        sample_baseline = self.read_img(baseline_file_path)
        sample_secondary = self.read_img(test_file)
        
        score, metric_img = METRICS[metric](sample_baseline, sample_secondary)
        metric_status = score > METRICS_PASS_THRESHOLD[metric]

        # Save metric values
        with open(metrics_output_file, 'a') as f:
            run_info = os.path.splitext(fname)[0]
            metric_status_str = "PASS ✅" if metric_status else "FAIL ❌"
            date_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            f.write(f"| {date_str} | {run_info} | {metric} | {metric_status_str} | {score} | \n")

        # Save metric image
        metric_img_dir = os.path.join(img_output_dir, metric)
        os.makedirs(metric_img_dir, exist_ok=True)
        output_filename = f'{fname}'
        Image.fromarray(metric_img).save(os.path.join(metric_img_dir, output_filename))

        assert score > METRICS_PASS_THRESHOLD[metric]

    def read_img(self, filename: str) -> np.ndarray:
        cvImg = imread(filename)
        cvImg = cvtColor(cvImg, COLOR_BGR2RGB)
        return cvImg

    def image_grid(self, img_list: list[list[Image.Image]]):
        # imgs is a 2D list of images
        # Assumes the input images are a rectangular grid of equal sized images
        rows = len(img_list)
        cols = len(img_list[0])

        w, h = img_list[0][0].size
        grid = Image.new('RGB', size=(cols*w, rows*h))
        
        for i, row in enumerate(img_list):
            for j, img in enumerate(row):
                grid.paste(img, box=(j*w, i*h))
        return grid

    def lookup_score_from_fname(self,

                                fname: str,

                                metrics_output_file: str

        ) -> float:
        fname_basestr = os.path.splitext(fname)[0]
        with open(metrics_output_file, 'r') as f:
            for line in f:
                if fname_basestr in line:
                    score = float(line.split('|')[5])
                    return score
        raise ValueError(f"Could not find score for {fname} in {metrics_output_file}")

    def gather_file_basenames(self, directory: str):
        files = []
        for file in os.listdir(directory):
            if file.endswith(".png"):
                files.append(file)
        return files

    def read_file_prompt(self, fname:str) -> str:
        # Read prompt from image file metadata
        img = Image.open(fname)
        img.load()
        return img.info['prompt']
    
    def find_file_match(self, baseline_file: str, file_paths: List[str]):
        # Find a file in file_paths with matching metadata to baseline_file
        baseline_prompt = self.read_file_prompt(baseline_file)

        # Do not match empty prompts
        if baseline_prompt is None or baseline_prompt == "":
            return None

        # Find file match
        # Reorder test_file_names so that the file with matching name is first
        # This is an optimization because matching file names are more likely 
        # to have matching metadata if they were generated with the same script
        basename = os.path.basename(baseline_file)
        file_path_basenames = [os.path.basename(f) for f in file_paths]
        if basename in file_path_basenames:
            match_index = file_path_basenames.index(basename)
            file_paths.insert(0, file_paths.pop(match_index))

        for f in file_paths:
            test_file_prompt = self.read_file_prompt(f)
            if baseline_prompt == test_file_prompt:
                return f