File size: 1,666 Bytes
873e677
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from functools import partial
from typing import Callable, Dict, List

import numpy as np
import torch
from torchmetrics.functional.multimodal import clip_score
from torchmetrics.image.inception import InceptionScore

SEED = 0

inception_score_fn = InceptionScore(normalize=True)
torch.manual_seed(SEED)
clip_score_fn = partial(clip_score, model_name_or_path="openai/clip-vit-base-patch16")


def compute_main_metrics(images: np.ndarray, prompts: List[str]) -> Dict:
    inception_score_fn.update(torch.from_numpy(images).permute(0, 3, 1, 2))
    inception_score = inception_score_fn.compute()

    images_int = (images * 255).astype("uint8")
    clip_score = clip_score_fn(
        torch.from_numpy(images_int).permute(0, 3, 1, 2), prompts
    ).detach()
    return {
        "inception_score (⬆️)": {
            "mean": round(float(inception_score[0]), 4),
            "std": round(float(inception_score[1]), 4),
        },
        "clip_score (⬆️)": round(float(clip_score), 4),
    }


def compute_psnr_or_ssim(
    fn: Callable, images_dict: Dict, original_scheduler_name: str
) -> Dict:
    result_dict = {}
    original_scheduler_images = images_dict[original_scheduler_name]
    original_scheduler_images = torch.from_numpy(original_scheduler_images).permute(
        0, 3, 1, 2
    )
    for k in images_dict:
        if k != original_scheduler_name:
            current_scheduler_images = torch.from_numpy(images_dict[k]).permute(
                0, 3, 1, 2
            )
            current_value = fn(current_scheduler_images, original_scheduler_images)
            result_dict.update({k: round(float(current_value), 4)})
    return result_dict