Spaces:
Sleeping
Sleeping
import openai | |
import gradio as gr | |
import os | |
from typing import Any, Dict, Generator, List | |
from huggingface_hub import InferenceClient | |
from transformers import AutoTokenizer | |
OPENAI_KEY = os.getenv("OPENAI_API_KEY") | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
TOKENIZER = AutoTokenizer.from_pretrained(os.getenv("HF_MODEL")) | |
HF_CLIENT = InferenceClient( | |
os.getenv("HF_MODEL"), | |
token=HF_TOKEN | |
) | |
OAI_CLIENT = openai.Client(api_key=OPENAI_KEY) | |
HF_GENERATE_KWARGS = { | |
'temperature': max(float(os.getenv("TEMPERATURE", 0.9)), 1e-2), | |
'max_new_tokens': int(os.getenv("MAX_NEW_TOKENS", 256)), | |
'top_p': float(os.getenv("TOP_P", 0.6)), | |
'repetition_penalty': float(os.getenv("REP_PENALTY", 1.2)), | |
'do_sample': bool(os.getenv("DO_SAMPLE", True)) | |
} | |
OAI_GENERATE_KWARGS = { | |
'temperature': max(float(os.getenv("TEMPERATURE", 0.9)), 1e-2), | |
'max_tokens': int(os.getenv("MAX_NEW_TOKENS", 256)), | |
'top_p': float(os.getenv("TOP_P", 0.6)), | |
'frequency_penalty': max(-2, min(float(os.getenv("FREQ_PENALTY", 0)), 2)) | |
} | |
def format_prompt(message: str, api_kind: str, tokenizer_name = None): | |
""" | |
Formats the given message using a chat template. | |
Args: | |
message (str): The user message to be formatted. | |
api_kind (str): LLM API provider. | |
Returns: | |
str: Formatted message after applying the chat template. | |
""" | |
# Create a list of message dictionaries with role and content | |
messages: List[Dict[str, Any]] = [{'role': 'user', 'content': message}] | |
if api_kind == "openai": | |
return messages | |
elif api_kind == "hf": | |
TOKENIZER = AutoTokenizer.from_pretrained(tokenizer_name) | |
return TOKENIZER.apply_chat_template(messages, tokenize=False) | |
elif api_kind: | |
raise ValueError("API is not supported") | |
def generate_hf(prompt: str, history: str, hf_model_name: str) -> Generator[str, None, str]: | |
""" | |
Generate a sequence of tokens based on a given prompt and history using Mistral client. | |
Args: | |
prompt (str): The prompt for the text generation. | |
history (str): Context or history for the text generation. | |
Returns: | |
Generator[str, None, str]: A generator yielding chunks of generated text. | |
Returns a final string if an error occurs. | |
""" | |
HF_CLIENT = InferenceClient( | |
hf_model_name, | |
token=os.getenv("HF_TOKEN") | |
) | |
formatted_prompt = format_prompt(prompt, "hf", hf_model_name) | |
formatted_prompt = formatted_prompt.encode("utf-8").decode("utf-8") | |
try: | |
stream = HF_CLIENT.text_generation( | |
formatted_prompt, | |
**HF_GENERATE_KWARGS, | |
stream=True, | |
details=True, | |
return_full_text=False | |
) | |
output = "" | |
for response in stream: | |
output += response.token.text | |
yield output | |
except Exception as e: | |
if "Too Many Requests" in str(e): | |
raise gr.Error(f"Too many requests: {str(e)}") | |
elif "Authorization header is invalid" in str(e): | |
raise gr.Error("Authentication error: HF token was either not provided or incorrect") | |
else: | |
raise gr.Error(f"Unhandled Exception: {str(e)}") | |
def generate_openai(prompt: str, history: str, model_name: str) -> Generator[str, None, str]: | |
""" | |
Generate a sequence of tokens based on a given prompt and history using Mistral client. | |
Args: | |
prompt (str): The initial prompt for the text generation. | |
history (str): Context or history for the text generation. | |
Returns: | |
Generator[str, None, str]: A generator yielding chunks of generated text. | |
Returns a final string if an error occurs. | |
""" | |
formatted_prompt = format_prompt(prompt, "openai") | |
try: | |
stream = OAI_CLIENT.chat.completions.create( | |
#model=os.getenv("OPENAI_MODEL"), | |
model = model_name, | |
messages=formatted_prompt, | |
**OAI_GENERATE_KWARGS, | |
stream=True | |
) | |
output = "" | |
for chunk in stream: | |
if chunk.choices[0].delta.content: | |
output += chunk.choices[0].delta.content | |
yield output | |
except Exception as e: | |
if "Too Many Requests" in str(e): | |
raise gr.Error("ERROR: Too many requests on OpenAI client") | |
elif "You didn't provide an API key" in str(e): | |
raise gr.Error("Authentication error: OpenAI key was either not provided or incorrect") | |
else: | |
raise gr.Error(f"Unhandled Exception: {str(e)}") | |