End to end instruction example using LLMs with fine-tuning¶
This notebook provides an example for how to use the pretrain data. This means that the model is trained on full patient histories, without any specific task. This can be used to develop models that can generate synthetic patients or embeddings.
Note: You need a GPU with at least 30GB of memory for this example to work. We also have not tested the performance of PEFT models - only as examples.
Important: Please install first the fine-tuning packages with
pip install twinweaver[fine-tuning-example].
from transformers import AutoTokenizer
import pandas as pd
import gc
import torch
from datasets import Dataset
from transformers import (
AutoModelForCausalLM,
BitsAndBytesConfig,
pipeline,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer, SFTConfig
from twinweaver import DataManager, Config, ConverterPretrain
# Some key settings
BASE_MODEL = "microsoft/Phi-4-mini-instruct" # NOTE: we haven't tested the performance of this model beyond examples
Generate training data¶
# Load data
df_events = pd.read_csv("../../example_data/events.csv")
df_constant = pd.read_csv("../../example_data/constant.csv")
df_constant_description = pd.read_csv("../../example_data/constant_description.csv")
# Manually set up which constant columns we want to use
config = Config() # Override values here to customize pipeline
config.constant_columns_to_use = ["birthyear", "gender", "histology", "smoking_history"]
config.constant_birthdate_column = "birthyear"
dm = DataManager(config=config)
dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)
dm.process_indication_data()
dm.setup_unique_mapping_of_events()
dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)
converter = ConverterPretrain(config=config, dm=dm)
# Get all training + validation patientids
training_patientids = dm.get_all_patientids_in_split(config.train_split_name)
validation_patientids = dm.get_all_patientids_in_split(config.validation_split_name)
The generate_transformers_df function iterates through each patient and generates the text data.
def generate_transformers_df(patientids_list):
df = []
for patientid in patientids_list:
patient_data = dm.get_patient_data(patientid)
p_converted = converter.forward_conversion(events=patient_data["events"], constant=patient_data["constant"])
new_data = {
"text": p_converted["text"],
"patientid": f"{patientid}", # Just for ease of finding later
}
df.append(new_data)
df = pd.DataFrame(df)
return df
# Generate training and validation dfs
df_train = generate_transformers_df(training_patientids)
df_validation = generate_transformers_df(validation_patientids)
df_train.head()
Fine-tune LLM¶
We start by setting up the tokenizer. We set the padding token to be the same as the EOS (End of Sequence) token, which is a common practice for causal language models.
# Setup tokenizer and datasets
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
# Set padding token to eos_token
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
train_dataset = Dataset.from_pandas(df_train)
validation_dataset = Dataset.from_pandas(df_validation)
Instruction-tuned models expect data in a specific conversational format (e.g., User: ... Assistant: ...).
We use format_chat_template to structure our raw prompt/completion strings into this list-of-messages format using the user and assistant roles.
# Format data for chat template
def format_chat_template(example):
"""Convert prompt/completion pairs to proper prompt/completion format"""
return {
"text": example["text"],
}
# Apply formatting to datasets
train_dataset = train_dataset.map(format_chat_template)
validation_dataset = validation_dataset.map(format_chat_template)
We configure 4-bit quantization using BitsAndBytesConfig (QLoRA). This significantly lowers memory usage, allowing us to fine-tune the model on consumer GPUs.
# Define Quantization Config (4-bit loading)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16, # This should be set based on your GPU capabilities
bnb_4bit_use_double_quant=True,
)
Here we set up Low-Rank Adaptation (LoRA) configuration. LoraConfig defines the adapter parameters (rank r, alpha). we target linear layers (q_proj, k_proj etc.) which generally yields better results than just attending to query/value projections.
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=8, # Rank (higher = more parameters to train)
bias="none",
task_type="CAUSAL_LM",
# Target all linear layers for best performance (specific to Llama architecture)
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)
We define the training arguments in SFTConfig. Notice the higher learning rate (1e-4) compared to typical full fine-tuning in the GDT paper. We also set bf16=True for newer GPUs (Ampere+) to improve training stability.
training_arguments = SFTConfig(
output_dir="./results",
num_train_epochs=5,
per_device_train_batch_size=1,
gradient_accumulation_steps=1,
optim="paged_adamw_32bit",
save_steps=10,
logging_steps=10,
eval_strategy="steps",
eval_steps=10,
per_device_eval_batch_size=1,
learning_rate=1e-4, # LR is higher for PEFT, see TwinWeaver paper for full fine-tuning details
fp16=False, # Use fp16 for older GPUs T4/V100, bf16 for Ampere and later (A100/3090/4090)
bf16=True,
max_grad_norm=1.0,
warmup_ratio=0.1,
group_by_length=True,
save_total_limit=1,
lr_scheduler_type="cosine",
max_length=8192,
packing=False, # Disable packing for more exact training, though can be activated
completion_only_loss=False, # Compute loss on entire text
)
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=False,
)
# Disable cache for training (required for gradient checkpointing)
model.config.use_cache = False
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
processing_class=tokenizer,
args=training_arguments,
eval_dataset=validation_dataset,
peft_config=peft_config,
)
# Start training - takes around 5 mins, depending on hardware
trainer.train()
# Save the fine-tuned adapter
adapter_path = "./results/final_adapter"
trainer.save_model(adapter_path)
print(f"Adapter saved to {adapter_path}")
del trainer
del model
gc.collect()
torch.cuda.empty_cache()
Inference example¶
Inference example for a test set patient, where we want to generate the full patient trajectory after the first line of therapy.
# Get the first test set patient
test_patientid = dm.get_all_patientids_in_split(config.test_split_name)[0]
patient_data = dm.get_patient_data(test_patientid)
# Lets simulate forecasts for after the first line of therapy
df_constant_patient = patient_data["constant"].copy()
df_events_patient = patient_data["events"].copy()
date_of_first_lot = df_events_patient.loc[df_events_patient["event_category"] == "lot", "date"].min()
date_of_first_event = df_events_patient["date"].min()
# Only keep data until (and including) first line of therapy
df_events_patient = df_events_patient.loc[df_events_patient["date"] <= date_of_first_lot]
We convert the patient data into the first part.
# Convert to instruction
converted = converter.forward_conversion(events=df_events_patient, constant=df_constant_patient)
For inference, we load the base model again (clean slate) to avoid any state from training, and then attach the adapter we trained. PeftModel handles the integration of the LoRA weights with the base model.
For inference, we load the base model again (clean slate) and then attach the adapter we trained. PeftModel handles the integration of the LoRA weights.
# 1. Load the Base Model again (clean instance)
base_model_inference = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
quantization_config=bnb_config, # Reuse the 4-bit config
device_map="auto",
trust_remote_code=False,
)
# 2. Load the Saved Adapter
# This wraps the base model with the fine-tuned LoRA layers
inference_model = PeftModel.from_pretrained(base_model_inference, adapter_path)
# 3. Switch to evaluation mode
inference_model.eval()
# Create text generation pipeline
# Re-enable cache for inference
inference_model.config.use_cache = True
text_gen_pipeline = pipeline("text-generation", model=inference_model, tokenizer=tokenizer)
# Generate with LLM, for a given time
generated_answer = text_gen_pipeline(
converted["text"],
max_new_tokens=128, # <------- Set this higher for longer answers, lower for shorter answers
return_full_text=False,
do_sample=True, # Using nucleus sampling
temperature=0.7,
top_p=0.9,
)[0]["generated_text"]
# Show the generated answer
print(generated_answer)
The raw text output from the model needs to be parsed back into structured data. reverse_conversion handles this, returning a dictionary with the data.
# Reverse convert
full_trajectory = converted["text"] + generated_answer
ret_dict = converter.reverse_conversion(full_trajectory, dm, date_of_first_event)
ret_dict["events"].head()