Spaces:
Running
Running
from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
from fastapi.responses import JSONResponse, FileResponse | |
from pydantic import BaseModel | |
import io | |
from PIL import Image | |
import json | |
import tempfile | |
import base64 | |
from typing import List, Optional, Any | |
from kraken import binarization | |
from kraken import pageseg | |
from kraken import rpred | |
from kraken.lib import models | |
from kraken import blla | |
from kraken import serialization | |
from kraken.lib.exceptions import KrakenInvalidModelException | |
import numpy as np | |
app = FastAPI() | |
class RawResponse(BaseModel): | |
result: Any | |
def serialize_line(line): | |
# Create a dictionary with all available attributes | |
line_dict = vars(line) | |
# If 'bbox' is not available but 'polygon' is, calculate bbox from polygon | |
if 'bbox' not in line_dict and 'polygon' in line_dict and line_dict['polygon'] is not None: | |
x_coords, y_coords = zip(*line_dict['polygon']) | |
bbox = [min(x_coords), min(y_coords), max(x_coords), max(y_coords)] | |
line_dict['bbox'] = bbox | |
# Convert numpy arrays to lists for JSON serialization | |
for key, value in line_dict.items(): | |
if isinstance(value, np.ndarray): | |
line_dict[key] = value.tolist() | |
return line_dict | |
def serialize_region(region): | |
# Create a dictionary with known attributes | |
region_dict = { | |
"id": getattr(region, 'id', None), | |
"boundary": getattr(region, 'boundary', None), | |
"tags": getattr(region, 'tags', None), | |
} | |
# Convert numpy arrays to lists for JSON serialization | |
for key, value in region_dict.items(): | |
if isinstance(value, np.ndarray): | |
region_dict[key] = value.tolist() | |
return region_dict | |
async def detect_lines(file: UploadFile = File(...)): | |
content = await file.read() | |
image = Image.open(io.BytesIO(content)) | |
# Perform baseline and layout analysis (BLLA) segmentation with default model | |
baseline_seg = blla.segment(image) | |
serialized_seg = { | |
"lines": [serialize_line(line) for line in baseline_seg.lines], | |
"regions": [serialize_region(region) for region in baseline_seg.regions], | |
"type": baseline_seg.type, | |
"text_direction": baseline_seg.text_direction, | |
"script_detection": baseline_seg.script_detection, | |
} | |
return RawResponse(result=serialized_seg) | |
async def perform_ocr( | |
file: UploadFile = File(...), | |
model_name: str = Form("catmus-medieval.mlmodel"), | |
binarize: bool = Form(False) | |
): | |
content = await file.read() | |
image = Image.open(io.BytesIO(content)) | |
# Always binarize the image before segmentation | |
bw_img = binarization.nlbin(image) | |
try: | |
model = models.load_any(model_name) | |
except KrakenInvalidModelException: | |
raise HTTPException(status_code=400, detail=f"Model '{model_name}' not found or invalid") | |
baseline_seg = pageseg.segment(bw_img) | |
# Use the original image for OCR if binarize is False, otherwise use the binarized image | |
ocr_image = bw_img if binarize else image | |
result = list(rpred.rpred(model, ocr_image, baseline_seg)) | |
serialized_result = [ | |
{ | |
"bbox": record.bbox, | |
# "confidence": record.confidence, | |
"text": record.prediction, | |
"cuts": record.cuts, | |
# "line_id": record.line_id, | |
} | |
for record in result | |
] | |
return RawResponse(result=serialized_result) | |
async def segment_image( | |
file: UploadFile = File(...), | |
baseline: bool = Form(True) | |
): | |
content = await file.read() | |
image = Image.open(io.BytesIO(content)) | |
bw_img = binarization.nlbin(image) | |
if baseline: | |
segmentation = pageseg.segment(bw_img) | |
else: | |
segmentation = pageseg.segment(bw_img, text_direction='horizontal-lr') | |
serialized_seg = { | |
"lines": [serialize_line(line) for line in segmentation.lines], | |
"regions": [vars(region) for region in segmentation.regions], | |
"type": segmentation.type, | |
"text_direction": segmentation.text_direction, | |
"script_detection": segmentation.script_detection, | |
} | |
return RawResponse(result=serialized_seg) | |
async def binarize_image(file: UploadFile = File(...)): | |
content = await file.read() | |
image = Image.open(io.BytesIO(content)) | |
bw_img = binarization.nlbin(image) | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: | |
bw_img.save(temp_file.name) | |
return FileResponse(temp_file.name, media_type="image/png", filename="binarized.png") | |
async def process_all( | |
file: UploadFile = File(...), | |
model_name: str = Form("catmus-medieval") | |
): | |
content = await file.read() | |
image = Image.open(io.BytesIO(content)) | |
# Step 1: Binarization | |
bw_img = binarization.nlbin(image) | |
# Convert binarized image to base64 for JSON response | |
buffered = io.BytesIO() | |
bw_img.save(buffered, format="PNG") | |
binarized_base64 = base64.b64encode(buffered.getvalue()).decode() | |
# Step 2: Segmentation | |
segmentation = pageseg.segment(bw_img) | |
serialized_seg = { | |
"lines": [serialize_line(line) for line in segmentation.lines], | |
"regions": [vars(region) for region in segmentation.regions], | |
"type": segmentation.type, | |
"text_direction": segmentation.text_direction, | |
"script_detection": segmentation.script_detection, | |
} | |
# Step 3: OCR | |
try: | |
model = models.load_any(model_name) | |
except KrakenInvalidModelException: | |
raise HTTPException(status_code=400, detail=f"Model '{model_name}' not found or invalid") | |
result = list(rpred.rpred(model, bw_img, segmentation)) | |
print(result) | |
serialized_result = [ | |
{ | |
"bbox": record.bbox, | |
# "confidence": record.confidence, | |
"text": record.prediction, | |
"cuts": record.cuts, | |
# "line_id": record.line_id, | |
} | |
for record in result | |
] | |
return RawResponse(result={ | |
"binarized_image": binarized_base64, | |
"segmentation": serialized_seg, | |
"ocr_result": serialized_result | |
}) | |
async def root(): | |
return { | |
"message": "Welcome to the Complete Kraken Python API", | |
"available_endpoints": ["/detect_lines", "/ocr", "/segment", "/binarize", "/process_all"] | |
} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860, workers=1) |