from typing import Any, List import torch from torchvision.transforms import (CenterCrop, Compose, InterpolationMode, Normalize, Resize) from transformers import AutoProcessor from rewards.aesthetic import AestheticLoss from rewards.base_reward import BaseRewardLoss from rewards.clip import CLIPLoss from rewards.hps import HPSLoss from rewards.imagereward import ImageRewardLoss from rewards.pickscore import PickScoreLoss def get_reward_losses( args: Any, dtype: torch.dtype, device: torch.device, cache_dir: str ) -> List[BaseRewardLoss]: if args.enable_clip or args.enable_pickscore: tokenizer = AutoProcessor.from_pretrained( "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", cache_dir=cache_dir ) reward_losses = [] if args.enable_hps: reward_losses.append( HPSLoss(args.hps_weighting, dtype, device, cache_dir, memsave=args.memsave) ) if args.enable_imagereward: reward_losses.append( ImageRewardLoss( args.imagereward_weighting, dtype, device, cache_dir, memsave=args.memsave, ) ) if args.enable_clip: reward_losses.append( CLIPLoss( args.clip_weighting, dtype, device, cache_dir, tokenizer, memsave=args.memsave, ) ) if args.enable_pickscore: reward_losses.append( PickScoreLoss( args.pickscore_weighting, dtype, device, cache_dir, tokenizer, memsave=args.memsave, ) ) if args.enable_aesthetic: reward_losses.append( AestheticLoss( args.aesthetic_weighting, dtype, device, cache_dir, memsave=args.memsave ) ) return reward_losses def clip_img_transform(size: int = 224): return Compose( [ Resize(size, interpolation=InterpolationMode.BICUBIC), CenterCrop(size), Normalize( (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711), ), ] )