screenshot2html / app.py
VictorSanh's picture
some cleaning and on the path to having token streamign
4a9f0a0
raw
history blame
No virus
7.44 kB
import os
import subprocess
import torch
import gradio as gr
from gradio_client.client import DEFAULT_TEMP_DIR
from playwright.sync_api import sync_playwright
from threading import Thread
from transformers import AutoProcessor, AutoModelForCausalLM, TextIteratorStreamer
from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
from typing import List
from PIL import Image
from transformers.image_transforms import resize, to_channel_dimension_format
DEVICE = torch.device("cuda")
PROCESSOR = AutoProcessor.from_pretrained(
"HuggingFaceM4/VLM_WebSight_finetuned",
)
MODEL = AutoModelForCausalLM.from_pretrained(
"HuggingFaceM4/VLM_WebSight_finetuned",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
).to(DEVICE)
if MODEL.config.use_resampler:
image_seq_len = MODEL.config.perceiver_config.resampler_n_latents
else:
image_seq_len = (
MODEL.config.vision_config.image_size // MODEL.config.vision_config.patch_size
) ** 2
BOS_TOKEN = PROCESSOR.tokenizer.bos_token
BAD_WORDS_IDS = PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
## Utils
def convert_to_rgb(image):
# `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
# for transparent images. The call to `alpha_composite` handles this case
if image.mode == "RGB":
return image
image_rgba = image.convert("RGBA")
background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
alpha_composite = Image.alpha_composite(background, image_rgba)
alpha_composite = alpha_composite.convert("RGB")
return alpha_composite
# The processor is the same as the Idefics processor except for the BICUBIC interpolation inside siglip,
# so this is a hack in order to redefine ONLY the transform method
def custom_transform(x):
x = convert_to_rgb(x)
x = to_numpy_array(x)
x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR)
x = PROCESSOR.image_processor.rescale(x, scale=1 / 255)
x = PROCESSOR.image_processor.normalize(
x,
mean=PROCESSOR.image_processor.image_mean,
std=PROCESSOR.image_processor.image_std
)
x = to_channel_dimension_format(x, ChannelDimension.FIRST)
x = torch.tensor(x)
return x
## End of Utils
IMAGE_GALLERY_PATHS = [
f"example_images/{ex_image}"
for ex_image in os.listdir(f"example_images")
]
def install_playwright():
try:
subprocess.run(["playwright", "install"], check=True)
print("Playwright installation successful.")
except subprocess.CalledProcessError as e:
print(f"Error during Playwright installation: {e}")
install_playwright()
def add_file_gallery(
selected_state: gr.SelectData,
gallery_list: List[str]
):
return Image.open(gallery_list.root[selected_state.index].image.path)
def render_webpage(
html_css_code,
):
with sync_playwright() as p:
browser = p.chromium.launch(headless=True)
context = browser.new_context(
user_agent=(
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0"
" Safari/537.36"
)
)
page = context.new_page()
page.set_content(html_css_code)
page.wait_for_load_state("networkidle")
output_path_screenshot = f"{DEFAULT_TEMP_DIR}/{hash(html_css_code)}.png"
_ = page.screenshot(path=output_path_screenshot, full_page=True)
context.close()
browser.close()
return Image.open(output_path_screenshot)
def model_inference(
image,
):
if image is None:
raise ValueError("`image` is None. It should be a PIL image.")
inputs = PROCESSOR.tokenizer(
f"{BOS_TOKEN}<fake_token_around_image>{'<image>' * image_seq_len}<fake_token_around_image>",
return_tensors="pt",
add_special_tokens=False,
)
inputs["pixel_values"] = PROCESSOR.image_processor(
[image],
transform=custom_transform
)
inputs = {
k: v.to(DEVICE)
for k, v in inputs.items()
}
streamer = TextIteratorStreamer(
PROCESSOR.tokenizer,
decode_kwargs=dict(
skip_special_tokens=True
),
skip_prompt=True,
)
generation_kwargs = dict(
inputs,
bad_words_ids=BAD_WORDS_IDS,
max_length=4096,
streamer=streamer,
)
thread = Thread(
target=MODEL.generate,
kwargs=generation_kwargs,
)
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text
print("before yield")
# yield generated_text, image
print("after yield")
rendered_page = render_webpage(generated_text)
return generated_text, rendered_page
generated_html = gr.Code(
label="Extracted HTML",
elem_id="generated_html",
)
rendered_html = gr.Image(
label="Rendered HTML",
show_download_button=False,
show_share_button=False,
)
# rendered_html = gr.HTML(
# label="Rendered HTML"
# )
css = """
.gradio-container{max-width: 1000px!important}
h1{display: flex;align-items: center;justify-content: center;gap: .25em}
*{transition: width 0.5s ease, flex-grow 0.5s ease}
"""
with gr.Blocks(title="Screenshot to HTML", theme=gr.themes.Base(), css=css) as demo:
with gr.Row(equal_height=True):
with gr.Column(scale=4, min_width=250) as upload_area:
imagebox = gr.Image(
type="pil",
label="Screenshot to extract",
visible=True,
sources=["upload", "clipboard"],
)
with gr.Group():
with gr.Row():
submit_btn = gr.Button(
value="▶️ Submit", visible=True, min_width=120
)
clear_btn = gr.ClearButton(
[imagebox, generated_html, rendered_html], value="🧹 Clear", min_width=120
)
regenerate_btn = gr.Button(
value="🔄 Regenerate", visible=True, min_width=120
)
with gr.Column(scale=4):
rendered_html.render()
with gr.Row():
generated_html.render()
with gr.Row():
template_gallery = gr.Gallery(
value=IMAGE_GALLERY_PATHS,
label="Templates Gallery",
allow_preview=False,
columns=5,
elem_id="gallery",
show_share_button=False,
height=400,
)
gr.on(
triggers=[
imagebox.upload,
submit_btn.click,
regenerate_btn.click,
],
fn=model_inference,
inputs=[imagebox],
outputs=[generated_html, rendered_html],
queue=False,
)
regenerate_btn.click(
fn=model_inference,
inputs=[imagebox],
outputs=[generated_html, rendered_html],
queue=False,
)
template_gallery.select(
fn=add_file_gallery,
inputs=[template_gallery],
outputs=[imagebox],
queue=False,
).success(
fn=model_inference,
inputs=[imagebox],
outputs=[generated_html, rendered_html],
queue=False,
)
demo.load(queue=False)
demo.queue(max_size=40, api_open=False)
demo.launch(max_threads=400)