File size: 3,097 Bytes
9e6b8ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from typing import Any, Dict, Generator, List

import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")

temperature = 0.9
top_p = 0.6
repetition_penalty = 1.2

text_client = InferenceClient(
        "mistralai/Mistral-7B-Instruct-v0.1"
        )


def format_prompt(message: str) -> str:
    """
    Formats the given message using a chat template.

    Args:
        message (str): The user message to be formatted.

    Returns:
        str: Formatted message after applying the chat template.
    """

    # Create a list of message dictionaries with role and content
    messages: List[Dict[str, Any]] = [{'role': 'user', 'content': message}]

    # Return the message after applying the chat template
    return tokenizer.apply_chat_template(messages, tokenize=False)


def generate(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 256,
             top_p: float = 0.95, repetition_penalty: float = 1.0) -> Generator[str, None, str]:
    """
    Generate a sequence of tokens based on a given prompt and history using Mistral client.

    Args:
        prompt (str): The initial prompt for the text generation.
        history (str): Context or history for the text generation.
        temperature (float, optional): The softmax temperature for sampling. Defaults to 0.9.
        max_new_tokens (int, optional): Maximum number of tokens to be generated. Defaults to 256.
        top_p (float, optional): Nucleus sampling probability. Defaults to 0.95.
        repetition_penalty (float, optional): Penalty for repeated tokens. Defaults to 1.0.

    Returns:
        Generator[str, None, str]: A generator yielding chunks of generated text.
                                   Returns a final string if an error occurs.
    """

    temperature = max(float(temperature), 1e-2)  # Ensure temperature isn't too low
    top_p = float(top_p)

    generate_kwargs = {
        'temperature': temperature,
        'max_new_tokens': max_new_tokens,
        'top_p': top_p,
        'repetition_penalty': repetition_penalty,
        'do_sample': True,
        'seed': 42,
        }

    formatted_prompt = format_prompt(prompt)

    try:
        stream = text_client.text_generation(formatted_prompt, **generate_kwargs,
                                             stream=True, details=True, return_full_text=False)
        output = ""
        for response in stream:
            output += response.token.text
            yield output

    except Exception as e:
        if "Too Many Requests" in str(e):
            print("ERROR: Too many requests on Mistral client")
            gr.Warning("Unfortunately Mistral is unable to process")
            return "Unfortunately, I am not able to process your request now."
        else:
            print("Unhandled Exception:", str(e))
            gr.Warning("Unfortunately Mistral is unable to process")
            return "I do not know what happened, but I couldn't understand you."

    return output