import os os.system('git clone https://github.com/facebookresearch/detectron2.git') os.system('pip install -e detectron2') os.system("git clone https://github.com/microsoft/unilm.git") os.system("sed -i 's/from collections import Iterable/from collections.abc import Iterable/' unilm/dit/object_detection/ditod/table_evaluation/data_structure.py") os.system("curl -LJ -o publaynet_dit-b_cascade.pth 'https://layoutlm.blob.core.windows.net/dit/dit-fts/publaynet_dit-b_cascade.pth?sv=2022-11-02&ss=b&srt=o&sp=r&se=2033-06-08T16:48:15Z&st=2023-06-08T08:48:15Z&spr=https&sig=a9VXrihTzbWyVfaIDlIT1Z0FoR1073VB0RLQUMuudD4%3D'") import sys sys.path.append("unilm") sys.path.append("detectron2") import cv2 import filetype from PIL import Image import numpy as np from io import BytesIO from pdf2image import convert_from_bytes, convert_from_path import re import requests from collections import namedtuple from urllib.parse import urlparse, parse_qs from unilm.dit.object_detection.ditod import add_vit_config import torch from detectron2.config import CfgNode as CN from detectron2.config import get_cfg from detectron2.utils.visualizer import ColorMode, Visualizer from detectron2.data import MetadataCatalog from detectron2.engine import DefaultPredictor from huggingface_hub import hf_hub_download import gradio as gr # Step 1: instantiate config cfg = get_cfg() add_vit_config(cfg) #cfg.merge_from_file("cascade_dit_base.yml") cfg.merge_from_file("unilm/dit/object_detection/publaynet_configs/cascade/cascade_dit_base.yaml") # Step 2: add model weights URL to config filepath = hf_hub_download(repo_id="Sebas6k/DiT_weights", filename="publaynet_dit-b_cascade.pth", repo_type="model") cfg.MODEL.WEIGHTS = filepath # Step 3: set device cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Step 4: define model predictor = DefaultPredictor(cfg) # Set up internal data structure # Define a namedtuple for holding extracted image data ExtractedImage = namedtuple("ExtractedImage", ["image", "annotated_page", "original_page", "confidence_score", "top_left", "bottom_right", "num_pixels", "is_color"]) def analyze_image(img): images = extract_images(img) # Filter out figures based on class labels high_confidence = [] medium_confidence = [] low_confidence = [] result_image = img for extracted_image_object in images: cropped_img = extracted_image_object.image confidence_score = extracted_image_object.confidence_score confidence_text = f"Score: {confidence_score:.2f}%" if cropped_img is not None: # Overlay confidence score on the image # Enhanced label visualization with orange color font_scale = 0.9 font_thickness = 2 text_color = (255, 255, 255) # white background #background_color = (0, 165, 255) # BGR for orange background_color = (255, 165, 0) # RGB for orange (text_width, text_height), _ = cv2.getTextSize(confidence_text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, font_thickness) padding = 12 text_offset_x = padding - 3 text_offset_y = cropped_img.shape[0] - padding + 2 box_coords = ((text_offset_x, text_offset_y + padding // 2), (text_offset_x + text_width + padding, text_offset_y - text_height - padding // 2)) cv2.rectangle(cropped_img, box_coords[0], box_coords[1], background_color, cv2.FILLED) cv2.putText(cropped_img, confidence_text, (text_offset_x, text_offset_y), cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_color, font_thickness) # end adding score annotation #result_image.append(extracted_image_object.annotated_page) if extracted_image_object.annotated_page is not None: result_image = extracted_image_object.annotated_page # Categorize images based on confidence levels if confidence_score > 85: high_confidence.append(cropped_img) elif confidence_score > 50: medium_confidence.append(cropped_img) elif cropped_img is not None: low_confidence.append(cropped_img) return result_image, high_confidence, medium_confidence, low_confidence def extract_images(img): md = MetadataCatalog.get(cfg.DATASETS.TEST[0]) if cfg.DATASETS.TEST[0]=='icdar2019_test': md.set(thing_classes=["table"]) else: md.set(thing_classes=["text","title","list","table","figure"]) ## these are categories from PubLayNet (PubMed PDF/XML data): https://ieeexplore.ieee.org/document/8977963 is_color = None print(f"###################### Is effectively grayscale? {is_effectively_grayscale_np(img)} #######################") print(f"############################### ndim {img.ndim} -- shape[2] {img.shape[2]} #######################") # Ensure the image is in the correct format if img.ndim == 2: # Image is grayscale, needs converting img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) elif img.ndim == 3 and img.shape[2] == 3: if not is_effectively_grayscale_np(img): # Image is RGB mode, but still only using grayscale colors img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) is_color = True outputs = predictor(img) instances = outputs["instances"] # Ensure we're operating on CPU for numpy compatibility instances = instances.to("cpu") extracted_images = [] v = Visualizer(img[:, :, ::-1], md, scale=1.0, instance_mode=ColorMode.SEGMENTATION) result_image = v.draw_instance_predictions(instances).get_image()[:, :, ::-1] result_image = cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB) for i in range(len(instances)): if md.thing_classes[instances.pred_classes[i]] == "figure": box = instances.pred_boxes.tensor[i].numpy().astype(int) cropped_img = img[box[1]:box[3], box[0]:box[2]] cropped_img = cv2.cvtColor(cropped_img, cv2.COLOR_BGR2RGB) confidence_score = instances.scores[i].numpy() * 100 # convert to percentage num_pixels = cropped_img.shape[0] * cropped_img.shape[1] is_color = len(cropped_img.shape) == 3 and cropped_img.shape[2] == 3 and not is_effectively_grayscale_np(img) extracted_images.append(ExtractedImage( image=cropped_img, annotated_page=result_image, original_page=img, confidence_score=confidence_score, top_left=f"{box[0]}-{box[1]}", bottom_right=f"{box[2]}-{box[3]}", num_pixels=num_pixels, is_color=is_color )) if not extracted_images: # there were none to process, still need to return basic image extracted_images.append(ExtractedImage( image=None, # or an appropriate default value annotated_page=result_image, original_page=img, # The original input image confidence_score=-1, # Indicates no confidence top_left=None, bottom_right=None, # No bounding box coordinates num_pixels=0, # No pixels counted is_color=False # Default to grayscale or False )) return extracted_images def is_effectively_grayscale_np(array): if array.ndim != 3 or array.shape[2] != 3: raise ValueError("Input must be an RGB image") # Check if all color channels are equal across the image r, g, b = array[:,:,0], array[:,:,1], array[:,:,2] return np.array_equal(r, g) and np.array_equal(g, b) def handle_input(input_data): images = [] #input_data is a dict with keys 'text' and 'files' if 'text' in input_data and input_data['text']: input_text = input_data['text'].strip() # this is either a URL or a PDF ID if input_text.startswith('http://') or input_text.startswith('https://'): # Extract the ID from the URL url_parts = urlparse(input_text) query_params = parse_qs(url_parts.fragment) # Assumes ID is a fragment parameter pdf_id = query_params.get('id', [None])[0] if not pdf_id: raise ValueError("PDF ID not found in URL") else: # Assume input is a direct PDF ID pdf_id = input_text if not re.match(r'^[a-zA-Z]{4}\d{4}$', pdf_id): raise ValueError("Invalid PDF ID format. Expected four letters followed by four numbers.") # Assume input is a PDF ID, convert to URL # Now construct the download URL pdf_url = construct_download_url(pdf_id) #https://download.industrydocuments.ucsf.edu/k/t/k/l/ktkl0236/ktkl0236.pdf # Assume input is a PDF URL pdf_data = download_pdf(pdf_url) images = pdf_to_images(pdf_data) if 'files' in input_data and input_data['files']: for file_path in input_data['files']: print("Type of file as uploaded:", type(file_path)) print(f" File: {file_path}") # Check if the input is a file and determine its type kind = filetype.guess(file_path) if kind.mime.startswith('image'): # Process a single image images.append(load_image(file_path)) # Process image directly elif kind.mime == 'application/pdf': # Convert PDF pages to images images.extend(pdf_to_images(file_path)) else: raise ValueError("Unsupported file type.") if not images: raise ValueError("No valid input provided. Please upload a file or enter a PDF ID.") # Assuming process_images returns galleries of images by confidence return process_images(images) def load_image(img_path): print(f"Loading image: {img_path}") # Load an image from a file path image = Image.open(img_path) print(f" Image mode: {image.mode}") # Add this debug line if image.mode != 'RGB': print(f" Converting from {image.mode} to RGB") image = image.convert('RGB') if isinstance(image, Image.Image): print(" Converting to numpy") image = np.array(image) # Convert PIL Image to numpy array print(f" Array shape: {image.shape}") # Add this debug line return image def construct_download_url(pdf_id): # Construct the download URL from the PDF ID # https://download.examples.edu/k/t/k/l/ktkl0236/ktkl0236.pdf path_parts = '/'.join(pdf_id[i] for i in range(4)) # 'k/t/k/l' download_url = f"https://download.industrydocuments.ucsf.edu/{path_parts}/{pdf_id}/{pdf_id}.pdf" return download_url def download_pdf(pdf_url): # Download the PDF file from the given URL response = requests.get(pdf_url) response.raise_for_status() # Ensure we notice bad responses return BytesIO(response.content) def pdf_to_images(data_or_path): # Create a temporary directory to store the page images temp_dir = "temp_images" os.makedirs(temp_dir, exist_ok=True) try: # Convert PDF to a list of PIL images # Handle both BytesIO and file path input for PDF conversion if isinstance(data_or_path, BytesIO): # Convert directly from bytes pages = convert_from_bytes(data_or_path.read()) elif isinstance(data_or_path, str): # Convert from a file path pages = convert_from_path(data_or_path) # Save each page as an image file page_images = [] for i, page in enumerate(pages): image_path = os.path.join(temp_dir, f"page_{i+1}.jpg") page.save(image_path, "JPEG") page_images.append(load_image(image_path)) return page_images except Exception as e: print(f"Error converting PDF to images: {str(e)}") return [] finally: # Clean up the temporary directory (optional) # os.rmdir(temp_dir) pass def process_images(images): all_processed_images = [] all_high_confidence = [] all_medium_confidence = [] all_low_confidence = [] idx = 0 for img in images: idx += 1 #print("Type of img before processing:", type(img)) #print(f" img before processing: {img}") processed_images, high_confidence, medium_confidence, low_confidence = analyze_image(img) if processed_images is None: print(f" ******* processed_images is None on page: {idx}") else: all_processed_images.append(processed_images) print(f" ******* type of processed_images: {type(processed_images)}") if not high_confidence: print(f" ******* high_confidence is empty on page: {idx}") all_high_confidence.extend(high_confidence) if not medium_confidence: print(f" ******* medium_confidence is empty on page: {idx}") all_medium_confidence.extend(medium_confidence) if not low_confidence: print(f" ******* low_confidence is empty on page: {idx}") all_low_confidence.extend(low_confidence) print(f" ******* Size of all_process_images: {len(all_processed_images)}") for item in all_processed_images: print(f"Type Check all_processed: {type(item)}") print(f" ******* Size of all_high_conf: {len(all_high_confidence)}") for item in all_high_confidence: print(f"Type Check high_conf: {type(item)}") print(f" ******* Size of all_med: {len(all_medium_confidence)}") for item in all_medium_confidence: print(f"Type Check med_conf: {type(item)}") print(f" ******* Size of all_low: {len(all_low_confidence)}") for item in all_low_confidence: print(f"Type Check low_conf: {type(item)}") return all_processed_images, all_high_confidence, all_medium_confidence, all_low_confidence title = "OIDA Image Collection Interactive demo: Document Layout Analysis with DiT and PubLayNet" description = "
Paper | Github Repo | HuggingFace doc | PubLayNet paper
" #examples =[['fpmj0236_Page_012.png'],['fnmf0234_Page_2.png'],['publaynet_example.jpeg'],['fpmj0236_Page_018.png'],['lrpw0232_Page_14.png'],['kllx0250'],['https://www.industrydocuments.ucsf.edu/opioids/docs/#id=yqgg0230']] examples =[{'files': ['fnmf0234_Page_2.png']},{'files': ['fpmj0236_Page_012.png']},{'files': ['lrpw0232.pdf']},{'text': 'https://www.industrydocuments.ucsf.edu/opioids/docs/#id=yqgg0230'},{'files':['fpmj0236_Page_018.png']},{'files':['lrpw0232_Page_14.png']},{'files':['publaynet_example.jpeg']},{'text':'kllx0250'},{'text':'txhk0255'},{'text':'gpdk0256'}] css = ".output-image, .input-image, .image-preview {height: 600px !important} td.textbox {display:none;} #component-5 .submit-button {display:none;}" def setup_gradio_interface(): #iface = gr.Interface(fn=handle_input, # inputs=gr.MultimodalTextbox(interactive=True, # label="Upload image/PDF file OR enter OIDA ID or URL", # file_types=["image",".pdf"], # placeholder="Upload image/PDF file OR enter OIDA ID or URL"), # outputs=[gr.Gallery(label="annotated documents"), # gr.Gallery(label="Figures with High (>85%) Confidence Scores"), # gr.Gallery(label="Figures with Moderate (50-85%) Confidence Scores"), # gr.Gallery(label="Figures with Lower Confidence (under 50%) Scores")], # title=title, # description=description, # examples=examples, # article=article, # css=css) ## enable_queue=True) with gr.Blocks(css=css) as iface: gr.Markdown(f"# {title}") gr.HTML(description) with gr.Row(): with gr.Column(): input = gr.MultimodalTextbox(interactive=True, label="Upload image/PDF file OR enter OIDA ID or URL", file_types=["image",".pdf"], placeholder="Upload image/PDF file OR enter OIDA ID or URL", submit_btn=None) submit_btn = gr.Button("Submit") gr.HTML('