kskathe's picture
Update app.py
33873fb verified
raw
history blame
2.71 kB
# import streamlit as st
# from transformers import pipeline
# from peft import AutoPeftModelForCausalLM
# from transformers import AutoTokenizer
# # Initialize the tokenizer first
# tokenizer = AutoTokenizer.from_pretrained("kskathe/finetuned-llama-text-summarization")
# alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
# ### Article:
# {}
# ### Highlights:
# {}"""
# # Streamlit interface for user input
# st.title("AI Article Summarizer")
# user_input = st.text_area("Enter the article text here:")
# if user_input:
# # Prepare the input using the user-provided text
# formatted_input = alpaca_prompt.format(user_input, "") # Highlights left blank for generation
# inputs = tokenizer([formatted_input], return_tensors="pt")
# # Load the model and move it to the same device
# text_model = AutoPeftModelForCausalLM.from_pretrained("kskathe/finetuned-llama-text-summarization")
# # Generate the output
# output = text_model.generate(**inputs, max_new_tokens=128)
# # Decode the output
# decoded_output = tokenizer.batch_decode(output, skip_special_tokens=True)
# # Display the output
# st.write("### Highlights:")
# st.write(decoded_output[0])
import streamlit as st
from transformers import pipeline
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer
# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained("kskathe/finetuned-llama-text-summarization")
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Article:
{}
### Highlights:
{}"""
# Input from the user
input_text = st.text_area("Enter the article content:")
formatted_input = alpaca_prompt.format(input_text, "")
if st.button("Generate Highlights"):
# Prepare the input
inputs = tokenizer([formatted_input], return_tensors="pt")
# Load the model without quantization and force CPU usage
text_model = AutoPeftModelForCausalLM.from_pretrained(
"kskathe/finetuned-llama-text-summarization",
device_map="cpu", # Force the model to run on CPU
load_in_8bit=False, # Disable 8-bit quantization if it was enabled
torch_dtype="float32" # Use float32 precision which is CPU friendly
)
# Generate the output
output = text_model.generate(**inputs, max_new_tokens=128)
# Decode the output
decoded_output = tokenizer.batch_decode(output, skip_special_tokens=True)
# Display the output
st.write(decoded_output)