import gradio as gr from transformers import LEDForConditionalGeneration, LEDTokenizer import torch from datasets import load_dataset import re # Set device to GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the LED model and tokenizer model = LEDForConditionalGeneration.from_pretrained("./summary_generation_Led_4").to(device) tokenizer = LEDTokenizer.from_pretrained("./summary_generation_Led_4") # Normalize the input text (plot synopsis) def normalize_text(text): text = text.lower() # Lowercase the text text = re.sub(r'\s+', ' ', text).strip() # Remove extra spaces and newlines text = re.sub(r'[^\w\s]', '', text) # Remove non-alphanumeric characters return text # Function to preprocess and generate summaries def generate_summary(plot_synopsis): # Preprocess the plot_synopsis inputs = tokenizer("summarize: " + normalize_text(plot_synopsis), max_length=3000, truncation=True, padding="max_length", return_tensors="pt") inputs = inputs.to(device) # Generate the summary outputs = model.generate(inputs["input_ids"], max_length=315, min_length=20, length_penalty=2.0, num_beams=4, early_stopping=True) summary = tokenizer.decode(outputs[0], skip_special_tokens=True) return summary # Gradio interface to take plot synopsis and output a generated summary interface = gr.Interface( fn=generate_summary, inputs=gr.Textbox(label="Plot Synopsis", lines=10, placeholder="Enter the plot synopsis here..."), outputs=gr.Textbox(label="Generated Summary"), title="Plot Summary Generator", description="This demo generates a plot summary based on the plot synopsis using a fine-tuned LED model." ) # Launch the Gradio interface interface.launch()