humblemikey's picture
Update app.py
b129969 verified
import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM
import re
from PIL import Image
import os
import numpy as np
import spaces
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
model = AutoModelForCausalLM.from_pretrained('thwri/CogFlorence-2.1-Large', trust_remote_code=True).to("cuda").eval()
processor = AutoProcessor.from_pretrained('thwri/CogFlorence-2.1-Large', trust_remote_code=True)
TITLE = "# [thwri/CogFlorence-2.1-Large](https://huggingface.co/thwri/CogFlorence-2.1-Large/)"
DESCRIPTION = "[microsoft/Florence-2-large](https://huggingface.co/microsoft/Florence-2-large) tuned on [Ejafa/ye-pop](https://huggingface.co/datasets/Ejafa/ye-pop) captioned with [CogVLM2](https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B)"
def modify_caption(caption: str) -> str:
special_patterns = [
(r'the image is ', ''),
(r'the image captures ', ''),
(r'the image showcases ', ''),
(r'the image shows ', ''),
(r'the image ', ''),
]
for pattern, replacement in special_patterns:
caption = re.sub(pattern, replacement, caption, flags=re.IGNORECASE)
caption = caption.replace('\n', '').replace('\r', '')
caption = re.sub(r'(?<=[.,?!])(?=[^\s])', r' ', caption)
caption = ' '.join(caption.strip().splitlines())
return caption
@spaces.GPU
def process_image(image):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
elif isinstance(image, str):
image = Image.open(image)
if image.mode != "RGB":
image = image.convert("RGB")
prompt = "<MORE_DETAILED_CAPTION>"
inputs = processor(text=prompt, images=image, return_tensors="pt").to("cuda")
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3,
do_sample=True
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
return modify_caption(parsed_answer["<MORE_DETAILED_CAPTION>"])
def extract_frames(image_path, output_folder):
with Image.open(image_path) as img:
base_name = os.path.splitext(os.path.basename(image_path))[0]
frame_paths = []
try:
for i in range(0, img.n_frames):
img.seek(i)
frame_path = os.path.join(output_folder, f"{base_name}_frame_{i:03d}.png")
img.save(frame_path)
frame_paths.append(frame_path)
except EOFError:
pass # We've reached the end of the sequence
return frame_paths
def process_folder(folder_path):
if not os.path.isdir(folder_path):
return "Invalid folder path."
processed_files = []
skipped_files = []
for filename in os.listdir(folder_path):
if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp', '.heic')):
image_path = os.path.join(folder_path, filename)
txt_filename = os.path.splitext(filename)[0] + '.txt'
txt_path = os.path.join(folder_path, txt_filename)
# Check if the corresponding text file already exists
if os.path.exists(txt_path):
skipped_files.append(f"Skipped {filename} (text file already exists)")
continue
# Check if the image has multiple frames
with Image.open(image_path) as img:
if getattr(img, "is_animated", False) and img.n_frames > 1:
# Extract frames
frames = extract_frames(image_path, folder_path)
for frame_path in frames:
frame_txt_filename = os.path.splitext(os.path.basename(frame_path))[0] + '.txt'
frame_txt_path = os.path.join(folder_path, frame_txt_filename)
# Check if the corresponding text file for the frame already exists
if os.path.exists(frame_txt_path):
skipped_files.append(f"Skipped {os.path.basename(frame_path)} (text file already exists)")
continue
caption = process_image(frame_path)
with open(frame_txt_path, 'w', encoding='utf-8') as f:
f.write(caption)
processed_files.append(f"Processed {os.path.basename(frame_path)} -> {frame_txt_filename}")
else:
# Process single image
caption = process_image(image_path)
with open(txt_path, 'w', encoding='utf-8') as f:
f.write(caption)
processed_files.append(f"Processed {filename} -> {txt_filename}")
result = "\n".join(processed_files + skipped_files)
return result if result else "No image files found or all files were skipped in the specified folder."
css = """
#output { height: 500px; overflow: auto; border: 1px solid #ccc; }
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(TITLE)
gr.Markdown(DESCRIPTION)
with gr.Tab(label="Single Image Processing"):
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Picture")
submit_btn = gr.Button(value="Submit")
with gr.Column():
output_text = gr.Textbox(label="Output Text")
gr.Examples(
[["image1.jpg"], ["image2.jpg"], ["image3.png"], ["image4.jpg"], ["image5.jpg"], ["image6.PNG"]],
inputs=[input_img],
outputs=[output_text],
fn=process_image,
label='Try captioning on below examples'
)
submit_btn.click(process_image, [input_img], [output_text])
with gr.Tab(label="Batch Processing"):
with gr.Row():
folder_input = gr.Textbox(label="Input Folder Path")
batch_submit_btn = gr.Button(value="Process Folder")
batch_output = gr.Textbox(label="Batch Processing Results", lines=10)
batch_submit_btn.click(process_folder, [folder_input], [batch_output])
demo.launch(debug=True)