Spaces:
Sleeping
Sleeping
# ------------------------------------------------------------------------ | |
# Copyright (c) 2023-present, BAAI. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ------------------------------------------------------------------------ | |
"""Concet projector.""" | |
import pickle | |
import numpy as np | |
import torch | |
from torch import nn | |
class ConceptProjector(nn.Module): | |
"""Encode and decode concept using CLIP.""" | |
def __init__(self, src_weights=None, tgt_weights=None): | |
super(ConceptProjector, self).__init__() | |
self.reset_weights(src_weights, tgt_weights) | |
def reset_weights(self, src_weights=None, tgt_weights=None): | |
"""Reset the normalized projection weights.""" | |
if src_weights: | |
with open(src_weights, "rb") as f: | |
self.src_weights, self.concepts = pickle.load(f) | |
self.src_weights = torch.from_numpy(self.src_weights) | |
self.concepts = np.array(self.concepts) | |
if tgt_weights: | |
with open(tgt_weights, "rb") as f: | |
self.tgt_weights, self.concepts = pickle.load(f) | |
self.tgt_weights = torch.from_numpy(self.tgt_weights) | |
self.concepts = np.array(self.concepts) | |
def maybe_convert(embeds, proj): | |
"""Convert inputs for safe projection.""" | |
if embeds.dtype != torch.float32: | |
embeds = embeds.float() | |
if embeds.device != proj.device: | |
proj = proj.to(device=embeds.device) | |
return embeds, proj | |
def encode_src(self, src_embeds, logpi=True): | |
"""Encode source visual embedding via concept projection.""" | |
src_embeds, self.src_weights = self.maybe_convert(src_embeds, self.src_weights) | |
logits = nn.functional.normalize(src_embeds, dim=-1) @ self.src_weights | |
return nn.functional.log_softmax(logits, dim=-1) if logpi else logits | |
def encode_tgt(self, tgt_embeds): | |
"""Encode target visual embedding via concept projection.""" | |
tgt_embeds, self.tgt_weights = self.maybe_convert(tgt_embeds, self.tgt_weights) | |
logits = nn.functional.normalize(tgt_embeds, dim=-1) @ self.tgt_weights | |
return nn.functional.log_softmax(logits, dim=-1) | |
def decode(self, src_embeds, k=1, return_index=False, return_prob=False): | |
"""Return the top-k concepts of source visual embedding.""" | |
src_embeds, self.src_weights = self.maybe_convert(src_embeds, self.src_weights) | |
logits = nn.functional.normalize(src_embeds, dim=-1) @ self.src_weights | |
probs = nn.functional.softmax(logits, dim=-1) | |
if return_prob: | |
return probs.cpu().numpy() | |
score, index = [x.cpu().numpy() for x in probs.topk(k, -1)] | |
return (index if return_index else self.concepts[index]), score | |