kraken-api / app /main.py
wjbmattingly's picture
fixed ocr
56a0241
raw
history blame contribute delete
No virus
6.76 kB
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import JSONResponse, FileResponse
from pydantic import BaseModel
import io
from PIL import Image
import json
import tempfile
import base64
from typing import List, Optional, Any
from kraken import binarization
from kraken import pageseg
from kraken import rpred
from kraken.lib import models
from kraken import blla
from kraken import serialization
from kraken.lib.exceptions import KrakenInvalidModelException
import numpy as np
app = FastAPI()
class RawResponse(BaseModel):
result: Any
def serialize_line(line):
# Create a dictionary with all available attributes
line_dict = vars(line)
# If 'bbox' is not available but 'polygon' is, calculate bbox from polygon
if 'bbox' not in line_dict and 'polygon' in line_dict and line_dict['polygon'] is not None:
x_coords, y_coords = zip(*line_dict['polygon'])
bbox = [min(x_coords), min(y_coords), max(x_coords), max(y_coords)]
line_dict['bbox'] = bbox
# Convert numpy arrays to lists for JSON serialization
for key, value in line_dict.items():
if isinstance(value, np.ndarray):
line_dict[key] = value.tolist()
return line_dict
def serialize_region(region):
# Create a dictionary with known attributes
region_dict = {
"id": getattr(region, 'id', None),
"boundary": getattr(region, 'boundary', None),
"tags": getattr(region, 'tags', None),
}
# Convert numpy arrays to lists for JSON serialization
for key, value in region_dict.items():
if isinstance(value, np.ndarray):
region_dict[key] = value.tolist()
return region_dict
@app.post("/detect_lines", response_model=RawResponse)
async def detect_lines(file: UploadFile = File(...)):
content = await file.read()
image = Image.open(io.BytesIO(content))
# Perform baseline and layout analysis (BLLA) segmentation with default model
baseline_seg = blla.segment(image)
serialized_seg = {
"lines": [serialize_line(line) for line in baseline_seg.lines],
"regions": [serialize_region(region) for region in baseline_seg.regions],
"type": baseline_seg.type,
"text_direction": baseline_seg.text_direction,
"script_detection": baseline_seg.script_detection,
}
return RawResponse(result=serialized_seg)
@app.post("/ocr", response_model=RawResponse)
async def perform_ocr(
file: UploadFile = File(...),
model_name: str = Form("catmus-medieval.mlmodel"),
binarize: bool = Form(False)
):
content = await file.read()
image = Image.open(io.BytesIO(content))
# Always binarize the image before segmentation
bw_img = binarization.nlbin(image)
try:
model = models.load_any(model_name)
except KrakenInvalidModelException:
raise HTTPException(status_code=400, detail=f"Model '{model_name}' not found or invalid")
baseline_seg = pageseg.segment(bw_img)
# Use the original image for OCR if binarize is False, otherwise use the binarized image
ocr_image = bw_img if binarize else image
result = list(rpred.rpred(model, ocr_image, baseline_seg))
serialized_result = [
{
"bbox": record.bbox,
# "confidence": record.confidence,
"text": record.prediction,
"cuts": record.cuts,
# "line_id": record.line_id,
}
for record in result
]
return RawResponse(result=serialized_result)
@app.post("/segment", response_model=RawResponse)
async def segment_image(
file: UploadFile = File(...),
baseline: bool = Form(True)
):
content = await file.read()
image = Image.open(io.BytesIO(content))
bw_img = binarization.nlbin(image)
if baseline:
segmentation = pageseg.segment(bw_img)
else:
segmentation = pageseg.segment(bw_img, text_direction='horizontal-lr')
serialized_seg = {
"lines": [serialize_line(line) for line in segmentation.lines],
"regions": [vars(region) for region in segmentation.regions],
"type": segmentation.type,
"text_direction": segmentation.text_direction,
"script_detection": segmentation.script_detection,
}
return RawResponse(result=serialized_seg)
@app.post("/binarize")
async def binarize_image(file: UploadFile = File(...)):
content = await file.read()
image = Image.open(io.BytesIO(content))
bw_img = binarization.nlbin(image)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
bw_img.save(temp_file.name)
return FileResponse(temp_file.name, media_type="image/png", filename="binarized.png")
@app.post("/process_all", response_model=RawResponse)
async def process_all(
file: UploadFile = File(...),
model_name: str = Form("catmus-medieval")
):
content = await file.read()
image = Image.open(io.BytesIO(content))
# Step 1: Binarization
bw_img = binarization.nlbin(image)
# Convert binarized image to base64 for JSON response
buffered = io.BytesIO()
bw_img.save(buffered, format="PNG")
binarized_base64 = base64.b64encode(buffered.getvalue()).decode()
# Step 2: Segmentation
segmentation = pageseg.segment(bw_img)
serialized_seg = {
"lines": [serialize_line(line) for line in segmentation.lines],
"regions": [vars(region) for region in segmentation.regions],
"type": segmentation.type,
"text_direction": segmentation.text_direction,
"script_detection": segmentation.script_detection,
}
# Step 3: OCR
try:
model = models.load_any(model_name)
except KrakenInvalidModelException:
raise HTTPException(status_code=400, detail=f"Model '{model_name}' not found or invalid")
result = list(rpred.rpred(model, bw_img, segmentation))
print(result)
serialized_result = [
{
"bbox": record.bbox,
# "confidence": record.confidence,
"text": record.prediction,
"cuts": record.cuts,
# "line_id": record.line_id,
}
for record in result
]
return RawResponse(result={
"binarized_image": binarized_base64,
"segmentation": serialized_seg,
"ocr_result": serialized_result
})
@app.get("/")
async def root():
return {
"message": "Welcome to the Complete Kraken Python API",
"available_endpoints": ["/detect_lines", "/ocr", "/segment", "/binarize", "/process_all"]
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860, workers=1)