Spaces:
Runtime error
Runtime error
from PIL import Image | |
import base64 | |
from io import BytesIO | |
import os | |
from openai import OpenAI | |
import json | |
class Captioner: | |
def __init__(self, api_key_path = None, proxy=None, api_base="https://api.lingyiwanwu.com/v1"): | |
# if api_key_path is None: | |
# # try find datas/01_key.txt and ../datas/01_key.txt | |
# cand_paths = ['datas/01_key.txt', '../datas/01_key.txt'] | |
# flag = False | |
# for path in cand_paths: | |
# if os.path.exists(path): | |
# api_key_path = path | |
# flag = True | |
# break | |
# if not flag: | |
# raise ValueError("Please provide the path to the API key file.") | |
self.api_key = os.getenv('YI_VL_KEY') | |
self.api_base = api_base | |
# if proxy: | |
# os.environ['HTTP_PROXY'] = proxy | |
# os.environ['HTTPS_PROXY'] = proxy | |
self.client = OpenAI( | |
api_key=self.api_key, | |
base_url=self.api_base | |
) | |
self.history = {} | |
self.history_file = None | |
self.load_history() | |
def load_access_token(self, file_path): | |
with open(file_path, 'r') as file: | |
return file.read().strip() | |
def image2base64(self, image_path): | |
# 打开图像 | |
with Image.open(image_path) as img: | |
# 检查图像高度是否超过480 | |
if img.height > 480: | |
# 计算调整后的宽度,以保持宽高比不变 | |
aspect_ratio = img.width / img.height | |
new_height = 480 | |
new_width = int(new_height * aspect_ratio) | |
img = img.resize((new_width, new_height), Image.ANTIALIAS) | |
# 使用BytesIO在内存中保存调整大小后的图像 | |
buffered = BytesIO() | |
img.save(buffered, format="JPEG") | |
buffered.seek(0) | |
# 将图像转换为Base64编码字符串 | |
img_base64 = "data:image/jpeg;base64," + base64.b64encode(buffered.read()).decode('utf-8') | |
return img_base64 | |
def load_history(self, jsonl_file_name=None): | |
if jsonl_file_name is None: | |
jsonl_file_name = "datas/caption_history.jsonl" | |
self.history_file = jsonl_file_name | |
if os.path.exists(jsonl_file_name): | |
with open(jsonl_file_name, 'r', encoding='utf-8') as f: | |
for line in f: | |
data = json.loads(line) | |
self.history[data['file_name']] = data['response'] | |
def search_from_history(self, file_name): | |
return self.history.get(file_name, None) | |
def save_history(self, jsonl_file_name=None): | |
if jsonl_file_name is None: | |
jsonl_file_name = self.history_file | |
if jsonl_file_name: | |
with open(jsonl_file_name, 'w', encoding='utf-8') as f: | |
for file_name, response in self.history.items(): | |
json.dump({'file_name': file_name, 'response': response}, f, ensure_ascii=False) | |
f.write('\n') | |
# print(f"History saved to {jsonl_file_name}") | |
def add_to_history(self, file_name, response): | |
self.history[file_name] = response | |
def caption(self, image_name): | |
# Check if the caption is already in the history | |
cached_response = self.search_from_history(image_name) | |
if cached_response: | |
# print("return the cache") | |
return cached_response | |
prompt = """Analyze the image and output in JSON format, including the following fields: | |
- "detailed_description": A detailed description of the image content. | |
- "major_object": Determine the main object/scene in the image based on the description, output with a simple word | |
- "Chinese_name": 判断图片中主要物体的中文名 | |
- "real_or_composite": Determine whether this image was taken with a camera or created/modifed by a computer, output with real or composite.""" | |
img_base64 = self.image2base64(image_name) | |
completion = self.client.chat.completions.create( | |
model="yi-vision", | |
messages=[ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": prompt | |
}, | |
{ | |
"type": "image_url", | |
"image_url": { | |
"url": img_base64 | |
} | |
} | |
] | |
} | |
], | |
stream=False | |
) | |
response = completion.choices[0].message.content | |
# Add the new response to history | |
self.add_to_history(image_name, response) | |
# Save history after adding the new entry | |
self.save_history() | |
return response | |
if __name__ == "__main__": | |
import os | |
os.environ['HTTP_PROXY'] = 'http://localhost:8234' | |
os.environ['HTTPS_PROXY'] = 'http://localhost:8234' | |
captioner = Captioner() | |
test_image = "temp_images/3zjz9b3l.jpg" | |
print(captioner.caption(test_image)) |