Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import numpy as np | |
import torch.nn.functional as F | |
import PIL | |
import random | |
from threading import Thread | |
from transformers import AutoModel, AutoProcessor | |
from transformers import StoppingCriteria, TextIteratorStreamer, StoppingCriteriaList | |
from torchvision.transforms.functional import normalize | |
from huggingface_hub import hf_hub_download, InferenceClient | |
from briarmbg import BriaRMBG | |
from PIL import Image | |
from typing import Tuple | |
net=BriaRMBG() | |
# model_path = "./model1.pth" | |
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth') | |
if torch.cuda.is_available(): | |
net.load_state_dict(torch.load(model_path)) | |
net=net.cuda() | |
else: | |
net.load_state_dict(torch.load(model_path,map_location="cpu")) | |
net.eval() | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
model = AutoModel.from_pretrained("unum-cloud/uform-gen2-qwen-500m", trust_remote_code=True).to(device) | |
processor = AutoProcessor.from_pretrained("unum-cloud/uform-gen2-qwen-500m", trust_remote_code=True) | |
class StopOnTokens(StoppingCriteria): | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
stop_ids = [151645] | |
for stop_id in stop_ids: | |
if input_ids[0][-1] == stop_id: | |
return True | |
return False | |
def format_prompt(message, history): | |
prompt = "" | |
if history: | |
for user_prompt, bot_response in history: | |
prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>" | |
prompt += f"<start_of_turn>model{bot_response}" | |
prompt += f"<start_of_turn>user{message}<end_of_turn><start_of_turn>modelo" | |
return prompt | |
def getProductTitle(history, context, image): | |
product_description=getImageDescription(image) | |
prompt="We have a product which is a" + context + ". Product description is as follows: " + product_description + ". Please write a product title options for it." | |
yield interactWithModel(history, prompt) | |
def getProductDescription(history): | |
prompt="Please also write an SEO friendly description for it describing its value to its users." | |
yield interactWithModel(history, prompt) | |
def interactWithModel(history, prompt): | |
system_prompt="You're a helpful e-commerce marketing assitant working on art products." | |
client = InferenceClient("google/gemma-7b-it") | |
rand_val = random.randint(1, 1111111111111111) | |
if not history: | |
history = [] | |
generate_kwargs = dict( | |
temperature=0.67, | |
max_new_tokens=1024, | |
top_p=0.9, | |
repetition_penalty=1, | |
do_sample=True, | |
seed=rand_val, | |
) | |
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history) | |
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) | |
output = "" | |
for response in stream: | |
output += response.token.text | |
history.append((prompt, output)) | |
return history | |
def getImageDescription(image): | |
message = "Generate an ecommerce product description for the image" | |
stop = StopOnTokens() | |
messages = [{"role": "system", "content": "You are a helpful assistant."}] | |
if len(messages) == 1: | |
message = f" <image>{message}" | |
messages.append({"role": "user", "content": message}) | |
model_inputs = processor.tokenizer.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
return_tensors="pt" | |
) | |
image = ( | |
processor.feature_extractor(image) | |
.unsqueeze(0) | |
) | |
attention_mask = torch.ones( | |
1, model_inputs.shape[1] + processor.num_image_latents - 1 | |
) | |
model_inputs = { | |
"input_ids": model_inputs, | |
"images": image, | |
"attention_mask": attention_mask | |
} | |
model_inputs = {k: v.to(device) for k, v in model_inputs.items()} | |
streamer = TextIteratorStreamer(processor.tokenizer, timeout=30., skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
model_inputs, | |
streamer=streamer, | |
max_new_tokens=1024, | |
stopping_criteria=StoppingCriteriaList([stop]) | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
# history.append([message, ""]) | |
partial_response = "" | |
for new_token in streamer: | |
partial_response += new_token | |
# history[-1][1] = partial_response | |
# yield history | |
return partial_response | |
def resize_image(image): | |
image = image.convert('RGB') | |
model_input_size = (1024, 1024) | |
image = image.resize(model_input_size, Image.BILINEAR) | |
return image | |
def process(image): | |
# prepare input | |
orig_image = image | |
w,h = orig_im_size = orig_image.size | |
image = resize_image(orig_image) | |
im_np = np.array(image) | |
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1) | |
im_tensor = torch.unsqueeze(im_tensor,0) | |
im_tensor = torch.divide(im_tensor,255.0) | |
im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0]) | |
if torch.cuda.is_available(): | |
im_tensor=im_tensor.cuda() | |
#inference | |
result=net(im_tensor) | |
# post process | |
result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0) | |
ma = torch.max(result) | |
mi = torch.min(result) | |
result = (result-mi)/(ma-mi) | |
# image to pil | |
im_array = (result*255).cpu().data.numpy().astype(np.uint8) | |
pil_im = Image.fromarray(np.squeeze(im_array)) | |
# paste the mask on the original image | |
new_im = Image.new("RGBA", pil_im.size, (0,0,0,0)) | |
new_im.paste(orig_image, mask=pil_im) | |
# new_orig_image = orig_image.convert('RGBA') | |
return new_im | |
title = """<h1 style="text-align: center;">Product description generator</h1>""" | |
css = """ | |
div#col-container { | |
margin: 0 auto; | |
max-width: 840px; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.HTML(title) | |
with gr.Row(): | |
with gr.Column(elem_id="col-container"): | |
image = gr.Image(type="pil") | |
output = gr.Image(type="pil", interactive=False, label="Without background") | |
context = gr.Textbox(label="Small description") | |
submit = gr.Button(value="Upload", variant="primary") | |
with gr.Column(): | |
chat = gr.Chatbot(show_label=False) | |
user_input= gr.Textbox() | |
send = gr.Button(value="Send") | |
title_handler = ( | |
getProductTitle, | |
[chat, context, image], | |
[chat] | |
) | |
description_handler = ( | |
getProductDescription, | |
[chat], | |
[chat] | |
) | |
interaction_handler = ( | |
interactWithModel, | |
[chat, user_input], | |
[chat] | |
) | |
background_remover_handler = ( | |
process, | |
[image], | |
[output] | |
) | |
# postresponse_handler = ( | |
# lambda: (gr.Button(visible=False), gr.Button(visible=True)), | |
# None, | |
# [submit] | |
# ) | |
submit.click(*title_handler).then(*description_handler) | |
submit.click(*background_remover_handler) | |
send.click(*interaction_handler) | |
# event.then(*postresponse_handler) | |
demo.launch(share=True) |