alpaca_orca_open_llama: An Open_LLaMA-3B model trained on Alpaca dataset using Orca Research paper approaches
Dataset and Training
We train OpenLLaMa-3B model on the custom Alpaca dataset created using Orca Research Paper approaches.
Please pay attention how System prompt is added and used for each instruction.
The training configurations are provided in the table below.
The training takes on 4 x A600(50G) GPUs and lasts for around 20 Hours for cost of $66.
We used DeepSpeed with Zero-3 approaches for parallel gpu training.
Batch Size | 16 |
train_micro_batch_size_per_gpu | 2 |
gradient_accumulation_steps | 2 |
Learning rate | 2e-5 |
Epochs | 3 |
Max length | 1024 |
Example Usage
Below shows an example on how to use OpenAlpaca
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
# the previewed version of OpenAlpaca
model_path = r'psmathur/alpaca_orca_open_llama_3b'
tokenizer = LlamaTokenizer.from_pretrained(model_path)
model = LlamaForCausalLM.from_pretrained(model_path).cuda()
tokenizer.bos_token_id, tokenizer.eos_token_id = 1,2 # see https://github.com/openlm-research/open_llama#preview-weights-release-and-usage
# same prompt as provided by Orca Research Paper
system = r'You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.'
instruction = r'Use the given data to calculate the median.'
input = r'[7, 3, 8, 2, 10]'
prompt_no_input = f'.\n\n### Instruction:\n{instruction}\n\n### Response:'
tokens = tokenizer.encode(prompt_no_input)
tokens = torch.LongTensor(tokens).unsqueeze(0)
instance = {'input_ids': tokens,
'top_k': 50,
'top_p': 0.9,
'generate_len': 128}
length = len(tokens[0])
with torch.no_grad():
rest = model.generate(
input_ids=tokens,
max_length=length+instance['generate_len'],
use_cache=True,
do_sample=True,
top_p=instance['top_p'],
top_k=instance['top_k']
)
output = rest[0][length:]
string = tokenizer.decode(output, skip_special_tokens=True)
print(f'[!] Generation results: {string}')