ReNO / rewards /base_reward.py
fffiloni's picture
Upload 24 files
ca25718 verified
raw
history blame
1.41 kB
from abc import ABC, abstractmethod
import torch
class BaseRewardLoss(ABC):
"""
Base class for reward functions implementing a differentiable reward function for optimization.
"""
def __init__(self, name: str, weighting: float):
self.name = name
self.weighting = weighting
@staticmethod
def freeze_parameters(params: torch.nn.ParameterList):
for param in params:
param.requires_grad = False
@abstractmethod
def get_image_features(self, image: torch.Tensor) -> torch.Tensor:
pass
@abstractmethod
def get_text_features(self, prompt: str) -> torch.Tensor:
pass
@abstractmethod
def compute_loss(
self, image_features: torch.Tensor, text_features: torch.Tensor
) -> torch.Tensor:
pass
def process_features(self, features: torch.Tensor) -> torch.Tensor:
features_normed = features / features.norm(dim=-1, keepdim=True)
return features_normed
def __call__(self, image: torch.Tensor, prompt: str) -> torch.Tensor:
image_features = self.get_image_features(image)
text_features = self.get_text_features(prompt)
image_features_normed = self.process_features(image_features)
text_features_normed = self.process_features(text_features)
loss = self.compute_loss(image_features_normed, text_features_normed)
return loss