Spaces:
Runtime error
Runtime error
File size: 3,097 Bytes
9e6b8ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
from typing import Any, Dict, Generator, List
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
temperature = 0.9
top_p = 0.6
repetition_penalty = 1.2
text_client = InferenceClient(
"mistralai/Mistral-7B-Instruct-v0.1"
)
def format_prompt(message: str) -> str:
"""
Formats the given message using a chat template.
Args:
message (str): The user message to be formatted.
Returns:
str: Formatted message after applying the chat template.
"""
# Create a list of message dictionaries with role and content
messages: List[Dict[str, Any]] = [{'role': 'user', 'content': message}]
# Return the message after applying the chat template
return tokenizer.apply_chat_template(messages, tokenize=False)
def generate(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 256,
top_p: float = 0.95, repetition_penalty: float = 1.0) -> Generator[str, None, str]:
"""
Generate a sequence of tokens based on a given prompt and history using Mistral client.
Args:
prompt (str): The initial prompt for the text generation.
history (str): Context or history for the text generation.
temperature (float, optional): The softmax temperature for sampling. Defaults to 0.9.
max_new_tokens (int, optional): Maximum number of tokens to be generated. Defaults to 256.
top_p (float, optional): Nucleus sampling probability. Defaults to 0.95.
repetition_penalty (float, optional): Penalty for repeated tokens. Defaults to 1.0.
Returns:
Generator[str, None, str]: A generator yielding chunks of generated text.
Returns a final string if an error occurs.
"""
temperature = max(float(temperature), 1e-2) # Ensure temperature isn't too low
top_p = float(top_p)
generate_kwargs = {
'temperature': temperature,
'max_new_tokens': max_new_tokens,
'top_p': top_p,
'repetition_penalty': repetition_penalty,
'do_sample': True,
'seed': 42,
}
formatted_prompt = format_prompt(prompt)
try:
stream = text_client.text_generation(formatted_prompt, **generate_kwargs,
stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
yield output
except Exception as e:
if "Too Many Requests" in str(e):
print("ERROR: Too many requests on Mistral client")
gr.Warning("Unfortunately Mistral is unable to process")
return "Unfortunately, I am not able to process your request now."
else:
print("Unhandled Exception:", str(e))
gr.Warning("Unfortunately Mistral is unable to process")
return "I do not know what happened, but I couldn't understand you."
return output
|