Spaces:
Runtime error
Runtime error
File size: 2,144 Bytes
e4ccddf 0c5cd97 f9972d8 e4ccddf f9972d8 e4ccddf d7f1aae e4ccddf 0c5cd97 2278cc5 f34d089 2278cc5 f9972d8 9d0dd09 5ab4218 9d0dd09 5ab4218 e4ccddf f9972d8 9d0dd09 f9972d8 9d0dd09 f9972d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
from transformers import pipeline,GemmaForCausalLM,AutoTokenizer,BitsAndBytesConfig
import gradio as gr
import spaces
import torch
# ignore_mismatched_sizes=True
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-9b')
model = GemmaForCausalLM.from_pretrained('google/gemma-2-9b',
quantization_config=quantization_config
)
# pipe = pipeline('text-generation', model=model,tokenizer = tokenizer)
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
@spaces.GPU(duration=120)
def generate(
message: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
):
input_ids = tokenizer(message, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids,top_p=top_p,max_new_tokens=max_new_tokens,top_k=top_k,repetition_penalty=repetition_penalty,temperature=temperature)
return tokenizer.decode(outputs[0], skip_special_tokens=True);
# return pipe(prompt)[0]['generated_text']
gr.Interface(
fn=generate,
inputs=[
gr.Text(),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.6,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),],
outputs="text",
examples=[['Write me a poem about Machine Learning.']],
).launch() |