--- library_name: peft base_model: google/gemma-2b license: mit tags: - Mathematical Reasoning language: - en datasets: - adityasihag/math_QAaugP --- **This repo contains LoRA adapter weights**. ### Model Description - **Project GitHub Page:** https://github.com/adityasihag1996/math_QA.git - **Developed by:** [Aditya Sihag](https://www.linkedin.com/in/aditya-sihag-ab29681a9/) - **Model type:** fine-tuned using QLoRA on 1x RTX 4090 - **Finetuned from model:** google/gemma-2b ## Results
Prompt Approach GSM8k MATH
Zero-Shot CoT 43.66 -
## Training procedure The following `bitsandbytes` quantization config was used during training: - quant_method: bitsandbytes - load_in_8bit: False - load_in_4bit: True - bnb_4bit_quant_type: nf4 - bnb_4bit_use_double_quant: True - bnb_4bit_compute_dtype: float16 `LoraConfig` params: - r: 128 - lora_alpha: lora_r * 2 - lora_dropout: 0.05 - bias: "none" - task_type: "CAUSAL_LM" - target_modules: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] The hyperparameters for the LoRA fine-tuning are listed below: - epochs: 3 - learning_rate: 5e-5 - batch_size: 256 - max_grad_norm: 1.0 - weight_decay: 0.001 - lr_scheduler_type: "cosine" - warmup_ratio: 0.03 ## Dataset math_QA dataset is prepared as combination of [MetaMathQA](https://huggingface.co/datasets/meta-math/MetaMathQA) and [MathInstruct](https://huggingface.co/datasets/TIGER-Lab/MathInstruct), and some internal data. Refer [math_QAaugP](https://huggingface.co/datasets/adityasihag/math_QAaugP) ## Model Usage ``` import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer ) from peft import PeftModel model_path = "google/gemma-2b" model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype = torch.float16, device_map = {"": 0}, ) # Load LoRA and merge model = PeftModel.from_pretrained(model, "adityasihag/math_QA-gemma-2B-QLoRA-adapter") model = model.merge_and_unload() tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token question = """Solve the linear equations. $3(x+2)-x=x + 9$. Find the value of x.""" sample_input = f"""Question: {question} \n Answer: """ sample_input_tokenised = tokenizer(sample_input, return_tensors = "pt").to("cuda") generated_ids = model.generate( **sample_input_tokenised, max_new_tokens = 1024, temperature = 0.3 ) output = tokenizer.decode(generated_ids[0], skip_special_tokens = True) print(output) ``` ##### Sample Input: ``` Question: Solve the linear equations. $3(x+2)-x=x + 9$. Find the value of x. \n Answer: ``` ##### Model Output: ``` Given the linear equation 3(x+2)-x=x+9. First, distribute the 3 in the brackets to get 3x + 6 - x = x + 9. Simplify the equation to get 2x + 6 = x + 9. Next, transpose x from the right side to the left side and from the left side to the right side to get x = 9 - 6. Finally, solve for x to get x = 3. ``` #### Prompt Template: ``` Question: Answer: ``` ## Comparing math_QA models with other SFT LLM models | Model | GSM8k Pass@1 | MATH Pass@1 | |---------------------|--------------|-------------| | LLaMA-2-7B | 14.6 | 2.5 | | gemma-2b | 17.7 | | | LLaMA-2-13B | 28.7 | 3.9 | | LLaMA-2-34B | 42.2 | 6.24 | | **math_QA-gemma-2B** | **43.66** | | | gemma-7b | 46.4 | | | WizardMath-7B | 54.9 | 10.7 | | Mistral-7B | 35.4 | | | WizardMath-13B | 63.9 | 14.0 | | MetaMath-7B | 66.5 | 19.8 | | MetaMath-13B | 72.3 | 22.4 | | **math_QA-Mistral-7B** | **75.81** | | | Arithmo2-Mistral-7B | 76.4 | 27.2 | | MetaMath-Mistral-7B | 77.7 | 28.2 | | DeepSeekMath-Instruct-7B | 82.9 | 46.8 | | GPT4 | 92.0 | 52.9 |