Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import LEDForConditionalGeneration, LEDTokenizer | |
import torch | |
from datasets import load_dataset | |
import re | |
# Set device to GPU if available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Load the LED model and tokenizer | |
model = LEDForConditionalGeneration.from_pretrained("./summary_generation_Led_4").to(device) | |
tokenizer = LEDTokenizer.from_pretrained("./summary_generation_Led_4") | |
# Normalize the input text (plot synopsis) | |
def normalize_text(text): | |
text = text.lower() # Lowercase the text | |
text = re.sub(r'\s+', ' ', text).strip() # Remove extra spaces and newlines | |
text = re.sub(r'[^\w\s]', '', text) # Remove non-alphanumeric characters | |
return text | |
# Function to preprocess and generate summaries | |
def generate_summary(plot_synopsis): | |
# Preprocess the plot_synopsis | |
inputs = tokenizer("summarize: " + normalize_text(plot_synopsis), | |
max_length=3000, truncation=True, padding="max_length", return_tensors="pt") | |
inputs = inputs.to(device) | |
# Generate the summary | |
outputs = model.generate(inputs["input_ids"], max_length=315, min_length=20, | |
length_penalty=2.0, num_beams=4, early_stopping=True) | |
summary = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return summary | |
# Gradio interface to take plot synopsis and output a generated summary | |
interface = gr.Interface( | |
fn=generate_summary, | |
inputs=gr.Textbox(label="Plot Synopsis", lines=10, placeholder="Enter the plot synopsis here..."), | |
outputs=gr.Textbox(label="Generated Summary"), | |
title="Plot Summary Generator", | |
description="This demo generates a plot summary based on the plot synopsis using a fine-tuned LED model." | |
) | |
# Launch the Gradio interface | |
interface.launch() | |