{ "cells": [ { "cell_type": "markdown", "id": "e28a25ec-df73-4876-a380-1dc12b0f1b64", "metadata": { "id": "e28a25ec-df73-4876-a380-1dc12b0f1b64" }, "source": [ "# Fine-tuning Bioinspired-Llama 3.1 8B\n", "\n", "#### Markus J. Buehler, MIT" ] }, { "cell_type": "code", "execution_count": 1, "id": "d72ae88a-eac6-48e7-8e4a-49b60da52eba", "metadata": { "id": "d72ae88a-eac6-48e7-8e4a-49b60da52eba" }, "outputs": [], "source": [ "!pip install -U -q accelerate peft bitsandbytes transformers trl wandb" ] }, { "cell_type": "code", "execution_count": 2, "id": "aa0a69fb-f742-4d71-9ca8-9afd6e612ca6", "metadata": { "id": "aa0a69fb-f742-4d71-9ca8-9afd6e612ca6" }, "outputs": [], "source": [ "import os\n", "\n", "#os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", "\n", "import torch\n", "from tqdm.notebook import tqdm\n", "import copy\n", "\n", "from datasets import load_dataset, load_from_disk\n", "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments, pipeline\n", "from transformers import logging\n", "\n", "from trl import SFTTrainer\n", "from peft import LoraConfig, PeftModel\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "b3bc138a-c08f-4448-b623-c8ac0468d4cb", "metadata": { "id": "b3bc138a-c08f-4448-b623-c8ac0468d4cb" }, "outputs": [], "source": [ "def generate_response (text_input=\"What is spider silk?\",\n", " system_prompt='',\n", " num_return_sequences=1,\n", " temperature=1., #the higher the temperature, the more creative the model becomes\n", " max_new_tokens=127,device='cuda',\n", " add_special_tokens = False, #since tokenizer.apply_chat_template adds <|begin_of_text|> template already, set to False\n", " num_beams=1,eos_token_id= [\n", " 128001,\n", " 128008,\n", " 128009\n", " ], verbatim=False,\n", " top_k = 50,\n", " top_p = 0.9,\n", " repetition_penalty=1.1,\n", " messages=[],\n", " ):\n", "\n", " if messages==[]: #start new messages dictionary\n", " if system_prompt != '': #include system prompt if provided\n", " messages.extend ([ {\"role\": \"system\", \"content\": system_prompt}, ])\n", " messages.extend ( [ {\"role\": \"user\", \"content\": text_input}, ])\n", "\n", " else: #if messages provided, will extend (make sure to add previous response as assistant message)\n", " messages.append ({\"role\": \"user\", \"content\": text_input})\n", "\n", " text_input = tokenizer.apply_chat_template(\n", " messages,\n", " tokenize=False,\n", " add_generation_prompt=True\n", " )\n", " inputs = tokenizer([text_input], add_special_tokens = add_special_tokens, return_tensors ='pt' ).to(device)\n", " if verbatim:\n", " print (inputs)\n", " with torch.no_grad():\n", " outputs = model.generate(**inputs,\n", " max_new_tokens=max_new_tokens,\n", " temperature=temperature,\n", " num_beams=num_beams,\n", " top_k = top_k,eos_token_id=eos_token_id,\n", " top_p =top_p,\n", " num_return_sequences = num_return_sequences,\n", " do_sample =True, repetition_penalty=repetition_penalty,\n", " )\n", " outputs=outputs[:, inputs[\"input_ids\"].shape[1]:]\n", " return tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True), messages" ] }, { "cell_type": "markdown", "id": "25675672-82ed-45a2-823f-802f9373ca8a", "metadata": { "id": "25675672-82ed-45a2-823f-802f9373ca8a" }, "source": [ "### Define parameters" ] }, { "cell_type": "code", "execution_count": null, "id": "aaaa8bbd-b4af-4474-a2dd-d484b85be81f", "metadata": { "id": "aaaa8bbd-b4af-4474-a2dd-d484b85be81f" }, "outputs": [], "source": [ "################################################################################\n", "# Base model to train from\n", "################################################################################\n", "\n", "# The model that you want to train from the Hugging Face hub\n", "base_model_name = \"lamm-mit/BioinspiredLlama-3-1-8B-128k\"\n", "\n", "# Fine-tuned model name\n", "new_model = \"protein_secondary_structure_predictor\"\n", "\n", "################################################################################\n", "# LoRA parameters\n", "################################################################################\n", "\n", "# LoRA target modules\n", "lora_target_modules=['q_proj','k_proj','v_proj','o_proj','down_proj',\n", " 'up_proj','gate_proj']\n", "\n", "# LoRA attention dimension\n", "lora_r = 64\n", "\n", "# Alpha parameter for LoRA scaling\n", "lora_alpha = 64\n", "\n", "# Dropout probability for LoRA layers\n", "lora_dropout = 0.1\n", "\n", "################################################################################\n", "# bitsandbytes parameters\n", "################################################################################\n", "\n", "# Activate 4-bit precision base model loading\n", "use_4bit = True\n", "\n", "# Compute dtype for 4-bit base models\n", "bnb_4bit_compute_dtype = \"bfloat16\"\n", "\n", "# Quantization type (fp4 or nf4)\n", "bnb_4bit_quant_type = \"nf4\"\n", "\n", "# Activate nested quantization for 4-bit base models, i.e. double quantization\n", "use_nested_quant = False\n", "\n", "################################################################################\n", "# TrainingArguments parameters\n", "################################################################################\n", "\n", "# Output directory where the model predictions and checkpoints will be stored\n", "output_dir = new_model\n", "\n", "# Number of training epochs\n", "num_train_epochs = 5\n", "\n", "# Enable fp16/bf16 training (set bf16 to True with an A100)\n", "fp16 = False\n", "bf16 = True\n", "\n", "# Batch size per GPU for training\n", "#per_device_train_batch_size = 12\n", "per_device_train_batch_size = 4 #on Colab, 16 GB GPU\n", "\n", "# Batch size per GPU for evaluation\n", "per_device_eval_batch_size = 6\n", "\n", "# Number of update steps to accumulate the gradients for\n", "gradient_accumulation_steps = 4\n", "\n", "# Enable gradient checkpointing\n", "gradient_checkpointing = True\n", "\n", "# Maximum gradient normal (gradient clipping)\n", "max_grad_norm = 0.5\n", "\n", "# Initial learning rate (AdamW optimizer)\n", "learning_rate = 2e-4\n", "\n", "# Weight decay to apply to all layers except bias/LayerNorm weights\n", "weight_decay = 0.001\n", "\n", "# Optimizer to use\n", "optim = \"paged_adamw_32bit\"\n", "\n", "# Learning rate schedule (constant, cosine, etc.)\n", "lr_scheduler_type = \"cosine\" # \"constant\"\n", "\n", "# Number of training steps (overrides num_train_epochs)\n", "max_steps = -1\n", "\n", "# Ratio of steps for a linear warmup (from 0 to learning rate)\n", "warmup_ratio = 0.03\n", "\n", "# Group sequences into batches with same length\n", "# Saves memory and speeds up training considerably\n", "group_by_length = True\n", "\n", "# Log every X updates steps\n", "logging_steps = 25\n", "save_strategy='epoch'\n", "eval_steps= 100\n", "evaluation_strategy = \"steps\" #\"steps\", \"epoch\", \"no\"\n", "\n", "################################################################################\n", "# SFT parameters\n", "################################################################################\n", "\n", "# Maximum sequence length to use\n", "max_seq_length = 128\n", "\n", "# Pack multiple short examples in the same input sequence to increase efficiency\n", "packing = True\n", "\n", "# Load the entire model on the GPU 0\n", "device_map = {\"\": 0}\n", "\n", "push_formatted_dataset_to_hub = False" ] }, { "cell_type": "markdown", "id": "1f989b7f-9e68-441e-ac61-f661dd1eddaa", "metadata": { "id": "1f989b7f-9e68-441e-ac61-f661dd1eddaa" }, "source": [ "### Prepare dataset\n", "\n", "We use a protein dataset here, with the goal to train a model that predicts the dominant secondary structure for a given protein" ] }, { "cell_type": "markdown", "id": "6caab12c-62cb-447c-aaa9-12439dbf3d2c", "metadata": { "id": "6caab12c-62cb-447c-aaa9-12439dbf3d2c" }, "source": [ "#### Chat template: Example\n", "\n", "In this case, we use the chat template so that the model learns to respond to a particular query" ] }, { "cell_type": "code", "execution_count": 5, "id": "a0592c35-28b1-4421-b8f8-20150f7a39f8", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 194 }, "id": "a0592c35-28b1-4421-b8f8-20150f7a39f8", "outputId": "26bd6578-8a61-4275-85b4-648c7109fef1" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n", "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", "You will be able to reuse this secret in all of your notebooks.\n", "Please note that authentication is recommended but still optional to access public models or datasets.\n", " warnings.warn(\n" ] }, { "data": { "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" }, "text/plain": [ "'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nYou are an expert in protein mechanics.<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\nWhat is a protein?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\nA protein is a chain of amino acids...<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n'" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True,)\n", "\n", "messages = [\n", " {\"role\": \"system\", \"content\": \"You are an expert in protein mechanics.\"},\n", " {\"role\": \"user\", \"content\": \"What is a protein?\"},\n", " {\"role\": \"assistant\", \"content\": \"A protein is a chain of amino acids...\"},\n", "]\n", "\n", "prompt = tokenizer.apply_chat_template(\n", " messages,\n", " tokenize=False,\n", " add_generation_prompt=True\n", ")\n", "prompt" ] }, { "cell_type": "markdown", "id": "3433c8c7-088d-4722-9068-26e9d29371e1", "metadata": { "id": "3433c8c7-088d-4722-9068-26e9d29371e1" }, "source": [ "#### Build dataset for LLM training from 'raw sources'" ] }, { "cell_type": "code", "execution_count": 6, "id": "724ace7f-d157-44b2-aa5d-961cd64e5367", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 309, "referenced_widgets": [ "3d441c6fe11d4c9cbb1a4d29e165b484", "c653871a43bf430b954b02f5fd91cdda", "502240b743a648c1a0f452182d312135", "81ce8c7c00f5412c8ed82a1cb8633af3", "17f5eda8dce247659adf33f598f15bbe", "027b94630fd34e8eaf915592a51edd89", "e821b92ac54b4569971ac51d7106f178", "7d1afc79ef2b4f78abe7611e81043350", "3d8a58b93b104e1e892e4fedb2d02f56", "8254778ae23f4572ac58824fd6f1750d", "008d5138dc2241aa84ed7eed5693ccd0", "72bb45159c0449958822b38dedfbbb22", "0cb92a1263814fd993fb835c516336cc", "e4b25d01700d493ebab4b51b857203f7", "c0e56ab1984d4da49f130c5e24b101d1", "7973783a74944950aa0b4aa9e1c25216", "be13531d3aa94e838a7aea7d18bf8c4f", "d6dc16801dca4f7fb2ec2b7f2504b4a1", "5dded94f8d944cdcbf22d4126e97cef7", "02b792973a674343a2d2c5e2efaeacc5", "4b6388050be8496d869084d8673fb0f5", "e18a5564d1f441049f6606f40def1be7" ] }, "id": "724ace7f-d157-44b2-aa5d-961cd64e5367", "outputId": "86884659-798b-4ac0-82dc-2c29c5b692c4", "scrolled": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3d441c6fe11d4c9cbb1a4d29e165b484", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/12921 [00:00<|start_header_id|>user<|end_header_id|>\\n\\nDominant secondary structure of < S S N A K I D Q L S S D V Q T L N A K V D Q L S N D V N A M R S D V Q A A K D D A A R A N Q R L D N M A T K Y R ><|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\nAH<|eot_id|>'}\n", "{'content': '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\\n\\nDominant secondary structure of < K D W S F Y K D W S F Y K D W S F Y K D W S F Y ><|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\nUNSTRUCTURED<|eot_id|>'}\n" ] } ], "source": [ "from datasets import load_dataset, DatasetDict, concatenate_datasets\n", "\n", "# Define the preprocess function for conversations with a chat template\n", "def preprocess_conversation(samples):\n", " system_prompt = ''\n", " batch = []\n", " for question, answer in zip(samples[\"question\"], samples[\"answer\"]):\n", " messages=[]\n", " if system_prompt != '':\n", "\n", " messages.extend ([\n", " {\"role\": \"system\", \"content\": system_prompt},\n", "\n", " ])\n", " messages.extend ( [\n", " {\"role\": \"user\", \"content\": 'Dominant secondary structure of < '+question+' >'},\n", " {\"role\": \"assistant\", \"content\": answer}\n", " ])\n", " formatted_conv = tokenizer.apply_chat_template(\n", " messages,\n", " tokenize=False,\n", " add_generation_prompt=False\n", " )\n", " remove_gen_prompt='<|start_header_id|>assistant<|end_header_id|>\\n\\n'\n", " if formatted_conv.endswith(remove_gen_prompt):\n", " formatted_conv = formatted_conv[:-len(remove_gen_prompt)]\n", "\n", " formatted_conv\n", "\n", " batch.append(formatted_conv)\n", " return {\"content\": batch}\n", "\n", "\n", "# Load the original dataset\n", "dataset = load_dataset('lamm-mit/protein_secondary_structure_from_PDB')['train']\n", "#dataset = load_from_disk('protein_secondary_structure_from_PDB')\n", "\n", "dataset = dataset.filter(lambda example: example['Sequence_length'] < max_seq_length)\n", "\n", "# Rename columns\n", "dataset = dataset.rename_column('Sequence_spaced', 'question')\n", "dataset = dataset.rename_column('Primary_SS_Type', 'answer')\n", "\n", "dataset = dataset.train_test_split(test_size=0.1, seed=42)\n", "\n", "test_dataset = copy.deepcopy(dataset['test'])\n", "\n", "# Apply preprocessing to the original dataset\n", "dataset['train'] = dataset['train'].map(\n", " #preprocess,\n", " preprocess_conversation,\n", " batched=True,\n", " #This argument removes the original columns from the dataset after preprocessing. dataset[\"train\"].column_names retrieves the column names from the training subset of the dataset, and these columns will be removed from the resulting dataset.\n", " #remove_columns=dataset[\"train\"].column_names\n", " remove_columns=dataset['train'].column_names\n", ")\n", "dataset['test'] = dataset['test'].map(\n", " #preprocess,\n", " preprocess_conversation,\n", " batched=True,\n", " #This argument removes the original columns from the dataset after preprocessing. dataset[\"train\"].column_names retrieves the column names from the training subset of the dataset, and these columns will be removed from the resulting dataset.\n", " #remove_columns=dataset[\"train\"].column_names\n", " remove_columns=dataset['test'].column_names\n", ")\n", "# Shuffle the dataset\n", "dataset = dataset.shuffle(100)\n", "\n", "#dataset = dataset['train'].train_test_split(test_size=0.1)\n", "\n", "# Create a new with 'train' and 'test' splits\n", "new_dataset = DatasetDict({\n", " # \"train\": dataset[\"train\"],\n", " \"train\": concatenate_datasets([dataset['train'],\n", " # dataset_2['train'],\n", " ],\n", " ),\n", " \"test\": concatenate_datasets([dataset['test'],\n", " # dataset_2['train'],\n", " ],\n", " )\n", "})\n", "\n", "print(new_dataset)\n", "print(new_dataset[\"train\"][0])\n", "print(new_dataset[\"test\"][0])\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "6bcae638-0809-484d-bee7-93ac7c936157", "metadata": { "id": "6bcae638-0809-484d-bee7-93ac7c936157" }, "outputs": [], "source": [ "if push_formatted_dataset_to_hub:\n", " dataset_name='formatted_data_for_training'\n", " # Push the modified dataset to the hub\n", " new_dataset.push_to_hub(f\"lamm-mit/{dataset_name}\", private=True,\n", " )" ] }, { "cell_type": "markdown", "id": "7a54d3b0-3c56-4097-a4d7-afdea45e9287", "metadata": { "id": "7a54d3b0-3c56-4097-a4d7-afdea45e9287" }, "source": [ "### Train the model" ] }, { "cell_type": "code", "execution_count": 8, "id": "ed7f5a39-fc23-4b92-9129-3d9d34ccc8f2", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 66, "referenced_widgets": [ "cf3545758b0f48f2b159e1bd42fffee3", "14a244af747f42b58b88a2e46611d8bf", "328b2d32af0747a2978790e62217149f", "73f508c080c542cca5b7fc09406c594f", "879ffffedc9342159a47918f914837d4", "0675a2178c4844ba8ead2f56311a9c62", "edfb829258b04e13854ca9df8ffa2135", "132c1952b3b54a7580155b1cd05220d6", "44df33c09a29490a8a965cd67b2c2688", "64b10175cd5d4650bcd4f7d553c6292d", "3895b2ed98184ecc84bb43d43cc6cf3a" ] }, "id": "ed7f5a39-fc23-4b92-9129-3d9d34ccc8f2", "outputId": "121f719f-0590-4e4f-e49e-498e620abbb7" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "cf3545758b0f48f2b159e1bd42fffee3", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/4 [00:00= 8:\n", " print(\"=\" * 80)\n", " print(\"GPU supports bfloat16: accelerate training with bf16=True\")\n", " print(\"=\" * 80)\n", "\n", "# Load base model\n", "model = AutoModelForCausalLM.from_pretrained(\n", " base_model_name,\n", " quantization_config=bnb_config,\n", " device_map=device_map\n", ")\n", "model.config.use_cache = False\n", "#model.config.pretraining_tp = 1\n", "\n", "# Load tokenizer\n", "tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)\n", "tokenizer.padding_side = \"right\"\n", "\n", "'''\n", "The pad_token_id and eos_token_id values should not be identical when using SFFT.\n", "This is because of the way masking is implemented, and can result in the end token to be masked out.\n", "As a result, in the model continuously generating questions and answers without eos token.\n", "To avoid this, we set the pad_token_id to a different value.\"\n", "'''\n", "\n", "tokenizer.pad_token_id= tokenizer.encode('<|end_of_text|>')[1] #set to 128001\n", "tokenizer.pad_token_id" ] }, { "cell_type": "code", "execution_count": 9, "id": "5b73c3b6-ac7b-4474-abd2-4dee3f2d4517", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "5b73c3b6-ac7b-4474-abd2-4dee3f2d4517", "outputId": "656c87f9-9633-4992-96d2-7388f8a69b97" }, "outputs": [ { "data": { "text/plain": [ "('<|end_of_text|>', '<|eot_id|>')" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer.pad_token, tokenizer.eos_token" ] }, { "cell_type": "code", "execution_count": 10, "id": "2c5c9772-0982-495d-bf8e-c457b08dbb65", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "2c5c9772-0982-495d-bf8e-c457b08dbb65", "outputId": "21d62446-af29-4ecc-8bf7-71d892a7196f" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/transformers/training_args.py:1525: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n", " warnings.warn(\n" ] } ], "source": [ "# Load LoRA configuration\n", "peft_config = LoraConfig(\n", " lora_alpha=lora_alpha,\n", " lora_dropout=lora_dropout,\n", " r=lora_r,\n", " bias=\"none\",\n", " target_modules=lora_target_modules,\n", " task_type=\"CAUSAL_LM\",\n", ")\n", "\n", "# Set training parameters\n", "# check: https://huggingface.co/docs/transformers/v4.15.0/en/main_classes/trainer#transformers.TrainingArguments\n", "training_arguments = TrainingArguments(\n", " output_dir=output_dir,\n", " num_train_epochs=num_train_epochs,\n", " per_device_train_batch_size=per_device_train_batch_size,\n", " gradient_accumulation_steps=gradient_accumulation_steps,\n", " optim=optim,\n", " save_strategy=save_strategy,\n", " evaluation_strategy= evaluation_strategy,\n", " eval_steps=eval_steps,\n", " logging_steps=logging_steps,\n", " learning_rate=learning_rate,\n", " weight_decay=weight_decay,\n", " fp16=fp16,\n", " bf16=bf16,\n", " max_grad_norm=max_grad_norm,\n", " max_steps=max_steps,\n", " warmup_ratio=warmup_ratio,\n", " group_by_length=group_by_length,\n", " lr_scheduler_type=lr_scheduler_type,\n", " report_to=\"wandb\",\n", " #report_to=\"none\",\n", "\n", ")" ] }, { "cell_type": "markdown", "id": "b829e338-868c-49f8-80e7-24a7eae231a4", "metadata": { "id": "b829e338-868c-49f8-80e7-24a7eae231a4" }, "source": [ "#### Check how the model performs BEFORE training" ] }, { "cell_type": "code", "execution_count": 11, "id": "90f373bc-e286-429d-8b8e-6d01ac776b15", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "90f373bc-e286-429d-8b8e-6d01ac776b15", "outputId": "b1dcb6a7-20b9-4a84-901e-48991643be9a" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Answer LLM: The dominant secondary structure of the given protein sequence is a combination of alpha helices and beta sheets. The protein contains several alpha helical regions, including one that spans from amino acid positions 40 to 57 (FVVTDCIYK), another from positions 79 to 103 (CVEVC-PVNCF\n", "Correct answer: UNSTRUCTURED\n" ] } ], "source": [ "question=test_dataset[0]['question']\n", "corr_answer=test_dataset[0]['answer']\n", "\n", "corr_answer\n", "answer,_ = generate_response (text_input='Dominant secondary structure of < '+question+' >', max_new_tokens=64)\n", "\n", "print (f\"Answer LLM: {answer[0]}\\nCorrect answer: {corr_answer}\")" ] }, { "cell_type": "code", "execution_count": 12, "id": "54022820-5634-4ca3-9fc6-b810d2add921", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "54022820-5634-4ca3-9fc6-b810d2add921", "outputId": "d0054631-039d-49b7-fc19-76561d689137" }, "outputs": [ { "data": { "text/plain": [ "128009" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer.eos_token_id" ] }, { "cell_type": "code", "execution_count": 13, "id": "ee45fb17-152c-460b-8c63-241b5983e060", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000, "referenced_widgets": [ "312932c72a38420db037e56c339db559", "4a35e2765bc44d42b63d7a823d9c0e28", "fb7fb94611a04098bfed8dd0bfb2497e", "6d2ffac7ab894099b8782175fdde56d9", "5baf0a5192544cf6bd3a138fb632f64e", "a5a7cba95a474882af91facd5d43a863", "43a3c01dc9434d6d9beda4cf96684596", "bc7eb5ead2204888b8837956185a853e", "a7a8ef45c40448a884528e135642af2d", "c7523ae0393c4cad8b3161142a381e0f", "bc0b398597574e8ea658b4060dc9865d", "7f8a13951c7e4ca7ae01fe99996ba9e5", "26cfb83cb26249c8989840637cdc1917", "dd7bb23a6f1a461ab2a5738356b8a1df", "3cb9714a03454ee491683f48c0197e95", "b292c34dd8a74df5ab8d566826720f88", "c76b2d301871420bbe3ddc6526332397", "52089ee51c5f4cc4b8360105bc994a03", "4cc376198019481d8c7c944972425a52", "b616c213da6346bfb62ab0e268a9fd5e", "f343263df351401d96549534ab288020", "79a4f2eeec274592a4cd7c19b759e60b" ] }, "id": "ee45fb17-152c-460b-8c63-241b5983e060", "outputId": "3754ad03-088e-491c-fb7b-2fdda171ca39", "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_deprecation.py:100: FutureWarning: Deprecated argument(s) used in '__init__': dataset_text_field, max_seq_length, packing, dataset_kwargs. Will not be supported from version '1.0.0'.\n", "\n", "Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.\n", " warnings.warn(message, FutureWarning)\n", "/usr/local/lib/python3.10/dist-packages/transformers/training_args.py:1525: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n", " warnings.warn(\n", "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:192: UserWarning: You passed a `packing` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.\n", " warnings.warn(\n", "/usr/local/lib/python3.10/dist-packages/transformers/training_args.py:1525: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n", " warnings.warn(\n", "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:280: UserWarning: You passed a `max_seq_length` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.\n", " warnings.warn(\n", "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:318: UserWarning: You passed a `dataset_text_field` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.\n", " warnings.warn(\n", "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:366: UserWarning: You passed a `dataset_kwargs` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.\n", " warnings.warn(\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "312932c72a38420db037e56c339db559", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating train split: 0 examples [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7f8a13951c7e4ca7ae01fe99996ba9e5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating train split: 0 examples [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m The `run_name` is currently set to the same value as `TrainingArguments.output_dir`. If this was not intended, please specify a different run name by setting the `TrainingArguments.run_name` parameter.\n", "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmarkusjbuehler\u001b[0m (\u001b[33mlamm-mit\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" ] }, { "data": { "text/html": [ "Tracking run with wandb version 0.17.5" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in /content/wandb/run-20240728_101654-9smq4g2s" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run protein_secondary_structure_predictor to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://wandb.ai/lamm-mit/huggingface" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://wandb.ai/lamm-mit/huggingface/runs/9smq4g2s" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [3300/3300 1:44:35, Epoch 5/5]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining LossValidation Loss
1002.2921002.290651
2001.9892002.018708
3001.8224001.855575
4001.7684001.763817
5001.6784001.703792
6001.6302001.639292
7001.4010001.580074
8001.3859001.548995
9001.3621001.504345
10001.2882001.470629
11001.3043001.436050
12001.2073001.396687
13001.1660001.363970
14000.8800001.424378
15000.7941001.385330
16000.8453001.372946
17000.8654001.340954
18000.7876001.315560
19000.8510001.308754
20000.5236001.526696
21000.4091001.531994
22000.3722001.512631
23000.3970001.523684
24000.3983001.511530
25000.4141001.517029
26000.4033001.511221
27000.1999001.774590
28000.2006001.821994
29000.2010001.833357
30000.2048001.828533
31000.1946001.834447
32000.2060001.833463
33000.2082001.832923

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Set supervised fine-tuning parameters\n", "trainer = SFTTrainer(\n", " model=model,\n", " tokenizer=tokenizer,\n", " train_dataset=new_dataset[\"train\"],\n", " eval_dataset=new_dataset[\"test\"],\n", " peft_config=peft_config,\n", " dataset_text_field=\"content\",\n", " max_seq_length=max_seq_length,\n", " args=training_arguments,\n", " packing=packing,\n", "\n", " dataset_kwargs={\n", " \"append_concat_token\": True, #If true, appends eos_token_id at the end of each sample being packed.\n", " },\n", ")\n", "\n", "# Train model\n", "trainer.train()\n", "\n", "# Save trained model\n", "trainer.model.save_pretrained(new_model)" ] }, { "cell_type": "code", "execution_count": 14, "id": "6bed4f75-7783-4db4-9bf5-5a98fa29250d", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 504 }, "id": "6bed4f75-7783-4db4-9bf5-5a98fa29250d", "outputId": "bc304f2c-48e7-4389-cc03-bee33b3cac14", "scrolled": true }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "

" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "The training loss at the step closest to the lowest evaluation loss (1.3087536096572876) is 0.851 at step 1900\n" ] } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "# Assuming 'trainer' is your trained Hugging Face Trainer object\n", "# Extract the training and evaluation loss from the Trainer's state\n", "training_loss = trainer.state.log_history\n", "\n", "# Initialize lists to store the steps, training loss, and evaluation loss\n", "steps = []\n", "l_steps = []\n", "train_loss = []\n", "eval_loss = []\n", "\n", "\n", "# Extract the relevant data\n", "for entry in training_loss:\n", " if 'loss' in entry:\n", " steps.append(entry['step'])\n", " train_loss.append(entry['loss'])\n", " if 'eval_loss' in entry:\n", " eval_loss.append(entry['eval_loss'])\n", " l_steps.append(entry['step'])\n", "\n", "# Plot the training and evaluation loss\n", "plt.figure(figsize=(10, 5))\n", "plt.plot(steps, train_loss, label='Training Loss')\n", "if eval_loss:\n", " plt.plot(l_steps, eval_loss, label='Evaluation Loss')\n", "\n", "# Adding labels and title\n", "plt.xlabel('Steps')\n", "plt.ylabel('Loss')\n", "plt.title('Training and Evaluation Loss')\n", "plt.legend()\n", "\n", "# Save the plot as an SVG file\n", "plt.savefig('training_evaluation_loss.svg')\n", "\n", "# Show the plot\n", "plt.show()\n", "\n", "\n", "# Identify the training loss checkpoint closest to the lowest eval loss\n", "min_eval_loss = min(eval_loss)\n", "min_eval_loss_step = l_steps[eval_loss.index(min_eval_loss)]\n", "\n", "# Find the closest training loss step\n", "closest_train_loss_step = min(steps, key=lambda x: abs(x - min_eval_loss_step))\n", "closest_train_loss = train_loss[steps.index(closest_train_loss_step)]\n", "\n", "print(f'The training loss at the step closest to the lowest evaluation loss ({min_eval_loss}) is {closest_train_loss} at step {closest_train_loss_step}')" ] }, { "cell_type": "markdown", "id": "ba9d4035-e79d-4502-9806-5e704dada48b", "metadata": { "id": "ba9d4035-e79d-4502-9806-5e704dada48b" }, "source": [ "### Load and test the fine-tuned model" ] }, { "cell_type": "code", "execution_count": 16, "id": "2ceca2f2-42b5-4f95-9c90-44f13d041185", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 84, "referenced_widgets": [ "c386ef03c3db48189d857e4a86125b05", "a4c4d886c5d540a6a6d683fa2c576240", "4000cf23d2a444dcb84a8254b947502c", "c07a3936cd08420f8583c34b73072028", "c26fe456208f454aba0ce3872eb74751", "45ea9cec348d450b88b9fdc3acc5541e", "0814ceb877c94bcc9fce074adc49cea8", "0ca3ec39ec5046768b77786c45d4a1ad", "708beaf047f546c68ba58b73cfcda51c", "64272d0e01fb46d38ef21d22a6d49b3d", "75de7b33bed94aef9ce3fca10b1ff09e" ] }, "id": "2ceca2f2-42b5-4f95-9c90-44f13d041185", "outputId": "82c410f0-e712-404c-cdec-9d80c9c68949", "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Unused kwargs: ['use_nested_quant']. These kwargs are not used in .\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "lamm-mit/BioinspiredLlama-3-1-8B-128k protein_secondary_structure_predictor/checkpoint-1980\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c386ef03c3db48189d857e4a86125b05", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/4 [00:00', max_new_tokens=16, temperature=0.1, messages=[], system_prompt='')\n", "\n", "print (f\"Answer LLM: {answer[0]}\\nCorrect answer: {corr_answer}\")" ] }, { "cell_type": "code", "execution_count": 20, "id": "28700cde-6489-4fa7-a57b-af8233e3f2aa", "metadata": { "id": "28700cde-6489-4fa7-a57b-af8233e3f2aa" }, "outputs": [], "source": [ "def calculate_accuracy(model, test_dataset, N_test):\n", " \"\"\"\n", " Calculate the accuracy of the model on the test dataset.\n", "\n", " Parameters:\n", " model (object): The model to use for predictions.\n", " test_dataset (list): The list of test samples, each containing a 'question' and 'answer'.\n", " N_test (int): The number of test samples to evaluate.\n", "\n", " Returns:\n", " float: The accuracy of the model on the test dataset.\n", " \"\"\"\n", " correct_predictions = 0\n", "\n", " for i in tqdm(range(N_test)):\n", " question = test_dataset[i]['question']\n", " correct_answer = test_dataset[i]['answer']\n", "\n", " # Generate the model's response\n", " answer, _ = generate_response(\n", " text_input='Dominant secondary structure of < ' + question + ' >',\n", " max_new_tokens=16,\n", " temperature=0.01, messages=[], system_prompt='',\n", " )\n", "\n", " # Check if the model's answer matches the correct answer\n", " if answer[0].strip() == correct_answer.strip():\n", " correct_predictions += 1\n", "\n", " else:\n", " print(f\"Answer LLM: {answer[0]}\\nWrong answer: {correct_answer}\")\n", "\n", " # Calculate accuracy\n", " accuracy = correct_predictions / N_test\n", " return accuracy\n" ] }, { "cell_type": "code", "execution_count": 21, "id": "a55e1a0c-87af-42b1-ba3c-15c2d8865215", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000, "referenced_widgets": [ "ed008dac238848349eaf427fc764b4c8", "fa1ff9dfeef6453c9eae1b4d55ad6bca", "20bebd52648447ef9c947b196b271520", "31d66ef3d1c94745a7ed52d748cc7fcc", "fa7bd2fceb844f1da76a081303991874", "597d5a90a55e4d8f8da0a4d9d4661173", "8cda86c26f4b4213b98fd75f49500f47", "e10ac0a882174a00a908514cfd0f6f63", "6c8fc499373541c0856970a565961938", "dc4868d3c0e540dd980d282e1ee2ef59", "c26d3ecdf8314508aac6ba467315dc76" ] }, "id": "a55e1a0c-87af-42b1-ba3c-15c2d8865215", "outputId": "35eabea8-ba13-4d29-99ed-cee8ce8cf9bf", "scrolled": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ed008dac238848349eaf427fc764b4c8", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/1436 [00:00