Spaces:
Sleeping
Sleeping
from typing import Any, List | |
import torch | |
from torchvision.transforms import (CenterCrop, Compose, InterpolationMode, | |
Normalize, Resize) | |
from transformers import AutoProcessor | |
from rewards.aesthetic import AestheticLoss | |
from rewards.base_reward import BaseRewardLoss | |
from rewards.clip import CLIPLoss | |
from rewards.hps import HPSLoss | |
from rewards.imagereward import ImageRewardLoss | |
from rewards.pickscore import PickScoreLoss | |
def get_reward_losses( | |
args: Any, dtype: torch.dtype, device: torch.device, cache_dir: str | |
) -> List[BaseRewardLoss]: | |
if args.enable_clip or args.enable_pickscore: | |
tokenizer = AutoProcessor.from_pretrained( | |
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", cache_dir=cache_dir | |
) | |
reward_losses = [] | |
if args.enable_hps: | |
reward_losses.append( | |
HPSLoss(args.hps_weighting, dtype, device, cache_dir, memsave=args.memsave) | |
) | |
if args.enable_imagereward: | |
reward_losses.append( | |
ImageRewardLoss( | |
args.imagereward_weighting, | |
dtype, | |
device, | |
cache_dir, | |
memsave=args.memsave, | |
) | |
) | |
if args.enable_clip: | |
reward_losses.append( | |
CLIPLoss( | |
args.clip_weighting, | |
dtype, | |
device, | |
cache_dir, | |
tokenizer, | |
memsave=args.memsave, | |
) | |
) | |
if args.enable_pickscore: | |
reward_losses.append( | |
PickScoreLoss( | |
args.pickscore_weighting, | |
dtype, | |
device, | |
cache_dir, | |
tokenizer, | |
memsave=args.memsave, | |
) | |
) | |
if args.enable_aesthetic: | |
reward_losses.append( | |
AestheticLoss( | |
args.aesthetic_weighting, dtype, device, cache_dir, memsave=args.memsave | |
) | |
) | |
return reward_losses | |
def clip_img_transform(size: int = 224): | |
return Compose( | |
[ | |
Resize(size, interpolation=InterpolationMode.BICUBIC), | |
CenterCrop(size), | |
Normalize( | |
(0.48145466, 0.4578275, 0.40821073), | |
(0.26862954, 0.26130258, 0.27577711), | |
), | |
] | |
) | |