Spaces:
Sleeping
Sleeping
File size: 5,235 Bytes
1b50cd3 d0d7843 1b50cd3 d0d7843 1b50cd3 d0d7843 1b50cd3 d0d7843 1b50cd3 d0d7843 1b50cd3 b3a69a4 1b50cd3 d0d7843 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
#!/usr/bin/env python
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TextIteratorStreamer,
)
DESCRIPTION = """# Swallow-13B instruct"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
if torch.cuda.is_available():
model_name = "tokyotech-llm/Swallow-13b-instruct-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
low_cpu_mem_usage=True,
device_map="auto",
)
MAX_INPUT_TOKENS = 2048
PROMPT_DICT = {
"prompt_input": (
"以下に、あるタスクを説明する指示があり、それに付随する入力が更なる文脈を提供しています。"
"リクエストを適切に完了するための回答を記述してください。\n\n"
"### 指示:\n{instruction}\n\n### 入力:\n{input}\n\n### 応答:"
),
"prompt_no_input": (
"以下に、あるタスクを説明する指示があります。"
"リクエストを適切に完了するための回答を記述してください。\n\n"
"### 指示:\n{instruction}\n\n### 応答:"
),
}
def create_prompt(instruction: str, input_text: str | None = None) -> str:
"""Generates a prompt based on the given instruction and an optional input.
If input is provided, it uses the 'prompt_input' template from PROMPT_DICT.
If no input is provided, it uses the 'prompt_no_input' template.
Args:
instruction (str): The instruction describing the task.
input_text (str, optional): Additional input providing context for the task. Default is None.
Returns:
str: The generated prompt.
"""
if input_text:
# Use the 'prompt_input' template when additional input is provided
return PROMPT_DICT["prompt_input"].format(instruction=instruction, input=input_text)
else:
# Use the 'prompt_no_input' template when no additional input is provided
return PROMPT_DICT["prompt_no_input"].format(instruction=instruction)
@spaces.GPU
@torch.inference_mode()
def run(
instruction: str,
input_text: str | None = None,
max_new_tokens: int = 256,
temperature: float = 0.99,
top_p: float = 0.95,
) -> Iterator[str]:
if input_text == "":
input_text = None
prompt = create_prompt(instruction, input_text)
input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
if input_ids.shape[-1] > MAX_INPUT_TOKENS:
raise gr.Error(f"Input exceeds maximum number of tokens ({MAX_INPUT_TOKENS})")
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids.to(model.device)},
streamer=streamer,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
def process_example(instruction: str, input_text: str) -> Iterator[str]:
yield from run(instruction, input_text)
with gr.Blocks(css_paths="style.css") as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
instruction = gr.Textbox(label="Instruction", lines=5)
input_text = gr.Textbox(label="Input (optional)", lines=5)
run_button = gr.Button()
with gr.Accordion(label="Advanced Options", open=False):
max_new_tokens = gr.Slider(label="Max New Tokens", minimum=1, maximum=1024, step=1, value=256)
temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=2.0, step=0.01, value=0.99)
top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, step=0.01, value=0.95)
with gr.Column():
output = gr.Textbox(label="Output", lines=10)
run_button.click(
fn=run,
inputs=[instruction, input_text, max_new_tokens, temperature, top_p],
outputs=output,
api_name="run",
)
gr.Examples(
examples=[
[
"以下のトピックに関する詳細な情報を提供してください。",
"東京工業大学の主なキャンパスについて教えてください。",
],
[
"以下のトピックに関する詳細な情報を提供してください。",
"夢オチとは何かについて教えてください。",
],
["暴れん坊将軍って誰のことですか?", ""],
],
inputs=[instruction, input_text],
outputs=output,
fn=process_example,
cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
api_name=False,
)
if __name__ == "__main__":
demo.launch()
|