sprigs / app.py
kookoobau's picture
update
03813a3
raw
history blame
1.62 kB
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from langchain import PromptTemplate, LLMChain
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationChain
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
import gradio as gr
# Load the tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
template = """Question: {question}
------------------
Answer: Let's think step by step."""
prompt = PromptTemplate(template=template, input_variables=["question"])
# Create a memory module with a maximum capacity of 1000 items
memory = ConversationBufferMemory()
# Callbacks support token-wise streaming
callbacks = [StreamingStdOutCallbackHandler()]
# Instantiate the LLMChain with the model and tokenizer
llm = LLMChain(model=model, tokenizer=tokenizer, callbacks=callbacks, verbose=True)
conversation = ConversationChain(llm=llm, memory=memory, callbacks=callbacks, prompt=prompt)
# Define the Gradio interface
def chatbot_interface(input_text):
response = conversation.predict(input_text)
memory.chat_memory.add_user_message(input_text)
memory.chat_memory.add_ai_message(response)
return response
# Define the Gradio app
gradio_app = gr.Interface(
fn=chatbot_interface,
inputs=gr.inputs.Textbox(label="Say something..."),
outputs=gr.outputs.Textbox(),
title="ConversationChain Chatbot",
description="A chatbot interface powered by ConversationChain and Hugging Face.",
)
# Run the Gradio app
if __name__ == "__main__":
gradio_app.run()