Spaces:
Runtime error
Runtime error
File size: 2,469 Bytes
bd0332f 933ec2b a475ce0 bd0332f 9ce660a bd0332f 933ec2b 4becd74 933ec2b a475ce0 b94cdc8 933ec2b d57f720 933ec2b d57f720 933ec2b bb98ae2 933ec2b bb98ae2 933ec2b bb98ae2 933ec2b bb98ae2 bd0332f d57f720 933ec2b d57f720 933ec2b b94cdc8 bd0332f 933ec2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import spaces
model_id = "mistralai/Mistral-Nemo-Instruct-2407"
# Загрузка токенизатора и модели
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
load_in_8bit=True
)
@spaces.GPU
def predict(message, history, max_tokens, temperature, top_p):
# Формирование чата из истории и нового сообщения
chat = [{"role": "user" if i % 2 == 0 else "assistant", "content": m}
for i, (m, _) in enumerate(history)] + [{"role": "user", "content": message}]
# Применение шаблона чата
prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
# Кодирование входных данных
inputs = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
# Генерация ответа
outputs = model.generate(
input_ids=inputs,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
)
# Декодирование результата
response = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
return response.strip().replace("assistant", "", 1)
# Настройка интерфейса Gradio
iface = gr.ChatInterface(
predict,
chatbot=gr.Chatbot(height=600),
textbox=gr.Textbox(placeholder="Введите ваше сообщение здесь...", container=False, scale=7),
title="Чат с Aeonium v1.1",
description="Это чат-интерфейс для модели Aeonium v1.1 Chat 4B. Задавайте вопросы и получайте ответы!",
theme="soft",
retry_btn="Повторить",
undo_btn="Отменить последнее",
clear_btn="Очистить",
additional_inputs=[
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Максимальное количество новых токенов"),
gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Температура"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
)
# Запуск интерфейса
iface.launch() |