Spaces:
Sleeping
Sleeping
import ImageReward as RM | |
import torch | |
from rewards.base_reward import BaseRewardLoss | |
class ImageRewardLoss: | |
"""Image reward loss for optimization.""" | |
def __init__( | |
self, | |
weighting: float, | |
dtype: torch.dtype, | |
device: torch.device, | |
cache_dir: str, | |
memsave: bool = False, | |
): | |
self.name = "ImageReward" | |
self.weighting = weighting | |
self.dtype = dtype | |
self.imagereward_model = RM.load("ImageReward-v1.0", download_root=cache_dir) | |
self.imagereward_model = self.imagereward_model.to( | |
device=device, dtype=self.dtype | |
) | |
self.imagereward_model.eval() | |
BaseRewardLoss.freeze_parameters(self.imagereward_model.parameters()) | |
def __call__(self, image: torch.Tensor, prompt: str) -> torch.Tensor: | |
imagereward_score = self.score_diff(prompt, image) | |
return (2 - imagereward_score).mean() | |
def score_diff(self, prompt, image): | |
# text encode | |
text_input = self.imagereward_model.blip.tokenizer( | |
prompt, | |
padding="max_length", | |
truncation=True, | |
max_length=35, | |
return_tensors="pt", | |
).to(self.imagereward_model.device) | |
image_embeds = self.imagereward_model.blip.visual_encoder(image) | |
# text encode cross attention with image | |
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( | |
self.imagereward_model.device | |
) | |
text_output = self.imagereward_model.blip.text_encoder( | |
text_input.input_ids, | |
attention_mask=text_input.attention_mask, | |
encoder_hidden_states=image_embeds, | |
encoder_attention_mask=image_atts, | |
return_dict=True, | |
) | |
txt_features = text_output.last_hidden_state[:, 0, :].to( | |
self.imagereward_model.device, dtype=self.dtype | |
) | |
rewards = self.imagereward_model.mlp(txt_features) | |
rewards = (rewards - self.imagereward_model.mean) / self.imagereward_model.std | |
return rewards | |