Calin Rada
Beds fix
7b68836 unverified
import logging
import time
from concurrent.futures import ThreadPoolExecutor
from fastapi import APIRouter, Query, Request
from mappingservice.constants import DEFAULT_LABEL, DEFAULT_SCORE
from mappingservice.dependencies import (
mc,
)
from mappingservice.models import (
AllPredictionsResponse,
PredictionResponse,
Predictions,
RoomData,
)
from mappingservice.ms.ml_models.bed_type import BedType as BedTypeModel
from mappingservice.ms.ml_models.environment import Environment
from mappingservice.ms.ml_models.room_category import RoomCategory
from mappingservice.ms.ml_models.room_features import RoomFeatures
from mappingservice.ms.ml_models.room_type import RoomType
from mappingservice.ms.ml_models.room_view import RoomView
from mappingservice.utils import (
get_bed_predictions,
process_predictions,
safe_round,
)
logging.basicConfig(
level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/predict/room", tags=["room"], responses={404: {"description": "Not found"}}
)
def get_room_type_prediction(room_description: str, language: str = "en"):
pipeline = mc['room_type'][language]
model = BedTypeModel()
return {"msg": model.predict(room_description, pipeline, language)}
def get_view_prediction(room_description: str, language: str = "en"):
pipeline = mc['room_view'][language]
model = RoomView()
return {"view_prediction": model.predict(room_description, pipeline, language)}
def get_room_category_prediction(room_description: str, language: str = "en"):
pipeline = mc['room_category'][language]
model = RoomCategory()
return {"msg": model.predict(room_description, pipeline, language)}
def get_feature_prediction(room_description: str, language: str = "en"):
pipeline = mc['room_features'][language]
model = RoomFeatures()
return {"feature_prediction": model.predict(room_description, pipeline, language)}
def get_room_environment_prediction(room_description: str, language: str = "en"):
pipeline = mc['environment'][language]
model = Environment()
return {"msg": model.predict(room_description, pipeline, language)}
@router.post("/predict/beds")
async def predict_beds(request: Request, room_description: str = Query(...)): # noqa: E501
language = request.state.predicted_language
pipeline = mc['bed_type'][language]
model = BedTypeModel()
prediction = model.predict(room_description, pipeline, language)
return prediction
@router.get("/type")
async def predict_room_type_endpoint(request: Request, room_description: str = Query(...)): # noqa: E501
language = request.state.predicted_language
pipeline = mc['room_type'][language]
model = RoomType()
prediction = model.predict(room_description, pipeline, language)
return prediction
@router.get("/category")
async def predict_room_category_endpoint(request: Request, room_description: str = Query(...)): # noqa: E501
prediction = mc['room_category']['en'].predict(room_description)
return prediction
@router.get("/environment")
async def predict_room_environment_endpoint(request: Request, room_description: str = Query(...)): # noqa: E501
prediction = mc['environment']['en'].predict(room_description)
return prediction
@router.get("/view")
async def predict_view_endpoint(request: Request, room_description: str = Query(...)): # noqa: E501
prediction = mc['room_view']['en'].predict(room_description)
return prediction
@router.get("/feature")
async def predict_feature_endpoint(request: Request, room_description: str = Query(...)): # noqa: E501
prediction = mc['room_features']['en'].predict(room_description)
return prediction
@router.post("/predict/all", response_model=AllPredictionsResponse)
async def predict_all(request: Request, room_data: RoomData):
start_time = time.time()
room_data = RoomData(**await request.json())
language = request.state.predicted_language
with ThreadPoolExecutor() as executor:
type_future = executor.submit(
get_room_type_prediction, room_data.room_description, language
)
category_future = executor.submit(
get_room_category_prediction, room_data.room_description, language
)
environment_future = executor.submit(
get_room_environment_prediction, room_data.room_description, language
)
feature_future = executor.submit(
get_feature_prediction, room_data.room_description, language
)
view_future = executor.submit(
get_view_prediction, room_data.room_description, language
)
type_pred = type_future.result()["msg"]
category_pred = category_future.result()["msg"]
environment_pred_results = environment_future.result()["msg"]
feature_pred_results = feature_future.result()["feature_prediction"]
view_pred_results = view_future.result()["view_prediction"]
bed_predictions = room_data.beds
if not room_data.beds:
logger.debug("No bed data provided or valid; extracting from description.")
extracted_beds = get_bed_predictions(room_data.room_description)
if extracted_beds:
bed_predictions.extend(extracted_beds)
if not isinstance(bed_predictions, list):
bed_predictions = [bed_predictions]
end_time = time.time()
total_time = end_time - start_time
logger.info(f"Total processing time: {total_time:.3f} seconds")
formatted_predictions = {
"type": {
"label": type_pred.get("label", DEFAULT_LABEL),
"score": safe_round(type_pred.get("score", DEFAULT_SCORE), 3),
},
"category": {
"label": category_pred.get("label", DEFAULT_LABEL),
"score": safe_round(category_pred.get("score", DEFAULT_SCORE), 3),
},
}
env_preds = process_predictions(environment_pred_results)
feat_preds = process_predictions(
feature_pred_results.get("features", []), label_key="word"
)
view_preds = process_predictions(
view_pred_results.get("views", []), label_key="word"
)
predictions = Predictions(
type=PredictionResponse(**formatted_predictions["type"]),
category=PredictionResponse(**formatted_predictions["category"]),
environment=[PredictionResponse(**pred) for pred in env_preds] if env_preds else [], # noqa: E501
feature=[PredictionResponse(**pred) for pred in feat_preds] if feat_preds else [], # noqa: E501
view=[PredictionResponse(**pred) for pred in view_preds] if view_preds else [],
language_detected=language,
beds=bed_predictions,
)
return AllPredictionsResponse(predictions=predictions)