from typing import Tuple from io import BytesIO from fastapi import FastAPI, UploadFile, File from fastapi.middleware.cors import CORSMiddleware from starlette.responses import StreamingResponse import numpy as np import cv2 import torch from torchvision import transforms from PIL import Image, ImageChops import segmentation_models_pytorch as smp constants: dict = { 'encoder_name': 'resnet34', 'encoder_weights': 'imagenet', 'sigmoid_threshold': 0.55, 'model_path': 'models/production/unetplusplus_resnet34.pth' } def load_model() -> smp.UnetPlusPlus: ''' Returns: model: smp.UnetPlusPlus ''' global model if model is None: model = smp.UnetPlusPlus(encoder_name=constants['encoder_name'], encoder_weights=constants['encoder_weights'], in_channels=3, classes=1).to(device) model.load_state_dict(torch.load(constants['model_path'], map_location=device)) return model def draw_bounding_boxes(mask: np.array) -> Tuple[np.array, float]: ''' Arguments: mask: np.array (numpy) Returns: Tuple[np.array, float] ''' mask = mask.astype(np.uint8) contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if len(contours) == 0: return mask, 0.0 mask_bgr = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) mean_area = [] for contour in contours: x, y, w, h = cv2.boundingRect(contour) max_x = x + w max_y = y + h cv2.rectangle(mask_bgr, (x, y), (max_x, max_y), (255, 0, 0), 1) mean_area.append(abs(max_x - x) * abs(max_y - y)) return mask_bgr, sum(mean_area) / len(mean_area) transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((256, 256), antialias=True)]) device = 'cuda' if torch.cuda.is_available() else 'cpu' model = None app = FastAPI(title='Brain MRI Medical Images Segmentation Hugging Face') app.add_middleware(CORSMiddleware, allow_headers=['*'], expose_headers=['has_tumor', 'average_pixel_area']) @app.post('/predict') async def predict(file: UploadFile = File(...)): input_image = Image.open(BytesIO(await file.read())).convert('RGB') array_image = np.array(input_image) transformed_image = transform(array_image) transformed_image = transformed_image.unsqueeze(0) transformed_image = transformed_image.to(device) model = load_model() model.eval() with torch.no_grad(): prediction = model(transformed_image) prediction = torch.sigmoid(prediction) prediction = (prediction > constants['sigmoid_threshold']).float() prediction = prediction.squeeze() prediction = prediction.cpu() prediction = prediction.numpy() prediction = prediction * 255 prediction, mean_area = draw_bounding_boxes(prediction) transformed_image = transformed_image.cpu() transformed_image = transformed_image[0].permute(1, 2, 0) transformed_image = transformed_image.numpy() * 255 transformed_image = transformed_image.astype(np.uint8) input_image = Image.fromarray(transformed_image).convert('RGBA') predicted_mask = Image.fromarray(prediction).convert('RGBA') result = ImageChops.screen(input_image, predicted_mask) bytes_io = BytesIO() result.save(bytes_io, format='PNG') bytes_io.seek(0) return StreamingResponse(bytes_io, media_type='image/png', headers={'has_tumor': f'{prediction.max() > 0}', 'average_pixel_area': f'{round(mean_area, 2)}'})