Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu
commited on
Commit
•
2af553a
1
Parent(s):
ed27351
为xmbot加入图片压缩功能,防止上传的图像过大
Browse files- config_example.json +2 -0
- modules/config.py +5 -2
- modules/models.py +35 -20
config_example.json
CHANGED
@@ -2,6 +2,8 @@
|
|
2 |
// 你的OpenAI API Key,一般必填,
|
3 |
// 若缺省填为 "openai_api_key": "" 则必须再在图形界面中填入API Key
|
4 |
"openai_api_key": "",
|
|
|
|
|
5 |
"language": "auto",
|
6 |
// 如果使用代理,请取消注释下面的两行,并替换代理URL
|
7 |
// "https_proxy": "http://127.0.0.1:1079",
|
|
|
2 |
// 你的OpenAI API Key,一般必填,
|
3 |
// 若缺省填为 "openai_api_key": "" 则必须再在图形界面中填入API Key
|
4 |
"openai_api_key": "",
|
5 |
+
// 你的xmbot API Key,与OpenAI API Key不同
|
6 |
+
"xmbot_api_key": "",
|
7 |
"language": "auto",
|
8 |
// 如果使用代理,请取消注释下面的两行,并替换代理URL
|
9 |
// "https_proxy": "http://127.0.0.1:1079",
|
modules/config.py
CHANGED
@@ -31,7 +31,7 @@ if os.path.exists("config.json"):
|
|
31 |
config = json.load(f)
|
32 |
else:
|
33 |
config = {}
|
34 |
-
|
35 |
language = config.get("language", "auto")
|
36 |
|
37 |
if os.path.exists("api_key.txt"):
|
@@ -64,9 +64,12 @@ if os.environ.get("dockerrun") == "yes":
|
|
64 |
dockerflag = True
|
65 |
|
66 |
## 处理 api-key 以及 允许的用户列表
|
67 |
-
my_api_key = config.get("openai_api_key", "")
|
68 |
my_api_key = os.environ.get("my_api_key", my_api_key)
|
69 |
|
|
|
|
|
|
|
70 |
## 多账户机制
|
71 |
multi_api_key = config.get("multi_api_key", False) # 是否开启多账户机制
|
72 |
if multi_api_key:
|
|
|
31 |
config = json.load(f)
|
32 |
else:
|
33 |
config = {}
|
34 |
+
|
35 |
language = config.get("language", "auto")
|
36 |
|
37 |
if os.path.exists("api_key.txt"):
|
|
|
64 |
dockerflag = True
|
65 |
|
66 |
## 处理 api-key 以及 允许的用户列表
|
67 |
+
my_api_key = config.get("openai_api_key", "")
|
68 |
my_api_key = os.environ.get("my_api_key", my_api_key)
|
69 |
|
70 |
+
xmbot_api_key = config.get("xmbot_api_key", "")
|
71 |
+
os.environ["XMBOT_API_KEY"] = xmbot_api_key
|
72 |
+
|
73 |
## 多账户机制
|
74 |
multi_api_key = config.get("multi_api_key", False) # 是否开启多账户机制
|
75 |
if multi_api_key:
|
modules/models.py
CHANGED
@@ -9,6 +9,9 @@ import sys
|
|
9 |
import requests
|
10 |
import urllib3
|
11 |
import platform
|
|
|
|
|
|
|
12 |
|
13 |
from tqdm import tqdm
|
14 |
import colorama
|
@@ -328,15 +331,6 @@ class LLaMA_Client(BaseLLMModel):
|
|
328 |
data_args=data_args,
|
329 |
pipeline_args=pipeline_args,
|
330 |
)
|
331 |
-
# Chats
|
332 |
-
# model_name = model_args.model_name_or_path
|
333 |
-
# if model_args.lora_model_path is not None:
|
334 |
-
# model_name += f" + {model_args.lora_model_path}"
|
335 |
-
|
336 |
-
# context = (
|
337 |
-
# "You are a helpful assistant who follows the given instructions"
|
338 |
-
# " unconditionally."
|
339 |
-
# )
|
340 |
|
341 |
def _get_llama_style_input(self):
|
342 |
history = []
|
@@ -406,26 +400,45 @@ class XMBot_Client(BaseLLMModel):
|
|
406 |
self.session_id = str(uuid.uuid4())
|
407 |
return [], "已重置"
|
408 |
|
409 |
-
def
|
410 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
411 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
412 |
def is_image_file(filepath):
|
413 |
# 判断文件是否为图片
|
414 |
valid_image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"]
|
415 |
file_extension = os.path.splitext(filepath)[1].lower()
|
416 |
return file_extension in valid_image_extensions
|
417 |
|
418 |
-
def read_image_as_bytes(filepath):
|
419 |
-
# 读取图片文件并返回比特流
|
420 |
-
with open(filepath, "rb") as f:
|
421 |
-
image_bytes = f.read()
|
422 |
-
return image_bytes
|
423 |
-
|
424 |
if is_image_file(filepath):
|
425 |
logging.info(f"读取图片文件: {filepath}")
|
426 |
-
image_bytes =
|
427 |
-
base64_encoded_image = base64.b64encode(image_bytes).decode()
|
428 |
-
self.image_bytes = base64_encoded_image
|
429 |
self.image_path = filepath
|
430 |
else:
|
431 |
self.image_bytes = None
|
@@ -529,6 +542,8 @@ def get_model(
|
|
529 |
msg += f" + {lora_model_path}"
|
530 |
model = LLaMA_Client(model_name, lora_model_path)
|
531 |
elif model_type == ModelType.XMBot:
|
|
|
|
|
532 |
model = XMBot_Client(api_key=access_key)
|
533 |
elif model_type == ModelType.Unknown:
|
534 |
raise ValueError(f"未知模型: {model_name}")
|
|
|
9 |
import requests
|
10 |
import urllib3
|
11 |
import platform
|
12 |
+
import base64
|
13 |
+
from io import BytesIO
|
14 |
+
from PIL import Image
|
15 |
|
16 |
from tqdm import tqdm
|
17 |
import colorama
|
|
|
331 |
data_args=data_args,
|
332 |
pipeline_args=pipeline_args,
|
333 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
|
335 |
def _get_llama_style_input(self):
|
336 |
history = []
|
|
|
400 |
self.session_id = str(uuid.uuid4())
|
401 |
return [], "已重置"
|
402 |
|
403 |
+
def image_to_base64(self, image_path):
|
404 |
+
# 打开并加载图片
|
405 |
+
img = Image.open(image_path)
|
406 |
+
|
407 |
+
# 获取图片的宽度和高度
|
408 |
+
width, height = img.size
|
409 |
+
|
410 |
+
# 计算压缩比例,以确保最长边小于4096像素
|
411 |
+
max_dimension = 2048
|
412 |
+
scale_ratio = min(max_dimension / width, max_dimension / height)
|
413 |
+
|
414 |
+
if scale_ratio < 1:
|
415 |
+
# 按压缩比例调整图片大小
|
416 |
+
new_width = int(width * scale_ratio)
|
417 |
+
new_height = int(height * scale_ratio)
|
418 |
+
img = img.resize((new_width, new_height), Image.ANTIALIAS)
|
419 |
|
420 |
+
# 将图片转换为jpg格式的二进制数据
|
421 |
+
buffer = BytesIO()
|
422 |
+
if img.mode == "RGBA":
|
423 |
+
img = img.convert("RGB")
|
424 |
+
img.save(buffer, format='JPEG')
|
425 |
+
binary_image = buffer.getvalue()
|
426 |
+
|
427 |
+
# 对二进制数据进行Base64编码
|
428 |
+
base64_image = base64.b64encode(binary_image).decode('utf-8')
|
429 |
+
|
430 |
+
return base64_image
|
431 |
+
|
432 |
+
def try_read_image(self, filepath):
|
433 |
def is_image_file(filepath):
|
434 |
# 判断文件是否为图片
|
435 |
valid_image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"]
|
436 |
file_extension = os.path.splitext(filepath)[1].lower()
|
437 |
return file_extension in valid_image_extensions
|
438 |
|
|
|
|
|
|
|
|
|
|
|
|
|
439 |
if is_image_file(filepath):
|
440 |
logging.info(f"读取图片文件: {filepath}")
|
441 |
+
self.image_bytes = self.image_to_base64(filepath)
|
|
|
|
|
442 |
self.image_path = filepath
|
443 |
else:
|
444 |
self.image_bytes = None
|
|
|
542 |
msg += f" + {lora_model_path}"
|
543 |
model = LLaMA_Client(model_name, lora_model_path)
|
544 |
elif model_type == ModelType.XMBot:
|
545 |
+
if os.environ.get("XMBOT_API_KEY") != "":
|
546 |
+
access_key = os.environ.get("XMBOT_API_KEY")
|
547 |
model = XMBot_Client(api_key=access_key)
|
548 |
elif model_type == ModelType.Unknown:
|
549 |
raise ValueError(f"未知模型: {model_name}")
|