import base64 import io import json import os import gradio as gr import matplotlib.pyplot as plt import spaces import torch from huggingface_hub import login from PIL import Image from transformers import AutoProcessor, MllamaForConditionalGeneration def check_environment(): required_vars = ["HF_TOKEN"] missing_vars = [var for var in required_vars if var not in os.environ] if missing_vars: raise ValueError( f"Missing required environment variables: {', '.join(missing_vars)}\n" "Please set the HF_TOKEN environment variable with your Hugging Face token" ) # Login to Hugging Face check_environment() login(token=os.environ["HF_TOKEN"], add_to_git_credential=True) import torch from transformers import AutoProcessor, MllamaForConditionalGeneration base_model_path = "taesiri/FireNet-LLama-3.2-11B-Vision-Base" processor = AutoProcessor.from_pretrained(base_model_path) model = MllamaForConditionalGeneration.from_pretrained( base_model_path, torch_dtype=torch.bfloat16, device_map="cuda" ) model.tie_weights() def create_color_palette_image(colors): if not colors or not isinstance(colors, list): return None try: # Validate color format for color in colors: if not isinstance(color, str) or not color.startswith("#"): return None # Create figure and axis fig, ax = plt.subplots(figsize=(10, 2)) # Create rectangles for each color for i, color in enumerate(colors): ax.add_patch(plt.Rectangle((i, 0), 1, 1, facecolor=color)) # Set the view limits and aspect ratio ax.set_xlim(0, len(colors)) ax.set_ylim(0, 1) ax.set_xticks([]) ax.set_yticks([]) return fig # Return the matplotlib figure directly except Exception as e: print(f"Error creating color palette: {e}") return None @spaces.GPU def inference(image): if image is None: return ["Please provide an image"] * 4 if not isinstance(image, Image.Image): try: image = Image.fromarray(image) except Exception as e: print(f"Image conversion error: {e}") return ["Invalid image format"] * 4 # Prepare input messages = [ { "role": "user", "content": [ {"type": "image"}, { "type": "text", "text": "Analyze this image for fire, smoke, haze, or other related conditions.", }, ], } ] input_text = processor.apply_chat_template(messages, add_generation_prompt=True) try: # Move inputs to the correct device inputs = processor( image, input_text, add_special_tokens=False, return_tensors="pt" ).to(model.device) # Clear CUDA cache after inference with torch.no_grad(): output = model.generate(**inputs, max_new_tokens=2048) if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception as e: print(f"Inference error: {e}") return ["Error during inference"] * 4 # Decode output result = processor.decode(output[0], skip_special_tokens=True) print("DEBUG: Full decoded output:", result) try: json_str = result.strip().split("assistant\n")[1].strip() parsed_json = json.loads(json_str) # Create specific JSON subsets for each section fire_analysis = { "predictions": parsed_json.get("predictions", "N/A"), "description": parsed_json.get("description", "No description available"), "confidence_scores": parsed_json.get("confidence_score", {}), } environment_analysis = { "environmental_factors": parsed_json.get("environmental_factors", {}) } detection_analysis = { "detections": parsed_json.get("detections", []), "detection_count": len(parsed_json.get("detections", [])), } report_analysis = { "uncertainty_factors": parsed_json.get("uncertainty_factors", []), "false_positive_indicators": parsed_json.get( "false_positive_indicators", [] ), } return ( json.dumps(fire_analysis, indent=2), json.dumps(environment_analysis, indent=2), json.dumps(detection_analysis, indent=2), json.dumps(report_analysis, indent=2), json_str, "", "Analysis complete", parsed_json, ) except Exception as e: print("DEBUG: Error processing response:", e) return ( "Error processing response", "", "", "", str(result), str(e), "Error", {}, ) # Update Gradio interface with gr.Blocks() as demo: gr.Markdown("# Fire Detection Demo") with gr.Row(): with gr.Column(scale=1): image_input = gr.Image( type="pil", label="Upload Image", elem_id="large-image", ) submit_btn = gr.Button("Analyze Image", variant="primary") # Updated examples gr.Examples( examples=[ "examples/1727808849.jpg", "examples/1727809389.jpg", "examples/Birch MWF014-0001.jpg", "examples/frame_000036.jpg", "examples/frame_000168.jpg", ], inputs=image_input, label="Example Images", examples_per_page=5, ) with gr.Tabs() as tabs: with gr.Tab("Analysis Results"): with gr.Row(): with gr.Column(): fire_output = gr.JSON( label="Fire Details", ) with gr.Column(): environment_output = gr.JSON( label="Environment Details", ) with gr.Row(): with gr.Column(): detection_output = gr.JSON( label="Detection Details", ) with gr.Column(): report_output = gr.JSON( label="Report Details", ) with gr.Tab("JSON Output", id=0): json_output = gr.JSON( label="Detailed JSON Results", ) with gr.Tab("Raw Output"): raw_output = gr.Textbox( label="Raw JSON Response", lines=10, ) error_box = gr.Textbox(label="Error Messages", visible=False) status_text = gr.Textbox(label="Status", value="Ready", interactive=False) submit_btn.click( fn=inference, inputs=[image_input], outputs=[ fire_output, environment_output, detection_output, report_output, raw_output, error_box, status_text, json_output, ], ) demo.launch(share=True)