Spaces:
Runtime error
Runtime error
File size: 6,772 Bytes
0117cec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import json
def data2reference( top_k_items, output_n = 3 ):
outputted_items = set()
output_str = "#Reference:\n"
for item in top_k_items:
item_in_life = item["keyword"]
if item_in_life in outputted_items:
continue
name_in_cultivation = item["name_in_cultivation"]
description_in_cultivation = item["description_in_cultivation"]
# output_str += f"name_in_life: {item_in_life}\n"
# output_str += f"name_in_cultivation: {name_in_cultivation}\n"
# output_str += f"description_in_cultivation: {description_in_cultivation}\n\n"
# output with into json format
output_data = {
"name_in_life": item_in_life,
"name_in_cultivation": name_in_cultivation,
"description_in_cultivation": description_in_cultivation
}
output_str += json.dumps(output_data, ensure_ascii=False) + "\n\n"
outputted_items.add(item_in_life)
if len(outputted_items) >= output_n:
break
return output_str.strip()
def data2prompt(query_item , top_k_items):
reference_prompt = data2reference(top_k_items, 3)
task_prompt1 = "\n请参考Reference中的物品描述,将Input中的输入物品,联系改写成修仙世界中的对应物品\n"
input_prompt = "# Input:\n"
if "keyword" in query_item:
input_prompt += f"input_name:{query_item['keyword']}\n"
if "description" in query_item:
input_prompt += f"description_in_life:{query_item['description']}\n"
else:
# directly dump query_item
input_prompt += json.dumps(query_item, ensure_ascii=False) + "\n"
CoT_prompt = \
"""Let's think it step by step,以json形式输出逐个字段。包含以下字段
- name_in_life: 进一步明确要生成描述的物品名称
- name_in_cultivation_1: 尝试编写物品在修仙界对应的名称
- description_in_cultivation_1: 尝试编写物品在修仙界对应的描述
- echo_1: "我将分析description_in_cultivation_1与Reference中的差异,分析description_in_cultivation_1是否已经足够生动"
- critique: 相比于Reference中的描述,分析description_in_cultivation_1在哪些方面有所欠缺
- echo_2: "根据input_name和description_in_cultivation_1,我将分析从物体的哪些属性,可以进一步加强、夸张和修改描述"
- analysis: 分析从物体的哪些属性,可以进一步加强、夸张和修改描述
- echo_3: "我将尝试3次,从不同角度加强description_in_cultivation_1的描述"
- candidate_descriptions: 从不同角度,输出3次不同的加强后的描述
- analysis_candidates: 分析各个candidates有什么优点
- echo_4: "根据analysis_candidates,我将merge出一个最终的描述"
- final_enhanced_description: 通过各个candidates的优点, merge出一个最终的描述
- echo_5: "我将分析根据final_description,是否简易将物品名称替换为新的名词"
- name_fit_analysis: 分析item_name是否还匹配final_description的描述,是否需要给input_name起一个更响亮的名字
- new_name: 如果需要,给input_name起一个更响亮的名字, 如果不需要,则仍然输出name_in_cultivation_1
"""
return reference_prompt + task_prompt1 + input_prompt + CoT_prompt
try:
from src.ZhipuClient import ZhipuClient
except:
from ZhipuClient import ZhipuClient
zhipu_client = None
import json
def markdown_to_json(markdown_str):
# 移除Markdown语法中可能存在的标记,如代码块标记等
if markdown_str.startswith("```json"):
markdown_str = markdown_str[7:-3].strip()
elif markdown_str.startswith("```"):
markdown_str = markdown_str[3:-3].strip()
# 将字符串转换为JSON字典
json_dict = json.loads(markdown_str)
return json_dict
import re
def forced_extract(input_str, keywords):
result = {key: "" for key in keywords}
for key in keywords:
# 使用正则表达式来查找关键词-值对
pattern = f'"{key}":\s*"(.*?)"'
match = re.search(pattern, input_str)
if match:
result[key] = match.group(1)
return result
def generate_cultivation_with_rag( query_item, search_result ):
global zhipu_client
if zhipu_client is None:
zhipu_client = ZhipuClient()
prompt = data2prompt(query_item, search_result)
response = zhipu_client.prompt2response(prompt)
try:
json_response = markdown_to_json(response)
except:
keyword_list = ["name_in_life", "name_in_cultivation_1","description_in_cultivation_1", "final_enhanced_description", "new_name"]
json_response = forced_extract(response, keyword_list)
if "new_name" not in json_response or json_response["new_name"] == "":
if "name_in_cultivation_1" in json_response:
json_response["new_name"] = json_response["name_in_cultivation_1"]
else:
json_response["new_name"] = ""
if "final_enhanced_description" not in json_response or json_response["final_enhanced_description"] == "":
if "description_in_cultivation_1" in json_response:
json_response["final_enhanced_description"] = json_response["description_in_cultivation_1"]
else:
json_response["final_enhanced_description"] = json_response["new_name"]
return json_response
if __name__ == '__main__':
try:
from src.Database import Database
except:
from Database import Database
db = Database()
try:
from src.Captioner import Captioner
except:
from Captioner import Captioner
import os
os.environ['HTTP_PROXY'] = 'http://localhost:8234'
os.environ['HTTPS_PROXY'] = 'http://localhost:8234'
captioner = Captioner()
test_image = "temp_images/3or47vg0.jpg"
caption_response = captioner.caption(test_image)
# print(caption_response)
search_result = db.search_with_image_name( test_image )
# print(search_result[0].keys())
# reference_str = data2reference(search_result, output_n = 3)
# print(reference_str)
seen = set()
keywords = [res['translated_word'] for res in search_result if not (res['translated_word'] in seen or seen.add(res['translated_word']))]
# print(keywords)
# prompt = data2prompt(caption_response , keywords)
# print(prompt)
from get_major_object import get_major_object, verify_keyword_in_base
json_response = get_major_object(caption_response , keywords)
print(json_response)
print()
in_base_data , alt_data = verify_keyword_in_base(json_response , db)
if alt_data is not None:
result = generate_cultivation_with_rag(alt_data , search_result)
print(result)
|