import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer from backtrack_sampler import BacktrackSampler, CreativeWritingStrategy from backtrack_sampler.provider.transformers_provider import TransformersProvider import torch import spaces import asyncio description = """## Compare Creative Writing: Standard Sampler vs. Backtrack Sampler with Creative Writing Strategy This is a demo of the [Backtrack Sampler](https://github.com/Mihaiii/backtrack_sampler) framework using "Creative Writing Strategy".
On the left is the output of the standard sampler and on the right the output privided by Backtrack Sampler. """ model_name = "unsloth/Llama-3.2-1B-Instruct" device = torch.device('cuda') tokenizer = AutoTokenizer.from_pretrained(model_name) model1 = AutoModelForCausalLM.from_pretrained(model_name).to("cuda") model2 = AutoModelForCausalLM.from_pretrained(model_name) provider = TransformersProvider(model2, tokenizer, device) strategy = CreativeWritingStrategy(provider, top_p_flat = 0.65, top_k_threshold_flat = 9, eos_penalty = 0.75) creative_sampler = BacktrackSampler(provider, strategy) def create_chat_template_messages(history, prompt): messages = [{"role": "user", "content": prompt}] for i, (input_text, response_text) in enumerate(history): messages.append({"role": "user" if i % 2 == 0 else "assistant", "content": input_text}) messages.append({"role": "assistant", "content": response_text}) return messages @spaces.GPU(duration=60) def generate_responses(prompt, history): messages = create_chat_template_messages(history, prompt) wrapped_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) #it already has special tokens from wrapped_prompt inputs = tokenizer.encode(wrapped_prompt, add_special_tokens=False, return_tensors="pt").to("cuda") async def custom_sampler_task(): generated_list = [] generator = creative_sampler.generate(wrapped_prompt, max_new_tokens=1024, temperature=1) for token in generator: generated_list.append(token) return tokenizer.decode(generated_list, skip_special_tokens=True) custom_output = asyncio.run(custom_sampler_task()) standard_output = model1.generate(inputs, max_new_tokens=1024, temperature=1) standard_response = tokenizer.decode(standard_output[0][len(inputs[0]):], skip_special_tokens=True) return standard_response.strip(), custom_output.strip() with gr.Blocks(theme=gr.themes.Citrus()) as demo: gr.Markdown(description) with gr.Row(): standard_chat = gr.Chatbot(label="Standard Sampler") custom_chat = gr.Chatbot(label="Creative Writing Strategy") with gr.Row(): prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Type your message here...", lines=1) examples = [ "Write me a short story about a talking dog who wants to be a detective.", "Tell me a short tale of a dragon who is afraid of heights.", "Create a short story where aliens land on Earth, but they just want to throw a party." ] gr.Examples(examples=examples, inputs=prompt_input) submit_button = gr.Button("Submit") def update_chat(prompt, standard_history, custom_history): standard_response, custom_response = generate_responses(prompt, standard_history) standard_history = standard_history + [(prompt, standard_response)] custom_history = custom_history + [(prompt, custom_response)] return standard_history, custom_history, "" prompt_input.submit(fn=update_chat, inputs=[prompt_input, standard_chat, custom_chat], outputs=[standard_chat, custom_chat, prompt_input]) submit_button.click(fn=update_chat, inputs=[prompt_input, standard_chat, custom_chat], outputs=[standard_chat, custom_chat, prompt_input]) demo.queue().launch(debug=True)