shyam-incedoinc's picture
Create README.md
195c8ae

This is a fine-tuned model, trained on 400+ test scripts, written in Java using Cucumber and Selenium frameworks.

Base model used is codellama/CodeLlama-7b-hf. The dataset used can be found at shyam-incedoinc/qa-finetune-dataset.

Training metrics can be seen in the metrics section.

Training Parameters

    num_train_epochs=25,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=1,
    gradient_checkpointing=True,
    optim="paged_adamw_32bit",
    #save_steps=save_steps,
    logging_steps=25,
    save_strategy="epoch",
    learning_rate=2e-4,
    weight_decay=0.001,
    fp16=True,
    bf16=False,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    #max_steps=max_steps,
    group_by_length=False,
    lr_scheduler_type="cosine",
    disable_tqdm=False,
    report_to="tensorboard",
    seed=42
)

LoraConfig(
        lora_alpha=16,
        lora_dropout=0.1,
        r=64,
        bias="none",
        task_type="CAUSAL_LM",
)

Run the below code block for getting inferences from this model.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

hf_model_repo = "shyam-incedoinc/codellama-7b-hf-peft-qlora-finetuned-qa"

# Get the tokenizer
tokenizer = AutoTokenizer.from_pretrained(hf_model_repo)

# Load the model
model = AutoModelForCausalLM.from_pretrained(hf_model_repo, load_in_4bit=True,
                                             torch_dtype=torch.float16,
                                             device_map="auto")

# Load dataset from the hub
hf_data_repo = "shyam-incedoinc/qa-finetune-dataset"
train_dataset = load_dataset(hf_data_repo, split="train")
valid_dataset = load_dataset(hf_data_repo, split="validation")

# Load the sample
sample = valid_dataset[randrange(len(valid_dataset))]['text']
groundtruth = sample.split("### Output:\n")[1]
prompt = sample.split("### Output:\n")[0]+"### Output:\n"

# Generate response
input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda()
outputs = model.generate(input_ids=input_ids, max_new_tokens=1024,
                                do_sample=True, top_p=0.9, temperature=0.6)

# Print the result
print(f"Generated response:\n{tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0]}")
print(f"Ground Truth:\n{groundtruth}")