Spaces:
Sleeping
Sleeping
import huggingface_hub | |
import torch | |
from hpsv2.src.open_clip import create_model, get_tokenizer | |
from rewards.base_reward import BaseRewardLoss | |
class HPSLoss(BaseRewardLoss): | |
"""HPS reward loss function for optimization.""" | |
def __init__( | |
self, | |
weighting: float, | |
dtype: torch.dtype, | |
device: torch.device, | |
cache_dir: str, | |
memsave: bool = False, | |
): | |
self.hps_model = create_model( | |
"ViT-H-14", | |
"laion2B-s32B-b79K", | |
precision=dtype, | |
device=device, | |
cache_dir=cache_dir, | |
) | |
checkpoint_path = huggingface_hub.hf_hub_download( | |
"xswu/HPSv2", "HPS_v2.1_compressed.pt", cache_dir=cache_dir | |
) | |
self.hps_model.load_state_dict( | |
torch.load(checkpoint_path, map_location=device)["state_dict"] | |
) | |
self.hps_tokenizer = get_tokenizer("ViT-H-14") | |
if memsave: | |
import memsave_torch.nn | |
self.hps_model = memsave_torch.nn.convert_to_memory_saving(self.hps_model) | |
self.hps_model = self.hps_model.to(device, dtype=dtype) | |
self.hps_model.eval() | |
self.freeze_parameters(self.hps_model.parameters()) | |
super().__init__("HPS", weighting) | |
self.hps_model.set_grad_checkpointing(True) | |
def get_image_features(self, image: torch.Tensor) -> torch.Tensor: | |
hps_image_features = self.hps_model.encode_image(image) | |
return hps_image_features | |
def get_text_features(self, prompt: str) -> torch.Tensor: | |
hps_text = self.hps_tokenizer(prompt).to("cuda") | |
hps_text_features = self.hps_model.encode_text(hps_text) | |
return hps_text_features | |
def compute_loss( | |
self, image_features: torch.Tensor, text_features: torch.Tensor | |
) -> torch.Tensor: | |
logits_per_image = image_features @ text_features.T | |
hps_loss = 1 - torch.diagonal(logits_per_image)[0] | |
return hps_loss | |