File size: 5,364 Bytes
8824f88
 
9ae4463
 
8824f88
 
 
 
9ae4463
 
 
 
 
 
8824f88
0adfa8b
ba0ce5d
8824f88
 
c4bb11b
8824f88
 
77bf3e6
c4bb11b
8824f88
ba0ce5d
8824f88
 
 
 
 
 
c4bb11b
 
 
8824f88
 
9ae4463
 
8824f88
 
 
 
c4bb11b
0737a9d
 
34353a1
0737a9d
8824f88
c4bb11b
8824f88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba0ce5d
8824f88
88a6c27
17aeee6
 
88a6c27
9ae4463
c4bb11b
8824f88
 
 
 
 
 
 
 
 
 
c4bb11b
 
 
9ae4463
8824f88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4868a7b
 
 
 
8824f88
9ae4463
 
 
 
 
 
 
 
 
 
8824f88
 
 
 
 
 
ba0ce5d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#!/usr/bin/env python
import os
import random
from threading import Thread  # noqa
from typing import Iterator

import gradio as gr
import spaces
import torch  # noqa
from transformers import (
    AutoModelForCausalLM,  # noqa
    AutoTokenizer,  # noqa
    TextIteratorStreamer,  # noqa
)

from chat_interface_preference import ChatInterface

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))

if torch.cuda.is_available():
    model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_id)


@spaces.GPU
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    max_new_tokens: int = 1024,
    temperature: float = 0.06,
    top_p: float = 0.95,
    top_k: int = 40,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    system_message = random.choice(["concise", "explicit", "simple", "complex", "usefull", "helpfull"])
    conversation = [{"role": "system", "content": f"Communicate {system_message}."}]
    for user, assistant in chat_history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


chat_interface = ChatInterface(
    fn=generate,
    prefence_technique="kto",
    min_turns=1,
    max_turns=10,
    repo_id="llm-human-feedback-collector-chat-interface-kto",
    chatbot=gr.Chatbot(height=450, label="Meta-Llama-3.1-8B-Instruct", show_share_button=True),
    cache_examples=False,
    additional_inputs=[
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.05,
            maximum=1.2,
            step=0.05,
            value=0.7,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.2,
        ),
    ],
    examples=[
        ["""What word doesn't make sense in this row: "car, airplane, lama, bus"?"""],
        ["Write a news article about the usage of Lama's by the CSI"],
        ["What are great things cook when getting started with Asian cooking?"],
        ["Who was Anthony Bourdain?"],
    ],
    title="💪🏽🦾 Human Feedback Collector | Meta-Llama-3.1-8B-Instruct | (KTO) 🦾💪🏽",
    description="".join(
        [
            "This is an adaptation of the [`gr.ChatInferface`](https://www.gradio.app/docs/gradio/chatinterface) which also uses the [`huggingface_hub.CommitScheduler`](https://huggingface.co/docs/huggingface_hub/main/en/package_reference/hf_api#huggingface_hub.CommitScheduler) to allow for human feedback collection. ",
            "Another cool tool for capturing Gradio interactions is the [`gr.HuggingFaceDatasetSaver`](https://www.gradio.app/guides/using-flagging#the-hugging-face-dataset-saver-callback). ",
            "This demo shows how you might capture human feedback directly from applications within Gradio. ",
            "The captured feedback can directly be used for fine-tuning LLMs within framework like [transformers](https://github.com/huggingface/transformers), [TRL](https://github.com/huggingface/trl) or [AutoTrain](https://huggingface.co/autotrain), ",
            "however, it might benefit from additional data curation with something like [Argilla](https://github.com/argilla-io/argilla/) for human feedback and/or [distilabel](https://github.com/argilla-io/distilabel/) for AI feedback. Argilla can even be [deployed for free on Hugging Face Spaces](https://argilla-io.github.io/argilla/latest/getting_started/huggingface-spaces/).",
        ]
    ),
)

with gr.Blocks(css="style.css") as demo:
    chat_interface.render()

if __name__ == "__main__":
    demo.queue(max_size=20).launch()