Spaces:
Paused
Paused
import logging | |
import time | |
from concurrent.futures import ThreadPoolExecutor | |
from fastapi import APIRouter, Query, Request | |
from mappingservice.constants import DEFAULT_LABEL, DEFAULT_SCORE | |
from mappingservice.dependencies import ( | |
mc, | |
) | |
from mappingservice.models import ( | |
AllPredictionsResponse, | |
PredictionResponse, | |
Predictions, | |
RoomData, | |
) | |
from mappingservice.ms.ml_models.bed_type import BedType as BedTypeModel | |
from mappingservice.ms.ml_models.environment import Environment | |
from mappingservice.ms.ml_models.room_category import RoomCategory | |
from mappingservice.ms.ml_models.room_features import RoomFeatures | |
from mappingservice.ms.ml_models.room_type import RoomType | |
from mappingservice.ms.ml_models.room_view import RoomView | |
from mappingservice.utils import ( | |
get_bed_predictions, | |
process_predictions, | |
safe_round, | |
) | |
logging.basicConfig( | |
level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
logger = logging.getLogger(__name__) | |
router = APIRouter( | |
prefix="/predict/room", tags=["room"], responses={404: {"description": "Not found"}} | |
) | |
def get_room_type_prediction(room_description: str, language: str = "en"): | |
pipeline = mc['room_type'][language] | |
model = BedTypeModel() | |
return {"msg": model.predict(room_description, pipeline, language)} | |
def get_view_prediction(room_description: str, language: str = "en"): | |
pipeline = mc['room_view'][language] | |
model = RoomView() | |
return {"view_prediction": model.predict(room_description, pipeline, language)} | |
def get_room_category_prediction(room_description: str, language: str = "en"): | |
pipeline = mc['room_category'][language] | |
model = RoomCategory() | |
return {"msg": model.predict(room_description, pipeline, language)} | |
def get_feature_prediction(room_description: str, language: str = "en"): | |
pipeline = mc['room_features'][language] | |
model = RoomFeatures() | |
return {"feature_prediction": model.predict(room_description, pipeline, language)} | |
def get_room_environment_prediction(room_description: str, language: str = "en"): | |
pipeline = mc['environment'][language] | |
model = Environment() | |
return {"msg": model.predict(room_description, pipeline, language)} | |
async def predict_beds(request: Request, room_description: str = Query(...)): # noqa: E501 | |
language = request.state.predicted_language | |
pipeline = mc['bed_type'][language] | |
model = BedTypeModel() | |
prediction = model.predict(room_description, pipeline, language) | |
return prediction | |
async def predict_room_type_endpoint(request: Request, room_description: str = Query(...)): # noqa: E501 | |
language = request.state.predicted_language | |
pipeline = mc['room_type'][language] | |
model = RoomType() | |
prediction = model.predict(room_description, pipeline, language) | |
return prediction | |
async def predict_room_category_endpoint(request: Request, room_description: str = Query(...)): # noqa: E501 | |
prediction = mc['room_category']['en'].predict(room_description) | |
return prediction | |
async def predict_room_environment_endpoint(request: Request, room_description: str = Query(...)): # noqa: E501 | |
prediction = mc['environment']['en'].predict(room_description) | |
return prediction | |
async def predict_view_endpoint(request: Request, room_description: str = Query(...)): # noqa: E501 | |
prediction = mc['room_view']['en'].predict(room_description) | |
return prediction | |
async def predict_feature_endpoint(request: Request, room_description: str = Query(...)): # noqa: E501 | |
prediction = mc['room_features']['en'].predict(room_description) | |
return prediction | |
async def predict_all(request: Request, room_data: RoomData): | |
start_time = time.time() | |
room_data = RoomData(**await request.json()) | |
language = request.state.predicted_language | |
with ThreadPoolExecutor() as executor: | |
type_future = executor.submit( | |
get_room_type_prediction, room_data.room_description, language | |
) | |
category_future = executor.submit( | |
get_room_category_prediction, room_data.room_description, language | |
) | |
environment_future = executor.submit( | |
get_room_environment_prediction, room_data.room_description, language | |
) | |
feature_future = executor.submit( | |
get_feature_prediction, room_data.room_description, language | |
) | |
view_future = executor.submit( | |
get_view_prediction, room_data.room_description, language | |
) | |
type_pred = type_future.result()["msg"] | |
category_pred = category_future.result()["msg"] | |
environment_pred_results = environment_future.result()["msg"] | |
feature_pred_results = feature_future.result()["feature_prediction"] | |
view_pred_results = view_future.result()["view_prediction"] | |
bed_predictions = [ | |
bed_data for bed_data in room_data.beds if bed_data.type is not None and bed_data.count is not None # noqa: E501 | |
] | |
if not bed_predictions: | |
logger.debug("No bed data provided or valid; extracting from description.") | |
extracted_beds = get_bed_predictions(room_data.room_description) | |
if extracted_beds: | |
bed_predictions.extend(extracted_beds) | |
end_time = time.time() | |
total_time = end_time - start_time | |
logger.info(f"Total processing time: {total_time:.3f} seconds") | |
formatted_predictions = { | |
"type": { | |
"label": type_pred.get("label", DEFAULT_LABEL), | |
"score": safe_round(type_pred.get("score", DEFAULT_SCORE), 3), | |
}, | |
"category": { | |
"label": category_pred.get("label", DEFAULT_LABEL), | |
"score": safe_round(category_pred.get("score", DEFAULT_SCORE), 3), | |
}, | |
} | |
env_preds = process_predictions(environment_pred_results) | |
feat_preds = process_predictions( | |
feature_pred_results.get("features", []), label_key="word" | |
) | |
view_preds = process_predictions( | |
view_pred_results.get("views", []), label_key="word" | |
) | |
predictions = Predictions( | |
type=PredictionResponse(**formatted_predictions["type"]), | |
category=PredictionResponse(**formatted_predictions["category"]), | |
environment=[PredictionResponse(**pred) for pred in env_preds] if env_preds else [], # noqa: E501 | |
feature=[PredictionResponse(**pred) for pred in feat_preds] if feat_preds else [], # noqa: E501 | |
view=[PredictionResponse(**pred) for pred in view_preds] if view_preds else [], | |
language_detected=language, | |
beds=bed_predictions, | |
) | |
return AllPredictionsResponse(predictions=predictions) | |