ReNO / rewards /imagereward.py
fffiloni's picture
Upload 24 files
ca25718 verified
import ImageReward as RM
import torch
from rewards.base_reward import BaseRewardLoss
class ImageRewardLoss:
"""Image reward loss for optimization."""
def __init__(
self,
weighting: float,
dtype: torch.dtype,
device: torch.device,
cache_dir: str,
memsave: bool = False,
):
self.name = "ImageReward"
self.weighting = weighting
self.dtype = dtype
self.imagereward_model = RM.load("ImageReward-v1.0", download_root=cache_dir)
self.imagereward_model = self.imagereward_model.to(
device=device, dtype=self.dtype
)
self.imagereward_model.eval()
BaseRewardLoss.freeze_parameters(self.imagereward_model.parameters())
def __call__(self, image: torch.Tensor, prompt: str) -> torch.Tensor:
imagereward_score = self.score_diff(prompt, image)
return (2 - imagereward_score).mean()
def score_diff(self, prompt, image):
# text encode
text_input = self.imagereward_model.blip.tokenizer(
prompt,
padding="max_length",
truncation=True,
max_length=35,
return_tensors="pt",
).to(self.imagereward_model.device)
image_embeds = self.imagereward_model.blip.visual_encoder(image)
# text encode cross attention with image
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
self.imagereward_model.device
)
text_output = self.imagereward_model.blip.text_encoder(
text_input.input_ids,
attention_mask=text_input.attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
txt_features = text_output.last_hidden_state[:, 0, :].to(
self.imagereward_model.device, dtype=self.dtype
)
rewards = self.imagereward_model.mlp(txt_features)
rewards = (rewards - self.imagereward_model.mean) / self.imagereward_model.std
return rewards