Spaces:
Sleeping
Sleeping
File size: 2,342 Bytes
ca25718 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
from typing import Any, List
import torch
from torchvision.transforms import (CenterCrop, Compose, InterpolationMode,
Normalize, Resize)
from transformers import AutoProcessor
from rewards.aesthetic import AestheticLoss
from rewards.base_reward import BaseRewardLoss
from rewards.clip import CLIPLoss
from rewards.hps import HPSLoss
from rewards.imagereward import ImageRewardLoss
from rewards.pickscore import PickScoreLoss
def get_reward_losses(
args: Any, dtype: torch.dtype, device: torch.device, cache_dir: str
) -> List[BaseRewardLoss]:
if args.enable_clip or args.enable_pickscore:
tokenizer = AutoProcessor.from_pretrained(
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", cache_dir=cache_dir
)
reward_losses = []
if args.enable_hps:
reward_losses.append(
HPSLoss(args.hps_weighting, dtype, device, cache_dir, memsave=args.memsave)
)
if args.enable_imagereward:
reward_losses.append(
ImageRewardLoss(
args.imagereward_weighting,
dtype,
device,
cache_dir,
memsave=args.memsave,
)
)
if args.enable_clip:
reward_losses.append(
CLIPLoss(
args.clip_weighting,
dtype,
device,
cache_dir,
tokenizer,
memsave=args.memsave,
)
)
if args.enable_pickscore:
reward_losses.append(
PickScoreLoss(
args.pickscore_weighting,
dtype,
device,
cache_dir,
tokenizer,
memsave=args.memsave,
)
)
if args.enable_aesthetic:
reward_losses.append(
AestheticLoss(
args.aesthetic_weighting, dtype, device, cache_dir, memsave=args.memsave
)
)
return reward_losses
def clip_img_transform(size: int = 224):
return Compose(
[
Resize(size, interpolation=InterpolationMode.BICUBIC),
CenterCrop(size),
Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
),
]
)
|