Spaces:
Runtime error
Runtime error
import os | |
import torch | |
from transformers import CLIPProcessor, CLIPModel | |
import cv2 | |
from PIL import Image | |
import numpy as np | |
class CLIPExtractor: | |
def __init__(self, model_name="openai/clip-vit-large-patch14", cache_dir=None): | |
# 设置代理环境变量 | |
# os.environ['HTTP_PROXY'] = 'http://localhost:8234' | |
# os.environ['HTTPS_PROXY'] = 'http://localhost:8234' | |
# # 设置环境变量 | |
# os.environ["HF_ENDPOINT"] = "https://hf-api.gitee.com" | |
# os.environ["HF_HOME"] = os.path.expanduser("models/") | |
if not cache_dir: | |
# 指定缓存目录 | |
cache_dir = "models" | |
if not os.path.exists(cache_dir) and os.path.exists("../models"): | |
cache_dir = "../models" | |
# Initialize the model and processor with specified values | |
self.model = CLIPModel.from_pretrained(model_name, cache_dir=cache_dir) | |
self.processor = CLIPProcessor.from_pretrained(model_name, cache_dir=cache_dir) | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model.to(self.device) | |
def extract_image(self, frame): | |
# Convert frame (from OpenCV) to PIL Image | |
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
images = [image] | |
# Process the image and extract features | |
inputs = self.processor(images=images, return_tensors="pt").to(self.device) | |
with torch.no_grad(): | |
outputs = self.model.get_image_features(**inputs) | |
ans = outputs.cpu().numpy() | |
return ans[0] | |
def extract_image_from_file(self, file_name): | |
if not os.path.exists(file_name): | |
raise FileNotFoundError(f"File {file_name} not found.") | |
images = [Image.open(file_name).convert("RGB")] | |
# Process the image and extract features | |
inputs = self.processor(images=images, return_tensors="pt").to(self.device) | |
with torch.no_grad(): | |
outputs = self.model.get_image_features(**inputs) | |
ans = outputs.cpu().numpy() | |
return ans[0] | |
def extract_text(self, text): | |
if not isinstance(text, str) or not text: | |
raise ValueError("Input text should be a non-empty string.") | |
# Tokenize the text | |
inputs = self.processor.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=77).to(self.device) | |
# Process the text and extract features | |
# inputs = self.processor(text=[text], return_tensors="pt", padding=True).to(self.device) | |
with torch.no_grad(): | |
outputs = self.model.get_text_features(**inputs) | |
ans = outputs.cpu().numpy() | |
return ans[0] | |
if __name__ == "__main__": | |
clip_extractor = CLIPExtractor() | |
sample_image = "images/狐狸.jpg" | |
# 提取图像特征 | |
image_feature = clip_extractor.extract_image_from_file(sample_image) | |
# 提取文本特征 | |
sample_text = "A photo of fox" | |
text_feature = clip_extractor.extract_text(sample_text) | |
# consine similarity | |
cosine_similarity = np.dot(image_feature, text_feature) / (np.linalg.norm(image_feature) * np.linalg.norm(text_feature)) | |
print(cosine_similarity) |