not-lain's picture
switch to 2b model
d6a7434 verified
raw
history blame
535 Bytes
import gradio as gr
import os
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
token = os.environ["HF_TOKEN"]
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b",token=token)
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b",token=token)
streamer = TextStreamer(tokenizer,skip_prompt=True)
def generate(inputs,history):
inputs = tokenizer([inputs], return_tensors="pt")
yield model.generate(**inputs, streamer=streamer)
app = gr.ChatInterface(generate)
app.launch(debug=True)