Spaces:
Sleeping
Sleeping
File size: 6,764 Bytes
a0babed cf28c32 a0babed 903eae7 a0babed a7690a3 a0babed 7497534 cf28c32 fdd3057 cf28c32 903eae7 a0babed 7497534 87e06fb 7497534 87e06fb 7497534 903eae7 cf28c32 a0babed b7fa040 06f1e0d 903eae7 7497534 87e06fb 903eae7 a0babed 903eae7 a0babed 903eae7 a0babed 107dc4b a0babed 107dc4b a0babed 7497534 f381fe6 7497534 a0babed 107dc4b a0babed 107dc4b 903eae7 56a0241 903eae7 56a0241 903eae7 a0babed 903eae7 a0babed 903eae7 7497534 903eae7 a0babed 903eae7 a0babed 903eae7 a0babed 9bbf9be a0babed 903eae7 7497534 903eae7 a0babed 7497534 7944035 7497534 903eae7 56a0241 903eae7 56a0241 903eae7 56a0241 903eae7 a0babed cf28c32 a0babed 7497534 a0babed cf28c32 |
|
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import JSONResponse, FileResponse
from pydantic import BaseModel
import io
from PIL import Image
import json
import tempfile
import base64
from typing import List, Optional, Any
from kraken import binarization
from kraken import pageseg
from kraken import rpred
from kraken.lib import models
from kraken import blla
from kraken import serialization
from kraken.lib.exceptions import KrakenInvalidModelException
import numpy as np
app = FastAPI()
class RawResponse(BaseModel):
result: Any
def serialize_line(line):
# Create a dictionary with all available attributes
line_dict = vars(line)
# If 'bbox' is not available but 'polygon' is, calculate bbox from polygon
if 'bbox' not in line_dict and 'polygon' in line_dict and line_dict['polygon'] is not None:
x_coords, y_coords = zip(*line_dict['polygon'])
bbox = [min(x_coords), min(y_coords), max(x_coords), max(y_coords)]
line_dict['bbox'] = bbox
# Convert numpy arrays to lists for JSON serialization
for key, value in line_dict.items():
if isinstance(value, np.ndarray):
line_dict[key] = value.tolist()
return line_dict
def serialize_region(region):
# Create a dictionary with known attributes
region_dict = {
"id": getattr(region, 'id', None),
"boundary": getattr(region, 'boundary', None),
"tags": getattr(region, 'tags', None),
}
# Convert numpy arrays to lists for JSON serialization
for key, value in region_dict.items():
if isinstance(value, np.ndarray):
region_dict[key] = value.tolist()
return region_dict
@app.post("/detect_lines", response_model=RawResponse)
async def detect_lines(file: UploadFile = File(...)):
content = await file.read()
image = Image.open(io.BytesIO(content))
# Perform baseline and layout analysis (BLLA) segmentation with default model
baseline_seg = blla.segment(image)
serialized_seg = {
"lines": [serialize_line(line) for line in baseline_seg.lines],
"regions": [serialize_region(region) for region in baseline_seg.regions],
"type": baseline_seg.type,
"text_direction": baseline_seg.text_direction,
"script_detection": baseline_seg.script_detection,
}
return RawResponse(result=serialized_seg)
@app.post("/ocr", response_model=RawResponse)
async def perform_ocr(
file: UploadFile = File(...),
model_name: str = Form("catmus-medieval.mlmodel"),
binarize: bool = Form(False)
):
content = await file.read()
image = Image.open(io.BytesIO(content))
# Always binarize the image before segmentation
bw_img = binarization.nlbin(image)
try:
model = models.load_any(model_name)
except KrakenInvalidModelException:
raise HTTPException(status_code=400, detail=f"Model '{model_name}' not found or invalid")
baseline_seg = pageseg.segment(bw_img)
# Use the original image for OCR if binarize is False, otherwise use the binarized image
ocr_image = bw_img if binarize else image
result = list(rpred.rpred(model, ocr_image, baseline_seg))
serialized_result = [
{
"bbox": record.bbox,
# "confidence": record.confidence,
"text": record.prediction,
"cuts": record.cuts,
# "line_id": record.line_id,
}
for record in result
]
return RawResponse(result=serialized_result)
@app.post("/segment", response_model=RawResponse)
async def segment_image(
file: UploadFile = File(...),
baseline: bool = Form(True)
):
content = await file.read()
image = Image.open(io.BytesIO(content))
bw_img = binarization.nlbin(image)
if baseline:
segmentation = pageseg.segment(bw_img)
else:
segmentation = pageseg.segment(bw_img, text_direction='horizontal-lr')
serialized_seg = {
"lines": [serialize_line(line) for line in segmentation.lines],
"regions": [vars(region) for region in segmentation.regions],
"type": segmentation.type,
"text_direction": segmentation.text_direction,
"script_detection": segmentation.script_detection,
}
return RawResponse(result=serialized_seg)
@app.post("/binarize")
async def binarize_image(file: UploadFile = File(...)):
content = await file.read()
image = Image.open(io.BytesIO(content))
bw_img = binarization.nlbin(image)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
bw_img.save(temp_file.name)
return FileResponse(temp_file.name, media_type="image/png", filename="binarized.png")
@app.post("/process_all", response_model=RawResponse)
async def process_all(
file: UploadFile = File(...),
model_name: str = Form("catmus-medieval")
):
content = await file.read()
image = Image.open(io.BytesIO(content))
# Step 1: Binarization
bw_img = binarization.nlbin(image)
# Convert binarized image to base64 for JSON response
buffered = io.BytesIO()
bw_img.save(buffered, format="PNG")
binarized_base64 = base64.b64encode(buffered.getvalue()).decode()
# Step 2: Segmentation
segmentation = pageseg.segment(bw_img)
serialized_seg = {
"lines": [serialize_line(line) for line in segmentation.lines],
"regions": [vars(region) for region in segmentation.regions],
"type": segmentation.type,
"text_direction": segmentation.text_direction,
"script_detection": segmentation.script_detection,
}
# Step 3: OCR
try:
model = models.load_any(model_name)
except KrakenInvalidModelException:
raise HTTPException(status_code=400, detail=f"Model '{model_name}' not found or invalid")
result = list(rpred.rpred(model, bw_img, segmentation))
print(result)
serialized_result = [
{
"bbox": record.bbox,
# "confidence": record.confidence,
"text": record.prediction,
"cuts": record.cuts,
# "line_id": record.line_id,
}
for record in result
]
return RawResponse(result={
"binarized_image": binarized_base64,
"segmentation": serialized_seg,
"ocr_result": serialized_result
})
@app.get("/")
async def root():
return {
"message": "Welcome to the Complete Kraken Python API",
"available_endpoints": ["/detect_lines", "/ocr", "/segment", "/binarize", "/process_all"]
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860, workers=1) |