import gradio as gr from PIL import Image, ImageDraw from inference import generate_image TASK_TO_INDEX = {"Task 1": 0, "Task 2": 1, "Task 3": 2, "Task 4": 3} TASK_OPTIMAL_COORDS = {0: (325, 326), 1: (59, 1126), 2: (47, 102), 3: (497, 933)} def create_marker_overlay(image_path: str, x: int, y: int) -> Image.Image: """Creates an image with a marker at the specified coordinates""" base_image = Image.open(image_path) marked_image = base_image.copy() draw = ImageDraw.Draw(marked_image) marker_size = 10 marker_color = "red" draw.line([x - marker_size, y, x + marker_size, y], fill=marker_color, width=2) draw.line([x, y - marker_size, x, y + marker_size], fill=marker_color, width=2) return marked_image def update_reference_image(choice: int) -> tuple[str, int, str]: image_path = f"imgs/pattern_{choice}.png" heatmap_path = f"imgs/heatmap_{choice}.png" return image_path, choice, heatmap_path def update_marker(image_idx: int, evt: gr.SelectData) -> tuple[Image.Image, tuple[int, int]]: x, y = evt.index[0], evt.index[1] heatmap_path = f"imgs/heatmap_{image_idx}.png" return create_marker_overlay(heatmap_path, x, y), (x, y) def generate_output_image(image_idx: int, coords: tuple[int, int]) -> Image.Image: x, y = coords x_norm, y_norm = x / 1155, y / 1155 return generate_image(image_idx, x_norm, y_norm) def find_optimal_latent(image_idx: int) -> tuple[Image.Image, tuple[int, int], Image.Image]: x, y = TASK_OPTIMAL_COORDS[image_idx] heatmap_path = f"imgs/heatmap_{image_idx}.png" marked_heatmap = create_marker_overlay(heatmap_path, x, y) output_img = generate_output_image(image_idx, (x, y)) return marked_heatmap, (x, y), output_img with gr.Blocks( css=""" .container { max-width: 1200px !important; width: 100% !important; margin-left: auto !important; margin-right: auto !important; padding: 0 1rem !important; } .diagram-container { width: 100% !important; max-width: 1000px !important; margin: 2rem auto !important; } .diagram-container img { width: 100% !important; height: auto !important; display: block !important; margin: 0 auto !important; cursor: default !important; } .radio-container { width: 100% !important; max-width: 450px !important; margin-bottom: 1rem !important; } .image-preview-container { width: 100% !important; max-width: 450px !important; } .image-preview-container img { width: 100% !important; height: 100% !important; object-fit: contain !important; cursor: default !important; } .coordinate-container { width: 100% !important; aspect-ratio: 1 !important; position: relative !important; max-width: 550px !important; } .coordinate-container img { width: 100% !important; height: 100% !important; object-fit: contain !important; } .button-container { width: 100% !important; max-width: 450px !important; display: flex !important; justify-content: center !important; margin-bottom: 1rem !important; } # .documentation { # margin-top: 2rem !important; # padding: 1rem !important; # background-color: #f8f9fa !important; # border-radius: 8px !important; # } .optimal-button { width: 200px !important; } """ ) as demo: with gr.Column(elem_classes="container"): gr.Markdown( """ # Interactive Visualization of a Latent Program Network (LPN) ## Introduction The LPN is an architecture for inductive program synthesis that builds in test-time adaption by learning a latent space that can be used for search. This interactive demo showcases a latent traversal of the LPN in the latent program space. More specifically, the decoder of the LPN is conditioned on a latent vector representing an abstract program, which is then used to generate an output. """ ) with gr.Column(elem_classes="diagram-container"): gr.Image( value="imgs/lpn_diagram.png", show_label=False, interactive=False, show_download_button=False, show_fullscreen_button=False, show_share_button=False, container=False, ) gr.Markdown( """ ### How to Use 1. Choose a pattern task using the radio buttons 2. View the input-output pairs for your selected task 3. The goal is to find the latent that will generate the right third image for the given input 4. Click anywhere on the latent space to specify coordinates for the latent 5. See the generated image based on your selected latent Use the "Find Optimal Latent" button to find the latent that maximizes likelihood of generating the other input-output pairs. """ ) with gr.Row(): # Left column for controls with gr.Column(scale=1): selected_idx = gr.State(value=0) coords = gr.State() with gr.Column(elem_classes="radio-container"): task_select = gr.Radio( choices=["Task 1", "Task 2", "Task 3", "Task 4"], value="Task 1", label="Select Task", interactive=True, ) gr.Markdown("### Latent Space Search") gr.Markdown( "Click anywhere in the 2D latent space below to condition the decoder on a specific latent vector. " "The heatmap shows the decoder log-likelihood of generating the first two input-output pairs conditioning on any point in the latent space. " "The goal is to find the latent that generates the third image for the given input." ) with gr.Column(elem_classes="coordinate-container"): coord_selector = gr.Image( value="imgs/heatmap_0.png", show_label=False, interactive=False, sources=[], container=True, show_download_button=False, show_fullscreen_button=False, show_share_button=False, ) with gr.Column(elem_classes="button-container"): optimal_button = gr.Button("Find Optimal Latent", elem_classes="optimal-button") # Right column for images with gr.Column(scale=1): gr.Markdown("### Input-Output Pairs") with gr.Column(elem_classes="image-preview-container"): reference_image = gr.Image( value="imgs/pattern_0.png", show_label=False, interactive=False, show_download_button=False, show_fullscreen_button=False, show_share_button=False, ) gr.Markdown("### Generated Output") with gr.Column(elem_classes="image-preview-container"): output_image = gr.Image( show_label=False, interactive=False, show_download_button=False, show_fullscreen_button=False, show_share_button=False, ) with gr.Column(elem_classes="container"): gr.Markdown( """ ### Technical Details For more information, please refer to our [paper](https://arxiv.org/pdf/2411.08706) or GitHub [repository](https://github.com/clement-bonnet/lpn). """ ) # Event handlers task_select.change( fn=lambda x: update_reference_image(TASK_TO_INDEX[x]), inputs=[task_select], outputs=[reference_image, selected_idx, coord_selector], ) coord_selector.select( fn=update_marker, inputs=[selected_idx], outputs=[coord_selector, coords], trigger_mode="multiple", ).then( fn=generate_output_image, inputs=[selected_idx, coords], outputs=output_image, ) optimal_button.click( fn=find_optimal_latent, inputs=[selected_idx], outputs=[coord_selector, coords, output_image], ) demo.launch()