Spaces:
Sleeping
Sleeping
import streamlit as st | |
from langchain_community.document_loaders import UnstructuredPDFLoader | |
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, pipeline | |
import torch | |
import os | |
from datasets import Dataset | |
import pandas as pd | |
import re | |
# Set up page | |
st.set_page_config( | |
page_title="Tweet Style Cloning", | |
page_icon="π¦", | |
layout="centered" | |
) | |
st.title("π¦ Clone Tweet Style from PDF") | |
# Step 1: Upload PDF | |
uploaded_file = st.file_uploader("Upload a PDF with tweets") | |
if uploaded_file is not None: | |
# Step 2: Extract text from PDF | |
def load_pdf_text(file_path): | |
loader = UnstructuredPDFLoader(file_path) | |
documents = loader.load() | |
return " ".join([doc.page_content for doc in documents]) | |
# Save the uploaded PDF file temporarily | |
with open("uploaded_tweets.pdf", "wb") as f: | |
f.write(uploaded_file.getbuffer()) | |
# Extract text from PDF | |
extracted_text = load_pdf_text("uploaded_tweets.pdf") | |
# Step 3: Preprocess text to separate each tweet (assuming tweets end with newline) | |
tweets = re.split(r'\n+', extracted_text) | |
tweets = [tweet.strip() for tweet in tweets if len(tweet.strip()) > 0] | |
# Display a few sample tweets for verification | |
st.write("Sample Tweets Extracted:") | |
st.write(tweets[:5]) | |
# Step 4: Fine-tune a model on the extracted tweets | |
def fine_tune_model(tweets): | |
# Convert tweets to a DataFrame and Dataset | |
df = pd.DataFrame(tweets, columns=["text"]) | |
tweet_dataset = Dataset.from_pandas(df) | |
# Load model and tokenizer | |
model_name = "gpt2" # Replace with a suitable model if needed | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
tokenizer.pad_token = tokenizer.eos_token | |
# Tokenize the dataset | |
def tokenize_function(examples): | |
tokens = tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128) | |
tokens["labels"] = tokens["input_ids"].copy() # Use input_ids as labels | |
return tokens | |
tokenized_tweets = tweet_dataset.map(tokenize_function, batched=True) | |
# Training arguments | |
training_args = TrainingArguments( | |
output_dir="./fine_tuned_tweet_model", | |
per_device_train_batch_size=4, | |
num_train_epochs=3, | |
save_steps=10_000, | |
save_total_limit=1, | |
logging_dir='./logs', | |
) | |
# Initialize the Trainer | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=tokenized_tweets, | |
) | |
# Fine-tune the model | |
trainer.train() | |
# Save the fine-tuned model | |
model.save_pretrained("fine_tuned_tweet_model") | |
tokenizer.save_pretrained("fine_tuned_tweet_model") | |
return model, tokenizer | |
# Trigger fine-tuning and notify user | |
with st.spinner("Fine-tuning model..."): | |
model, tokenizer = fine_tune_model(tweets) | |
st.success("Model fine-tuned successfully!") | |
# Step 5: Set up text generation | |
tweet_generator = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
# Generate a new tweet based on user input | |
prompt = st.text_input("Enter a prompt for a new tweet in the same style:") | |
if prompt: | |
with st.spinner("Generating tweet..."): | |
generated_tweet = tweet_generator(prompt, max_length=50, num_return_sequences=1) | |
st.write("Generated Tweet:") | |
st.write(generated_tweet[0]["generated_text"]) | |