Spaces:
Runtime error
Runtime error
import logging | |
from flask import Flask, request, jsonify | |
import os | |
from wtforms import Form, StringField | |
from wtforms.validators import DataRequired | |
from config import model_ckpt, pipe, labels, THRESHOLD | |
app = Flask(__name__) | |
class PredictForm(Form): | |
text = StringField('text', [DataRequired()]) | |
def predict(text: str) -> dict: | |
""" | |
Compute predictions for text. | |
:param text: str : The text to be analyzed. | |
:return: dict : A dictionary of predicted language and its score | |
""" | |
try: | |
preds = pipe(text, return_all_scores=True, truncation=True, max_length=128) | |
if preds: | |
pred = preds[0] | |
pred = sorted(pred, key=lambda x: x['score'], reverse=True) | |
if pred[0]["score"] > THRESHOLD: | |
return {labels.get(p["label"],p["label"]): float(p["score"]) for p in pred[:1]} | |
else: | |
score = pred[0]["score"] | |
logger.error("Prediction score below threshold. text: %s, score: %s", text, score) | |
return {'error': "Prediction score below threshold"} | |
else: | |
return {} | |
except Exception as e: | |
logger.error("Error processing request: %s", str(e)) | |
return {'error': str(e)}, 500 | |
def predict_language(): | |
""" | |
A Language Prediction API which accepts 'text' as input and return the language of text along with score | |
--- | |
parameters: | |
- in: body | |
name: text | |
schema: | |
type: string | |
required: true | |
description: The text to be analyzed | |
responses: | |
200: | |
description: A JSON object containing the language and its score | |
schema: | |
type: object | |
400: | |
description: Invalid request | |
500: | |
description: Internal server error | |
400: | |
description: Prediction score below threshold | |
""" | |
text = request.json.get('text') | |
if not text or len(text)==0: | |
return jsonify({'error': 'Empty text provided'}), 400 | |
result = predict(text) | |
if result: | |
return jsonify(result) | |
else: | |
return jsonify({'error': 'No predictions found'}), 400 | |
if __name__ == '__main__': | |
log_file = 'app.log' | |
logging.basicConfig(filename=log_file, level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
logger.info("Running the app...") | |
app.run() | |