Jacks_Clone / app.py
Manasa1's picture
Update app.py
0bc842c verified
raw
history blame
3.63 kB
import streamlit as st
from langchain_community.document_loaders import UnstructuredPDFLoader
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, pipeline
import torch
import os
from datasets import Dataset
import pandas as pd
import re
# Set up page
st.set_page_config(
page_title="Tweet Style Cloning",
page_icon="🐦",
layout="centered"
)
st.title("🐦 Clone Tweet Style from PDF")
# Step 1: Upload PDF
uploaded_file = st.file_uploader("Upload a PDF with tweets")
if uploaded_file is not None:
# Step 2: Extract text from PDF
def load_pdf_text(file_path):
loader = UnstructuredPDFLoader(file_path)
documents = loader.load()
return " ".join([doc.page_content for doc in documents])
# Save the uploaded PDF file temporarily
with open("uploaded_tweets.pdf", "wb") as f:
f.write(uploaded_file.getbuffer())
# Extract text from PDF
extracted_text = load_pdf_text("uploaded_tweets.pdf")
# Step 3: Preprocess text to separate each tweet (assuming tweets end with newline)
tweets = re.split(r'\n+', extracted_text)
tweets = [tweet.strip() for tweet in tweets if len(tweet.strip()) > 0]
# Display a few sample tweets for verification
st.write("Sample Tweets Extracted:")
st.write(tweets[:5])
# Step 4: Fine-tune a model on the extracted tweets
def fine_tune_model(tweets):
# Convert tweets to a DataFrame and Dataset
df = pd.DataFrame(tweets, columns=["text"])
tweet_dataset = Dataset.from_pandas(df)
# Load model and tokenizer
model_name = "gpt2" # Replace with a suitable model if needed
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# Tokenize the dataset
def tokenize_function(examples):
tokens = tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)
tokens["labels"] = tokens["input_ids"].copy() # Use input_ids as labels
return tokens
tokenized_tweets = tweet_dataset.map(tokenize_function, batched=True)
# Training arguments
training_args = TrainingArguments(
output_dir="./fine_tuned_tweet_model",
per_device_train_batch_size=4,
num_train_epochs=3,
save_steps=10_000,
save_total_limit=1,
logging_dir='./logs',
)
# Initialize the Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_tweets,
)
# Fine-tune the model
trainer.train()
# Save the fine-tuned model
model.save_pretrained("fine_tuned_tweet_model")
tokenizer.save_pretrained("fine_tuned_tweet_model")
return model, tokenizer
# Trigger fine-tuning and notify user
with st.spinner("Fine-tuning model..."):
model, tokenizer = fine_tune_model(tweets)
st.success("Model fine-tuned successfully!")
# Step 5: Set up text generation
tweet_generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
# Generate a new tweet based on user input
prompt = st.text_input("Enter a prompt for a new tweet in the same style:")
if prompt:
with st.spinner("Generating tweet..."):
generated_tweet = tweet_generator(prompt, max_length=50, num_return_sequences=1)
st.write("Generated Tweet:")
st.write(generated_tweet[0]["generated_text"])