Spaces:
Sleeping
Sleeping
File size: 1,666 Bytes
873e677 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
from functools import partial
from typing import Callable, Dict, List
import numpy as np
import torch
from torchmetrics.functional.multimodal import clip_score
from torchmetrics.image.inception import InceptionScore
SEED = 0
inception_score_fn = InceptionScore(normalize=True)
torch.manual_seed(SEED)
clip_score_fn = partial(clip_score, model_name_or_path="openai/clip-vit-base-patch16")
def compute_main_metrics(images: np.ndarray, prompts: List[str]) -> Dict:
inception_score_fn.update(torch.from_numpy(images).permute(0, 3, 1, 2))
inception_score = inception_score_fn.compute()
images_int = (images * 255).astype("uint8")
clip_score = clip_score_fn(
torch.from_numpy(images_int).permute(0, 3, 1, 2), prompts
).detach()
return {
"inception_score (⬆️)": {
"mean": round(float(inception_score[0]), 4),
"std": round(float(inception_score[1]), 4),
},
"clip_score (⬆️)": round(float(clip_score), 4),
}
def compute_psnr_or_ssim(
fn: Callable, images_dict: Dict, original_scheduler_name: str
) -> Dict:
result_dict = {}
original_scheduler_images = images_dict[original_scheduler_name]
original_scheduler_images = torch.from_numpy(original_scheduler_images).permute(
0, 3, 1, 2
)
for k in images_dict:
if k != original_scheduler_name:
current_scheduler_images = torch.from_numpy(images_dict[k]).permute(
0, 3, 1, 2
)
current_value = fn(current_scheduler_images, original_scheduler_images)
result_dict.update({k: round(float(current_value), 4)})
return result_dict
|