Code Example on Instruction Fine-tuning of llama2-7B using LoRA
In earlier articles we discussed instruction fine-tuning, LoRA and quantization. We now tie these concepts and show an example code where we perform instruction fine-tuning of llama2-7B using LoRA. This was done on a A5000 GPU with 24GB of ram.
Imports
import argparse
import json
from types import SimpleNamespace as Namespace
from typing import Union
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel, PeftConfig
device = "cuda" if torch.cuda.is_available() else "cpu"
Instruction Fine-tuning Prompt Templates
For instruction fine-tuning, we leverage the 52K examples collected from the Stanford Alpaca project.
prompt_input = (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
)
prompt_no_input = (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
)
Helper Methods to Generate and Tokenize Prompts
def tokenize(tokenizer, prompt, add_eos_token=True):
# there's probably a way to do this with the tokenizer settings
# but again, gotta move fast
result = tokenizer(
prompt,
truncation=True,
max_length=256,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != tokenizer.eos_token_id
and len(result["input_ids"]) < 256
and add_eos_token
):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
def generate_prompt(instruction: str, input: Union[None, str] = None, label: Union[None, str] = None):
if input is None or input == "":
prompt = prompt_no_input.format(instruction=instruction)
else:
prompt = prompt_input.format(instruction=instruction, input=input)
if label is not None:
prompt = f"{prompt}{label}"
return prompt
def generate_and_tokenize_prompt(data_point, tokenizer):
full_prompt = generate_prompt(instruction=data_point["instruction"], input=data_point["input"], label=data_point["output"])
full_prompt_tokenized = tokenize(tokenizer, full_prompt)
user_prompt = generate_prompt(instruction=data_point["instruction"], input=data_point["input"])
user_prompt_tokenized = tokenize(tokenizer, user_prompt, add_eos_token=False)
user_prompt_len = len(user_prompt_tokenized["input_ids"])
full_prompt_tokenized["labels"] = [-100] * user_prompt_len + full_prompt_tokenized["labels"][user_prompt_len:]
return full_prompt_tokenized
Training Method
def print_trainable_parameters(model):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
)
def train(params):
# e.g. tokenizer_id = "TinyPixel/Llama-2-7B-bf16-sharded"
tokenizer = AutoTokenizer.from_pretrained(params.tokenizer_id)
tokenizer.pad_token_id = (0) # just something different from eos token
dataset = load_dataset("json", data_files="../alpaca_data.json")
train_data = dataset['train'].shuffle().map(generate_and_tokenize_prompt, fn_kwargs={"tokenizer": tokenizer})
data_collator = transformers.DataCollatorForSeq2Seq(tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True)
# e.g. model_id = "TinyPixel/Llama-2-7B-bf16-sharded"
model = AutoModelForCausalLM.from_pretrained(params.model_id, load_in_8bit=True, device_map="auto")
model = prepare_model_for_kbit_training(model)
config = LoraConfig(
r=params.lora.r,
lora_alpha=params.lora.alpha,
target_modules=["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj", "self_attn.o_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, config)
print_trainable_parameters(model)
training_args = transformers.TrainingArguments(
per_device_train_batch_size=params.train.train_batch_size,
gradient_accumulation_steps=params.train.gradient_accumulation_steps,
warmup_steps=5,
max_steps=60,
num_train_epochs=1,
learning_rate=3e-4,
output_dir="outputs",
report_to=None
)
trainer = transformers.Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_data,
)
trainer.train()
trainer.model.save_pretrained(params.model_dir)
tokenizer.save_pretrained(params.model_dir)
Example Inference Method
Performing inference took about 10G of GPU ram.
def inference(params):
config = PeftConfig.from_pretrained(params.model_dir)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, params.model_dir)
model.eval()
tokenizer.pad_token_id = model.config.pad_token_id
instruction = "Tell me about deep learning and transformers."
prompt = generate_prompt(instruction=instruction, input=None)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig(temperature=0.1, top_p=1.0, top_k=10, num_beams=3)
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=128,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s)
print(output)
Main Controller Method
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--mode", required=True)
parser.add_argument("--params", required=True)
args = parser.parse_args()
with open(args.params, "r", encoding="utf-8") as f:
params = json.load(f, object_hook=lambda d: Namespace(**d))
if args.mode == "train":
train(params)
elif args.mode == "inference":
inference(params)
Written on October 1, 2023