Spaces:
Runtime error
Runtime error
import os | |
from glob import glob | |
try: | |
from src.Database import Database | |
from src.Captioner import Captioner | |
from src.ImageBase import Imagebase | |
from src.RemoteDatabase import RemoteDatabase | |
from src.get_major_object import get_major_object, verify_keyword_in_base | |
from src.generate_cultivation import generate_cultivation_with_rag | |
except: | |
from Database import Database | |
from Captioner import Captioner | |
from ImageBase import Imagebase | |
from RemoteDatabase import RemoteDatabase | |
from get_major_object import get_major_object, verify_keyword_in_base | |
from generate_cultivation import generate_cultivation_with_rag | |
class GameMaster: | |
def __init__( self ): | |
self.textdb = self.init_textdb() | |
self.clip_extractor = self.textdb.clip_extractor | |
self.imgdb = self.init_imgdb() | |
self.captioner = Captioner() | |
self.minimal_image_threshold = 0.9 | |
self.remote = RemoteDatabase() | |
def init_textdb( self ): | |
text_db = Database() | |
text_db.init_bge_extractor() | |
text_db.init_clip_extractor() | |
return text_db | |
def init_imgdb( self ): | |
img_db = Imagebase() | |
return img_db | |
def random_image_text_data( self, n = 12 ): | |
random_img_datas = self.remote.random_sample(n) | |
# keep image_name and keywords only | |
image_names = [img_data['image_name'] for img_data in random_img_datas] | |
blank_image_path = "datas/blank_item.jpg" | |
for i in range(len(image_names)): | |
if not os.path.exists(image_names[i]): | |
image_names[i] = blank_image_path | |
keywords_zh = [img_data['keyword'] for img_data in random_img_datas] | |
keywords = [img_data['translated_word'] for img_data in random_img_datas] | |
descriptions = [] | |
for keyword, keyword_zh in zip(keywords, keywords_zh): | |
result = self.remote.search_by_en_keyword(keyword) | |
if result and "description_in_cultivation" in result: | |
description = result['description_in_cultivation'] | |
if "name_in_cultivation" in result: | |
description = result['name_in_cultivation'] + "--" + description | |
descriptions.append(description) | |
else: | |
descriptions.append("") | |
#return tuple of imapge path and description | |
return zip(image_names, descriptions) | |
def search_with_path( self, image_path , threshold = None ): | |
# this is a relatively light weight search | |
image_feature = self.clip_extractor.extract_image_from_file(image_path) | |
# image_search_result = img_db.search_with_image_name(image_path) | |
# image_search_result = self.imgdb.top_k_search(image_feature, top_k=1) | |
image_search_result = self.remote.top_k_search(image_feature, 'clip_feature', top_k=1) | |
search_result = None | |
if threshold is None: | |
threshold = self.minimal_image_threshold | |
if image_search_result and len(image_search_result)>0 and image_search_result[0]['similarity'] > threshold: | |
# try find data with translated_word | |
result = self.remote.search_by_en_keyword(image_search_result[0]['translated_word']) | |
if result and "name_in_cultivation" in result: | |
search_result = result | |
search_result['similarity'] = image_search_result[0]['similarity'] | |
else: | |
print("Warning! Unfound keyword: ", image_search_result[0]['translated_word']) | |
# backup_results = None | |
# if search_result is None: | |
# try search with textdb | |
backup_results = self.remote.top_k_search(image_feature, 'text_feature', top_k = 5) | |
return search_result, backup_results, image_feature | |
def generate_cultivation_data( self, image_path , image_feature, text_search_result ): | |
# this is very expensive | |
cultivation_data = None | |
try: | |
caption_response = self.captioner.caption(image_path) | |
except: | |
print("Error occurred while captioning the image ", image_path) | |
return cultivation_data | |
if text_search_result is None: | |
# complete text search | |
text_search_result = self.remote.top_k_search(image_feature, 'text_feature', top_k = 5) | |
seen = set() | |
keywords = [res['translated_word'] for res in text_search_result if not (res['translated_word'] in seen or seen.add(res['translated_word']))] | |
try: | |
json_response = get_major_object(caption_response , keywords) | |
except: | |
print("Error occurred while getting major object from caption ", caption_response) | |
return cultivation_data | |
in_base_data , alt_data = verify_keyword_in_base(json_response , self.remote ) | |
if in_base_data is not None: | |
cultivation_data = in_base_data | |
# 这意味着找到了一张新的图片,不需要生成额外的词条 | |
# required_fields = ['image_name', 'keyword', 'translated_word'] | |
image_data = { | |
'image_name': image_path, | |
'keyword': in_base_data['keyword'], | |
'translated_word': in_base_data['translated_word'] | |
} | |
#self.imgdb.add_image( image_data, True, image_feature ) | |
self.remote.add_data(image_data, None, image_feature, None) | |
elif alt_data is not None: | |
try: | |
cultivation_data = generate_cultivation_with_rag(alt_data, text_search_result) | |
except: | |
print("Error occurred while generating cultivation data") | |
return cultivation_data | |
new_data = { | |
"keyword": alt_data['keyword'], | |
"name_in_cultivation": cultivation_data['new_name'], | |
"description_in_cultivation": cultivation_data['final_enhanced_description'], | |
"translated_word": alt_data['translated_word'], | |
"description": alt_data['description'] | |
} | |
#self.textdb.add_data(new_data) | |
text_feature = self.textdb.clip_extractor.extract_text(new_data['translated_word'] + '.' + new_data['description']) | |
print("Added new data to textdb: ", new_data["name_in_cultivation"]) | |
image_data = { | |
'image_name': image_path, | |
'keyword': new_data['keyword'], | |
'translated_word': new_data['translated_word'] | |
} | |
#self.imgdb.add_image( image_data, True, image_feature ) | |
self.remote.add_data(image_data, new_data, image_feature, text_feature) | |
print("Added new image to imgdb: ", image_data["keyword"]) | |
cultivation_data = new_data | |
self.remote.add_file(image_path) | |
return cultivation_data | |
if __name__ == "__main__": | |
os.environ['HTTP_PROXY'] = 'http://localhost:8234' | |
os.environ['HTTPS_PROXY'] = 'http://localhost:8234' | |
game_master = GameMaster() | |
target_folder="temp_images" | |
image_files = glob(os.path.join(target_folder, "*.jpg")) | |
for index, image_path in enumerate(image_files): | |
print("index:" , index ) | |
search_result, backup_results, image_feature = game_master.search_with_path(image_path) | |
if search_result: | |
print(search_result) | |
break | |
test_image_path = "temp_images/向日葵.jpg" | |
search_result, backup_results, image_feature = game_master.search_with_path(test_image_path) | |
cultivation_data = game_master.generate_cultivation_data( \ | |
test_image_path, image_feature, backup_results ) | |
print(cultivation_data) |