arjunanand13's picture
Create handler.py
ca98417 verified
import subprocess
import sys
import torch
import base64
from io import BytesIO
from PIL import Image
import requests
from transformers import AutoModelForCausalLM, AutoProcessor
import os
def install(package):
subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-warn-script-location", package])
class EndpointHandler:
def __init__(self, path=""):
required_packages = ['timm', 'einops', 'flash-attn', 'Pillow','-U transformers']
for package in required_packages:
try:
install(package)
print(f"Successfully installed {package}")
except Exception as e:
print(f"Failed to install {package}: {str(e)}")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
self.model_name = "arjunanand13/LADP_Florence-60e"
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
trust_remote_code=True,
).to(self.device)
self.processor = AutoProcessor.from_pretrained(
self.model_name,
trust_remote_code=True,
)
if torch.cuda.is_available():
torch.cuda.empty_cache()
def process_image(self,image_data):
print("[DEBUG] Attempting to process image")
try:
# Check if image_data is a file path
if isinstance(image_data, str) and len(image_data) < 256 and os.path.exists(image_data):
with open(image_data, 'rb') as image_file:
print("[DEBUG] File opened successfully")
image = Image.open(image_file)
else:
# Assume image_data is base64 encoded
print("[DEBUG] Decoding base64 image data")
image_bytes = base64.b64decode(image_data)
image = Image.open(BytesIO(image_bytes))
print("[DEBUG] Image opened with PIL:", image.format, image.size, image.mode)
return image
except Exception as e:
print(f"[ERROR] Error processing image: {str(e)}")
return None
def __call__(self, data):
try:
# Extract inputs from the expected Hugging Face format
inputs = data.pop("inputs", data)
# Check if inputs is a dict or string
if isinstance(inputs, dict):
image_path = inputs.get("image", None)
text_input = inputs.get("text", "")
else:
# If inputs is not a dict, assume it's the image path
image_path = inputs
text_input = "What is in this image?"
print("[INFO]",image_path,text_input)
# Process image
image = self.process_image(image_path) if image_path else None
print("[INFO]",image)
# Prepare inputs for the model
model_inputs = self.processor(
images=image if image else None,
text=text_input,
return_tensors="pt"
)
# Move inputs to device
model_inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
for k, v in model_inputs.items()}
# Generate output
with torch.no_grad():
outputs = self.model.generate(**model_inputs)
# Decode outputs
decoded_outputs = self.processor.batch_decode(outputs, skip_special_tokens=True)
print(f"[INFO],{decoded_outputs}")
print(f"[INFO],{decoded_outputs[0]}")
return {"generated_text": decoded_outputs[0]}
except Exception as e:
return {"error": str(e)}