import datetime import logging import os from os import getenv import time import gradio as gr import requests # Setting up the logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) API_URL = getenv('API_URL') BEARER = getenv('BEARER') headers = { "Authorization": f"Bearer {BEARER}", "Content-Type": "application/json" } def call_jais(payload): try: response = requests.post(API_URL, headers=headers, json=payload) response.raise_for_status() # This will raise an exception for HTTP error codes return response.json() except requests.exceptions.HTTPError as http_err: # Check if the error is a 5XX server error if 500 <= http_err.response.status_code < 600: raise gr.Warning("The endpoint is loading, it takes about 4 min from the first call.") else: raise gr.Warning(f"An error occurred while processing the request. {http_err}") except Exception as err: raise gr.Warning(f"Check Inference Endpoint Status. An error occurred while processing the request. {err}") def generate(prompt: str): start_time = time.perf_counter() payload = {'inputs': '', 'prompt': prompt} response = call_jais(payload) end_time = time.perf_counter() elapsed_time = end_time - start_time logger.warning(f"Function took {elapsed_time:.1f} seconds to execute") return response def check_endpoint_status(): # Replace with the actual API URL and headers api_url = os.getenv("ENDPOINT_URL") headers = { 'accept': 'application/json', 'Authorization': f'Bearer {os.getenv("BEARER")}' } try: response = requests.get(api_url, headers=headers) response.raise_for_status() data = response.json() # Extracting the status information status = data.get('status', {}).get('state', 'No status found') message = data.get('status', {}).get('message', 'No message found') if status == "scaledToZero": return f"