Spaces:
Sleeping
Sleeping
import gradio as gr | |
from langchain_core.messages import HumanMessage, AIMessage | |
from llm import DeepSeekLLM, OpenRouterLLM, TongYiLLM | |
from config import settings | |
from modelscope.outputs import OutputKeys | |
from modelscope.pipelines import pipeline | |
from modelscope.utils.constant import Tasks | |
import cv2 | |
from diffusers import StableDiffusionXLPipeline | |
import torch | |
deep_seek_llm = DeepSeekLLM(api_key=settings.deep_seek_api_key) | |
open_router_llm = OpenRouterLLM(api_key=settings.open_router_api_key) | |
tongyi_llm = TongYiLLM(api_key=settings.tongyi_api_key) | |
def init_chat(): | |
return deep_seek_llm.get_chat_engine() | |
def predict(message, history, chat): | |
if chat is None: | |
chat = init_chat() | |
history_messages = [] | |
for human, assistant in history: | |
history_messages.append(HumanMessage(content=human)) | |
history_messages.append(AIMessage(content=assistant)) | |
history_messages.append(HumanMessage(content=message.text)) | |
response_message = '' | |
for chunk in chat.stream(history_messages): | |
response_message = response_message + chunk.content | |
yield response_message | |
def update_chat(_provider: str, _chat, _model: str, _temperature: float, _max_tokens: int): | |
print('?????', _provider, _chat, _model, _temperature, _max_tokens) | |
if _provider == 'DeepSeek': | |
_chat = deep_seek_llm.get_chat_engine(model=_model, temperature=_temperature, max_tokens=_max_tokens) | |
if _provider == 'OpenRouter': | |
_chat = open_router_llm.get_chat_engine(model=_model, temperature=_temperature, max_tokens=_max_tokens) | |
if _provider == 'Tongyi': | |
_chat = tongyi_llm.get_chat_engine(model=_model, temperature=_temperature, max_tokens=_max_tokens) | |
return _chat | |
def object_remove(_image, refined): | |
mask = _image['layers'][0] | |
mask = mask.convert('L') | |
_input = { | |
'img': _image['background'].convert('RGB'), | |
'mask': mask, | |
} | |
inpainting = pipeline(Tasks.image_inpainting, model='damo/cv_fft_inpainting_lama', refined=refined) | |
result = inpainting(_input) | |
vis_img = result[OutputKeys.OUTPUT_IMG] | |
vis_img = cv2.cvtColor(vis_img, cv2.COLOR_BGR2RGB) | |
return vis_img, mask | |
def bg_remove(_image, _type): | |
input_image = _image['background'].convert('RGB') | |
if _type == '人像': | |
matting = pipeline(Tasks.portrait_matting, model='damo/cv_unet_image-matting') | |
else: | |
matting = pipeline(Tasks.universal_matting, model='damo/cv_unet_universal-matting') | |
result = matting(input_image) | |
vis_img = result[OutputKeys.OUTPUT_IMG] | |
vis_img = cv2.cvtColor(vis_img, cv2.COLOR_BGRA2RGBA) | |
return vis_img | |
def text_to_image(_image, _prompt): | |
t2i_pipeline = StableDiffusionXLPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
torch_dtype=torch.float16, | |
variant="fp16", | |
use_safetensors=True, | |
).to("cuda") | |
result = t2i_pipeline( | |
prompt=_prompt, | |
negative_prompt='ugly', | |
num_inference_steps=22, | |
width=1024, | |
height=1024, | |
guidance_scale=7, | |
).images[0] | |
return result | |
with gr.Blocks() as app: | |
with gr.Tab('聊天'): | |
chat_engine = gr.State(value=None) | |
with gr.Row(): | |
with gr.Column(scale=2, min_width=600): | |
chatbot = gr.ChatInterface( | |
predict, | |
multimodal=True, | |
chatbot=gr.Chatbot(elem_id="chatbot", height=600, show_share_button=False), | |
textbox=gr.MultimodalTextbox(lines=1), | |
additional_inputs=[chat_engine] | |
) | |
with gr.Column(scale=1, min_width=300): | |
with gr.Accordion('参数设置', open=True): | |
with gr.Column(): | |
provider = gr.Dropdown( | |
label='模型厂商', | |
choices=['DeepSeek', 'OpenRouter', 'Tongyi'], | |
value='DeepSeek', | |
info='不同模型厂商参数,效果和价格略有不同,请先设置好对应模型厂商的 API Key。', | |
) | |
def show_model_config_panel(_provider): | |
if _provider == 'DeepSeek': | |
with gr.Column(): | |
model = gr.Dropdown( | |
label='模型', | |
choices=deep_seek_llm.support_models, | |
value=deep_seek_llm.default_model | |
) | |
temperature = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
step=0.1, | |
value=deep_seek_llm.default_temperature, | |
label="Temperature", | |
key="temperature", | |
) | |
max_tokens = gr.Slider( | |
minimum=1024, | |
maximum=1024 * 20, | |
step=128, | |
value=deep_seek_llm.default_max_tokens, | |
label="Max Tokens", | |
key="max_tokens", | |
) | |
model.change( | |
fn=update_chat, | |
inputs=[provider, chat_engine, model, temperature, max_tokens], | |
outputs=[chat_engine], | |
) | |
temperature.change( | |
fn=update_chat, | |
inputs=[provider, chat_engine, model, temperature, max_tokens], | |
outputs=[chat_engine], | |
) | |
max_tokens.change( | |
fn=update_chat, | |
inputs=[provider, chat_engine, model, temperature, max_tokens], | |
outputs=[chat_engine], | |
) | |
if _provider == 'OpenRouter': | |
with gr.Column(): | |
model = gr.Dropdown( | |
label='模型', | |
choices=open_router_llm.support_models, | |
value=open_router_llm.default_model | |
) | |
temperature = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
step=0.1, | |
value=open_router_llm.default_temperature, | |
label="Temperature", | |
key="temperature", | |
) | |
max_tokens = gr.Slider( | |
minimum=1024, | |
maximum=1024 * 20, | |
step=128, | |
value=open_router_llm.default_max_tokens, | |
label="Max Tokens", | |
key="max_tokens", | |
) | |
model.change( | |
fn=update_chat, | |
inputs=[provider, chat_engine, model, temperature, max_tokens], | |
outputs=[chat_engine], | |
) | |
temperature.change( | |
fn=update_chat, | |
inputs=[provider, chat_engine, model, temperature, max_tokens], | |
outputs=[chat_engine], | |
) | |
max_tokens.change( | |
fn=update_chat, | |
inputs=[provider, chat_engine, model, temperature, max_tokens], | |
outputs=[chat_engine], | |
) | |
if _provider == 'Tongyi': | |
with gr.Column(): | |
model = gr.Dropdown( | |
label='模型', | |
choices=tongyi_llm.support_models, | |
value=tongyi_llm.default_model | |
) | |
temperature = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
step=0.1, | |
value=tongyi_llm.default_temperature, | |
label="Temperature", | |
key="temperature", | |
) | |
max_tokens = gr.Slider( | |
minimum=1000, | |
maximum=2000, | |
step=100, | |
value=tongyi_llm.default_max_tokens, | |
label="Max Tokens", | |
key="max_tokens", | |
) | |
model.change( | |
fn=update_chat, | |
inputs=[provider, chat_engine, model, temperature, max_tokens], | |
outputs=[chat_engine], | |
) | |
temperature.change( | |
fn=update_chat, | |
inputs=[provider, chat_engine, model, temperature, max_tokens], | |
outputs=[chat_engine], | |
) | |
max_tokens.change( | |
fn=update_chat, | |
inputs=[provider, chat_engine, model, temperature, max_tokens], | |
outputs=[chat_engine], | |
) | |
with gr.Tab('图像编辑'): | |
with gr.Row(): | |
with gr.Column(scale=2, min_width=600): | |
image = gr.ImageMask( | |
type='pil', | |
brush=gr.Brush(colors=["rgba(255, 255, 255, 0.9)"]), | |
) | |
with gr.Row(): | |
mask_preview = gr.Image(label='蒙板预览') | |
image_preview = gr.Image(label='图片预览') | |
with gr.Column(scale=1, min_width=300): | |
with gr.Accordion(label="物体移除"): | |
object_remove_refined = gr.Checkbox(label="Refined(GPU)", info="只支持 GPU, 开启将获得更好的效果") | |
object_remove_btn = gr.Button('物体移除', variant='primary') | |
with gr.Accordion(label="背景移除"): | |
bg_remove_type = gr.Radio(["人像", "通用"], label="类型", value='人像') | |
bg_remove_btn = gr.Button('背景移除', variant='primary') | |
object_remove_btn.click(fn=object_remove, inputs=[image, object_remove_refined], outputs=[image_preview, mask_preview]) | |
bg_remove_btn.click(fn=bg_remove, inputs=[image, bg_remove_type], outputs=[image_preview]) | |
with gr.Tab('画图(GPU)'): | |
with gr.Row(): | |
with gr.Column(scale=2, min_width=600): | |
image = gr.Image() | |
with gr.Column(scale=1, min_width=300): | |
with gr.Accordion(label="图像生成"): | |
prompt = gr.Textbox(label="提示语", value="", lines=3) | |
t2i_btn = gr.Button('画图', variant='primary') | |
t2i_btn.click(fn=text_to_image, inputs=[prompt, image], outputs=[image]) | |
app.launch(debug=settings.debug, show_api=False) | |