import os from glob import glob try: from src.Database import Database from src.Captioner import Captioner from src.ImageBase import Imagebase from src.RemoteDatabase import RemoteDatabase from src.get_major_object import get_major_object, verify_keyword_in_base from src.generate_cultivation import generate_cultivation_with_rag except: from Database import Database from Captioner import Captioner from ImageBase import Imagebase from RemoteDatabase import RemoteDatabase from get_major_object import get_major_object, verify_keyword_in_base from generate_cultivation import generate_cultivation_with_rag class GameMaster: def __init__( self ): self.textdb = self.init_textdb() self.clip_extractor = self.textdb.clip_extractor self.imgdb = self.init_imgdb() self.captioner = Captioner() self.minimal_image_threshold = 0.9 self.remote = RemoteDatabase() def init_textdb( self ): text_db = Database() text_db.init_bge_extractor() text_db.init_clip_extractor() return text_db def init_imgdb( self ): img_db = Imagebase() return img_db def random_image_text_data( self, n = 12 ): random_img_datas = self.remote.random_sample(n) # keep image_name and keywords only image_names = [img_data['image_name'] for img_data in random_img_datas] blank_image_path = "datas/blank_item.jpg" for i in range(len(image_names)): if not os.path.exists(image_names[i]): image_names[i] = blank_image_path keywords_zh = [img_data['keyword'] for img_data in random_img_datas] keywords = [img_data['translated_word'] for img_data in random_img_datas] descriptions = [] for keyword, keyword_zh in zip(keywords, keywords_zh): result = self.remote.search_by_en_keyword(keyword) if result and "description_in_cultivation" in result: description = result['description_in_cultivation'] if "name_in_cultivation" in result: description = result['name_in_cultivation'] + "--" + description descriptions.append(description) else: descriptions.append("") #return tuple of imapge path and description return zip(image_names, descriptions) def search_with_path( self, image_path , threshold = None ): # this is a relatively light weight search image_feature = self.clip_extractor.extract_image_from_file(image_path) # image_search_result = img_db.search_with_image_name(image_path) # image_search_result = self.imgdb.top_k_search(image_feature, top_k=1) image_search_result = self.remote.top_k_search(image_feature, 'clip_feature', top_k=1) search_result = None if threshold is None: threshold = self.minimal_image_threshold if image_search_result and len(image_search_result)>0 and image_search_result[0]['similarity'] > threshold: # try find data with translated_word result = self.remote.search_by_en_keyword(image_search_result[0]['translated_word']) if result and "name_in_cultivation" in result: search_result = result search_result['similarity'] = image_search_result[0]['similarity'] else: print("Warning! Unfound keyword: ", image_search_result[0]['translated_word']) # backup_results = None # if search_result is None: # try search with textdb backup_results = self.remote.top_k_search(image_feature, 'text_feature', top_k = 5) return search_result, backup_results, image_feature def generate_cultivation_data( self, image_path , image_feature, text_search_result ): # this is very expensive cultivation_data = None try: caption_response = self.captioner.caption(image_path) except: print("Error occurred while captioning the image ", image_path) return cultivation_data if text_search_result is None: # complete text search text_search_result = self.remote.top_k_search(image_feature, 'text_feature', top_k = 5) seen = set() keywords = [res['translated_word'] for res in text_search_result if not (res['translated_word'] in seen or seen.add(res['translated_word']))] try: json_response = get_major_object(caption_response , keywords) except: print("Error occurred while getting major object from caption ", caption_response) return cultivation_data in_base_data , alt_data = verify_keyword_in_base(json_response , self.remote ) if in_base_data is not None: cultivation_data = in_base_data # 这意味着找到了一张新的图片,不需要生成额外的词条 # required_fields = ['image_name', 'keyword', 'translated_word'] image_data = { 'image_name': image_path, 'keyword': in_base_data['keyword'], 'translated_word': in_base_data['translated_word'] } #self.imgdb.add_image( image_data, True, image_feature ) self.remote.add_data(image_data, None, image_feature, None) elif alt_data is not None: try: cultivation_data = generate_cultivation_with_rag(alt_data, text_search_result) except: print("Error occurred while generating cultivation data") return cultivation_data new_data = { "keyword": alt_data['keyword'], "name_in_cultivation": cultivation_data['new_name'], "description_in_cultivation": cultivation_data['final_enhanced_description'], "translated_word": alt_data['translated_word'], "description": alt_data['description'] } #self.textdb.add_data(new_data) text_feature = self.textdb.clip_extractor.extract_text(new_data['translated_word'] + '.' + new_data['description']) print("Added new data to textdb: ", new_data["name_in_cultivation"]) image_data = { 'image_name': image_path, 'keyword': new_data['keyword'], 'translated_word': new_data['translated_word'] } #self.imgdb.add_image( image_data, True, image_feature ) self.remote.add_data(image_data, new_data, image_feature, text_feature) print("Added new image to imgdb: ", image_data["keyword"]) cultivation_data = new_data self.remote.add_file(image_path) return cultivation_data if __name__ == "__main__": os.environ['HTTP_PROXY'] = 'http://localhost:8234' os.environ['HTTPS_PROXY'] = 'http://localhost:8234' game_master = GameMaster() target_folder="temp_images" image_files = glob(os.path.join(target_folder, "*.jpg")) for index, image_path in enumerate(image_files): print("index:" , index ) search_result, backup_results, image_feature = game_master.search_with_path(image_path) if search_result: print(search_result) break test_image_path = "temp_images/向日葵.jpg" search_result, backup_results, image_feature = game_master.search_with_path(test_image_path) cultivation_data = game_master.generate_cultivation_data( \ test_image_path, image_feature, backup_results ) print(cultivation_data)