import json import requests def check_server_health(cloud_gateway_api: str): """ Use the appropriate API endpoint to check the server health. Args: cloud_gateway_api: API endpoint to probe. Returns: True if server is active, false otherwise. """ try: response = requests.get(cloud_gateway_api + "/health") if response.status_code == 200: return True except requests.ConnectionError: print("Failed to establish connection to the server.") return False def request_generation(message: str, system_prompt: str, cloud_gateway_api: str, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2, ): """ Request streaming generation from the cloud gateway API. Uses the simple requests module with stream=True to utilize token-by-token generation from LLM. Args: message: prompt from the user. system_prompt: system prompt to append. cloud_gateway_api (str): API endpoint to send the request. max_new_tokens: maximum number of tokens to generate, ignoring the number of tokens in the prompt. temperature: the value used to module the next token probabilities. top_p: if set to float<1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. top_k: the number of highest probability vocabulary tokens to keep for top-k-filtering. repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty. Returns: """ payload = { "model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": message} ], "max_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "repetition_penalty": repetition_penalty, "top_k": top_k, "stream": True # Enable streaming } with requests.post(cloud_gateway_api + "/v1/chat/completions", json=payload, stream=True) as response: for chunk in response.iter_lines(): if chunk: # Convert the chunk from bytes to a string and then parse it as json chunk_str = chunk.decode('utf-8') # Remove the `data: ` prefix from the chunk if it exists if chunk_str.startswith("data: "): chunk_str = chunk_str[len("data: "):] # Skip empty chunks if chunk_str.strip() == "[DONE]": break # Parse the chunk into a JSON object try: chunk_json = json.loads(chunk_str) # Extract the "content" field from the choices content = chunk_json["choices"][0]["delta"].get("content", "") # Print the generated content as it's streamed if content: yield content except json.JSONDecodeError: # Handle any potential errors in decoding continue