--- license: apache-2.0 tags: - generated_from_trainer datasets: - samsum metrics: - rouge model-index: - name: flan-t5-base results: - task: name: Summarization type: summarization dataset: name: samsum type: samsum split: validation metrics: - type: rogue1 value: 46.819522% - type: rouge2 value: 20.898074% - type: rougeL value: 37.300937% - type: rougeLsum value: 37.271341% pipeline_tag: summarization inference: false library_name: transformers language: - en --- # flan-t5-base-cnn-samsum-lora This model is a fine-tuned version of [braindao/flan-t5-cnn](https://huggingface.co/braindao/flan-t5-cnn) on the [samsum](https://huggingface.co/datasets/samsum) dataset. The base model [braindao/flan-t5-cnn](https://huggingface.co/braindao/flan-t5-cnn) is a fine-tuned verstion of [google/flan-t5-base](https://huggingface.co/google/flan-t5-base) on the cnn_dailymail 3.0.0 dataset. ## Model API Spaces Please visit HF Spaces [sooolee/summarize-transcripts-gradio](https://huggingface.co/spaces/sooolee/summarize-transcripts-gradio) ## Model description * This model further finetuned [braindao/flan-t5-cnn](https://huggingface.co/braindao/flan-t5-cnn) on the more conversational samsum dataset. * Huggingface [PEFT Library](https://github.com/huggingface/peft) LoRA (r = 16) and bitsandbytes int-8 was used to speed up training and reduce the model size. * Only 1.7M parameters were trained (0.71% of original flan-t5-base 250M parameters). * The model checkpoint is just 7MB. ## Intended uses & limitations Summarize transcripts such as YouTube transcripts. ## Training and evaluation data ### Training hyperparameters The following hyperparameters were used during training: - learning_rate: 0.001 - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08 - lr_scheduler_type: linear - num_epochs: 5 ### Training results - train_loss: 1.47 ### How to use Note 'max_new_tokens=60' is used in the below example to control the length of the summary. FLAN-T5 model has max generation length = 200 and min generation length = 20 (default). ```python import torch from peft import PeftModel, PeftConfig from transformers import AutoModelForSeq2SeqLM, AutoTokenizer # Load peft config for pre-trained checkpoint etc. peft_model_id = "sooolee/flan-t5-base-cnn-samsum-lora" config = PeftConfig.from_pretrained(peft_model_id) # load base LLM model and tokenizer model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path, device_map='auto') # load_in_8bit=True, tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) # Load the Lora model model = PeftModel.from_pretrained(model, peft_model_id, device_map='auto') # Tokenize the text inputs texts = "" inputs = tokenizer(texts, return_tensors="pt", padding=True, ) # truncation=True # Make inferences device = torch.device("cuda" if torch.cuda.is_available() else "cpu") with torch.no_grad(): output = self.model.generate(input_ids=inputs["input_ids"].to(device), max_new_tokens=60, do_sample=True, top_p=0.9) summary = self.tokenizer.batch_decode(output.detach().cpu().numpy(), skip_special_tokens=True) summary ``` ### Framework versions - Transformers 4.27.2 - Pytorch 1.13.1+cu116 - Datasets 2.9.0 - Tokenizers 0.13.3 ## Other Please check out the BART-Large-CNN-Samsum model fine-tuned for the same purpose: [sooolee/bart-large-cnn-finetuned-samsum-lora](https://huggingface.co/sooolee/bart-large-cnn-finetuned-samsum-lora)