twn39's picture
update
44247fb
raw
history blame
12.2 kB
import gradio as gr
from langchain_core.messages import HumanMessage, AIMessage
from llm import DeepSeekLLM, OpenRouterLLM, TongYiLLM
from config import settings
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
import cv2
from diffusers import StableDiffusionXLPipeline
import torch
deep_seek_llm = DeepSeekLLM(api_key=settings.deep_seek_api_key)
open_router_llm = OpenRouterLLM(api_key=settings.open_router_api_key)
tongyi_llm = TongYiLLM(api_key=settings.tongyi_api_key)
def init_chat():
return deep_seek_llm.get_chat_engine()
def predict(message, history, chat):
if chat is None:
chat = init_chat()
history_messages = []
for human, assistant in history:
history_messages.append(HumanMessage(content=human))
history_messages.append(AIMessage(content=assistant))
history_messages.append(HumanMessage(content=message.text))
response_message = ''
for chunk in chat.stream(history_messages):
response_message = response_message + chunk.content
yield response_message
def update_chat(_provider: str, _chat, _model: str, _temperature: float, _max_tokens: int):
print('?????', _provider, _chat, _model, _temperature, _max_tokens)
if _provider == 'DeepSeek':
_chat = deep_seek_llm.get_chat_engine(model=_model, temperature=_temperature, max_tokens=_max_tokens)
if _provider == 'OpenRouter':
_chat = open_router_llm.get_chat_engine(model=_model, temperature=_temperature, max_tokens=_max_tokens)
if _provider == 'Tongyi':
_chat = tongyi_llm.get_chat_engine(model=_model, temperature=_temperature, max_tokens=_max_tokens)
return _chat
def object_remove(_image, refined):
mask = _image['layers'][0]
mask = mask.convert('L')
_input = {
'img': _image['background'].convert('RGB'),
'mask': mask,
}
inpainting = pipeline(Tasks.image_inpainting, model='damo/cv_fft_inpainting_lama', refined=refined)
result = inpainting(_input)
vis_img = result[OutputKeys.OUTPUT_IMG]
vis_img = cv2.cvtColor(vis_img, cv2.COLOR_BGR2RGB)
return vis_img, mask
def bg_remove(_image, _type):
input_image = _image['background'].convert('RGB')
if _type == '人像':
matting = pipeline(Tasks.portrait_matting, model='damo/cv_unet_image-matting')
else:
matting = pipeline(Tasks.universal_matting, model='damo/cv_unet_universal-matting')
result = matting(input_image)
vis_img = result[OutputKeys.OUTPUT_IMG]
vis_img = cv2.cvtColor(vis_img, cv2.COLOR_BGRA2RGBA)
return vis_img
def text_to_image(_image, _prompt):
t2i_pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
).to("cuda")
result = t2i_pipeline(
prompt=_prompt,
negative_prompt='ugly',
num_inference_steps=22,
width=1024,
height=1024,
guidance_scale=7,
).images[0]
return result
with gr.Blocks() as app:
with gr.Tab('聊天'):
chat_engine = gr.State(value=None)
with gr.Row():
with gr.Column(scale=2, min_width=600):
chatbot = gr.ChatInterface(
predict,
multimodal=True,
chatbot=gr.Chatbot(elem_id="chatbot", height=600, show_share_button=False),
textbox=gr.MultimodalTextbox(lines=1),
additional_inputs=[chat_engine]
)
with gr.Column(scale=1, min_width=300):
with gr.Accordion('参数设置', open=True):
with gr.Column():
provider = gr.Dropdown(
label='模型厂商',
choices=['DeepSeek', 'OpenRouter', 'Tongyi'],
value='DeepSeek',
info='不同模型厂商参数,效果和价格略有不同,请先设置好对应模型厂商的 API Key。',
)
@gr.render(inputs=provider)
def show_model_config_panel(_provider):
if _provider == 'DeepSeek':
with gr.Column():
model = gr.Dropdown(
label='模型',
choices=deep_seek_llm.support_models,
value=deep_seek_llm.default_model
)
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=deep_seek_llm.default_temperature,
label="Temperature",
key="temperature",
)
max_tokens = gr.Slider(
minimum=1024,
maximum=1024 * 20,
step=128,
value=deep_seek_llm.default_max_tokens,
label="Max Tokens",
key="max_tokens",
)
model.change(
fn=update_chat,
inputs=[provider, chat_engine, model, temperature, max_tokens],
outputs=[chat_engine],
)
temperature.change(
fn=update_chat,
inputs=[provider, chat_engine, model, temperature, max_tokens],
outputs=[chat_engine],
)
max_tokens.change(
fn=update_chat,
inputs=[provider, chat_engine, model, temperature, max_tokens],
outputs=[chat_engine],
)
if _provider == 'OpenRouter':
with gr.Column():
model = gr.Dropdown(
label='模型',
choices=open_router_llm.support_models,
value=open_router_llm.default_model
)
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=open_router_llm.default_temperature,
label="Temperature",
key="temperature",
)
max_tokens = gr.Slider(
minimum=1024,
maximum=1024 * 20,
step=128,
value=open_router_llm.default_max_tokens,
label="Max Tokens",
key="max_tokens",
)
model.change(
fn=update_chat,
inputs=[provider, chat_engine, model, temperature, max_tokens],
outputs=[chat_engine],
)
temperature.change(
fn=update_chat,
inputs=[provider, chat_engine, model, temperature, max_tokens],
outputs=[chat_engine],
)
max_tokens.change(
fn=update_chat,
inputs=[provider, chat_engine, model, temperature, max_tokens],
outputs=[chat_engine],
)
if _provider == 'Tongyi':
with gr.Column():
model = gr.Dropdown(
label='模型',
choices=tongyi_llm.support_models,
value=tongyi_llm.default_model
)
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=tongyi_llm.default_temperature,
label="Temperature",
key="temperature",
)
max_tokens = gr.Slider(
minimum=1000,
maximum=2000,
step=100,
value=tongyi_llm.default_max_tokens,
label="Max Tokens",
key="max_tokens",
)
model.change(
fn=update_chat,
inputs=[provider, chat_engine, model, temperature, max_tokens],
outputs=[chat_engine],
)
temperature.change(
fn=update_chat,
inputs=[provider, chat_engine, model, temperature, max_tokens],
outputs=[chat_engine],
)
max_tokens.change(
fn=update_chat,
inputs=[provider, chat_engine, model, temperature, max_tokens],
outputs=[chat_engine],
)
with gr.Tab('图像编辑'):
with gr.Row():
with gr.Column(scale=2, min_width=600):
image = gr.ImageMask(
type='pil',
brush=gr.Brush(colors=["rgba(255, 255, 255, 0.9)"]),
)
with gr.Row():
mask_preview = gr.Image(label='蒙板预览')
image_preview = gr.Image(label='图片预览')
with gr.Column(scale=1, min_width=300):
with gr.Accordion(label="物体移除"):
object_remove_refined = gr.Checkbox(label="Refined(GPU)", info="只支持 GPU, 开启将获得更好的效果")
object_remove_btn = gr.Button('物体移除', variant='primary')
with gr.Accordion(label="背景移除"):
bg_remove_type = gr.Radio(["人像", "通用"], label="类型", value='人像')
bg_remove_btn = gr.Button('背景移除', variant='primary')
object_remove_btn.click(fn=object_remove, inputs=[image, object_remove_refined], outputs=[image_preview, mask_preview])
bg_remove_btn.click(fn=bg_remove, inputs=[image, bg_remove_type], outputs=[image_preview])
with gr.Tab('画图(GPU)'):
with gr.Row():
with gr.Column(scale=2, min_width=600):
image = gr.Image()
with gr.Column(scale=1, min_width=300):
with gr.Accordion(label="图像生成"):
prompt = gr.Textbox(label="提示语", value="", lines=3)
t2i_btn = gr.Button('画图', variant='primary')
t2i_btn.click(fn=text_to_image, inputs=[prompt, image], outputs=[image])
app.launch(debug=settings.debug, show_api=False)