screenshot2html / app.py
VictorSanh's picture
very big update
844c526
raw
history blame
9.64 kB
import os
import subprocess
import torch
import gradio as gr
from gradio_client.client import DEFAULT_TEMP_DIR
from playwright.sync_api import sync_playwright
from transformers import AutoProcessor, AutoModelForCausalLM
from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
from typing import List
from PIL import Image
from transformers.image_transforms import resize, to_channel_dimension_format
API_TOKEN = os.getenv("HF_AUTH_TOKEN")
DEVICE = torch.device("cuda")
PROCESSOR = AutoProcessor.from_pretrained(
"HuggingFaceM4/img2html",
token=API_TOKEN,
)
MODEL = AutoModelForCausalLM.from_pretrained(
"HuggingFaceM4/img2html", #TODO
token=API_TOKEN,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
).to(DEVICE)
if MODEL.config.use_resampler:
image_seq_len = MODEL.config.perceiver_config.resampler_n_latents
else:
image_seq_len = (
MODEL.config.vision_config.image_size // MODEL.config.vision_config.patch_size
) ** 2
BOS_TOKEN = PROCESSOR.tokenizer.bos_token
BAD_WORDS_IDS = PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
## Utils
def convert_to_rgb(image):
# `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
# for transparent images. The call to `alpha_composite` handles this case
if image.mode == "RGB":
return image
image_rgba = image.convert("RGBA")
background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
alpha_composite = Image.alpha_composite(background, image_rgba)
alpha_composite = alpha_composite.convert("RGB")
return alpha_composite
# The processor is the same as the Idefics processor except for the BICUBIC interpolation inside siglip,
# so this is a hack in order to redefine ONLY the transform method
def custom_transform(x):
x = convert_to_rgb(x)
x = to_numpy_array(x)
x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR)
x = PROCESSOR.image_processor.rescale(x, scale=1 / 255)
x = PROCESSOR.image_processor.normalize(
x,
mean=PROCESSOR.image_processor.image_mean,
std=PROCESSOR.image_processor.image_std
)
x = to_channel_dimension_format(x, ChannelDimension.FIRST)
x = torch.tensor(x)
return x
## End of Utils
IMAGE_GALLERY_PATHS = [
f"example_images/{ex_image}"
for ex_image in os.listdir(f"example_images")
]
def install_playwright():
try:
subprocess.run(["playwright", "install"], check=True)
print("Playwright installation successful.")
except subprocess.CalledProcessError as e:
print(f"Error during Playwright installation: {e}")
install_playwright()
def add_file_gallery(
selected_state: gr.SelectData,
gallery_list: List[str]
):
return Image.open(gallery_list.root[selected_state.index].image.path)
def render_webpage(
html_css_code,
):
with sync_playwright() as p:
browser = p.chromium.launch(headless=True)
context = browser.new_context(
user_agent=(
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0"
" Safari/537.36"
)
)
page = context.new_page()
page.set_content(html_css_code)
page.wait_for_load_state("networkidle")
output_path_screenshot = f"{DEFAULT_TEMP_DIR}/{hash(html_css_code)}.png"
_ = page.screenshot(path=output_path_screenshot, full_page=True)
context.close()
browser.close()
return Image.open(output_path_screenshot)
def model_inference(
image,
):
if image is None:
raise ValueError("`image` is None. It should be a PIL image.")
inputs = PROCESSOR.tokenizer(
f"{BOS_TOKEN}<fake_token_around_image>{'<image>' * image_seq_len}<fake_token_around_image>",
return_tensors="pt"
)
inputs["pixel_values"] = PROCESSOR.image_processor(
[image],
transform=custom_transform
)
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
generated_ids = MODEL.generate(**inputs, bad_words_ids=BAD_WORDS_IDS)
generated_text = PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(generated_text)
CAR_COMPNAY = """<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>XYZ Car Company</title>
<style>
body {
font-family: 'Arial', sans-serif;
margin: 0;
padding: 0;
background-color: #f4f4f4;
}
header {
background-color: #333;
color: #fff;
padding: 1em;
text-align: center;
}
nav {
background-color: #555;
color: #fff;
padding: 0.5em;
text-align: center;
}
nav a {
color: #fff;
text-decoration: none;
padding: 0.5em 1em;
margin: 0 1em;
}
section {
padding: 2em;
}
h2 {
color: #333;
}
.car-container {
display: flex;
flex-wrap: wrap;
justify-content: space-around;
}
.car-card {
width: 300px;
margin: 1em;
border: 1px solid #ddd;
border-radius: 5px;
overflow: hidden;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
}
.car-image {
width: 100%;
height: 150px;
object-fit: cover;
}
.car-details {
padding: 1em;
}
footer {
background-color: #333;
color: #fff;
text-align: center;
padding: 1em;
position: fixed;
bottom: 0;
width: 100%;
}
</style>
</head>
<body>
<header>
<h1>XYZ Car Company</h1>
</header>
<nav>
<a href="#">Home</a>
<a href="#">Models</a>
<a href="#">About Us</a>
<a href="#">Contact</a>
</nav>
<section>
<h2>Our Cars</h2>
<div class="car-container">
<div class="car-card">
<img src="car1.jpg" alt="Car 1" class="car-image">
<div class="car-details">
<h3>Model A</h3>
<p>Description of Model A.</p>
</div>
</div>
<div class="car-card">
<img src="car2.jpg" alt="Car 2" class="car-image">
<div class="car-details">
<h3>Model B</h3>
<p>Description of Model B.</p>
</div>
</div>
<!-- Add more car cards as needed -->
</div>
</section>
<footer>
&copy; 2024 XYZ Car Company. All rights reserved.
</footer>
</body>
</html>"""
rendered_page = render_webpage(generated_text)
return generated_text, rendered_page
generated_html = gr.Code(
label="Extracted HTML",
elem_id="generated_html",
)
rendered_html = gr.Image(
label="Rendered HTML"
)
# rendered_html = gr.HTML(
# label="Rendered HTML"
# )
css = """
.gradio-container{max-width: 1000px!important}
h1{display: flex;align-items: center;justify-content: center;gap: .25em}
*{transition: width 0.5s ease, flex-grow 0.5s ease}
"""
with gr.Blocks(title="Img2html", theme=gr.themes.Base(), css=css) as demo:
with gr.Row(equal_height=True):
with gr.Column(scale=4, min_width=250) as upload_area:
imagebox = gr.Image(
type="pil",
label="Screenshot to extract",
visible=True,
sources=["upload", "clipboard"],
)
with gr.Group():
with gr.Row():
submit_btn = gr.Button(
value="▶️ Submit", visible=True, min_width=120
)
clear_btn = gr.ClearButton(
[imagebox, generated_html, rendered_html], value="🧹 Clear", min_width=120
)
regenerate_btn = gr.Button(
value="🔄 Regenerate", visible=True, min_width=120
)
with gr.Column(scale=4) as result_area:
rendered_html.render()
with gr.Row():
generated_html.render()
with gr.Row():
template_gallery = gr.Gallery(
value=IMAGE_GALLERY_PATHS,
label="Templates Gallery",
allow_preview=False,
columns=4,
elem_id="gallery",
show_share_button=False,
height=400,
)
gr.on(
triggers=[
imagebox.upload,
submit_btn.click,
regenerate_btn.click,
],
fn=model_inference,
inputs=[imagebox],
outputs=[generated_html, rendered_html],
queue=False,
)
regenerate_btn.click(
fn=model_inference,
inputs=[
imagebox,
],
outputs=[generated_html, rendered_html],
queue=False,
)
template_gallery.select(
fn=add_file_gallery,
inputs=[template_gallery],
outputs=[imagebox],
queue=False,
).success(
fn=model_inference,
inputs=[imagebox],
outputs=[generated_html, rendered_html],
)
demo.load(queue=False)
demo.queue(max_size=40, api_open=False)
demo.launch(max_threads=400)