End to end instruction example using LLMs with fine-tuning¶
This notebook provides a comprehensive, end-to-end demonstration of fine-tuning a Large Language Model (LLM) for medical forecasting tasks using the twinweaver library. The workflow begins by processing raw medical data (events, constants, and lab values) into structured instruction-tuning datasets using DataManager and ConverterInstruction, effectively translating patient histories into prompt-completion pairs. We then implement Parameter-Efficient Fine-Tuning (PEFT) using QLoRA (4-bit quantization) and the SFTTrainer to adapt a microsoft/Phi-4-mini-instruct model, optimizing it for clinical predictions while managing memory constraints. Finally, the example concludes with an inference pipeline that loads the trained adapter to predict future clinical outcomes—such as hemoglobin levels and mortality risks—and reverse-converts the LLM's text output back into structured data.
Note: You need a GPU with at least 30GB of memory for this example to work. We also have not tested the performance of PEFT models - only as examples.
Important: Please install first the packages from
requirements.txtfrom the/examplesfolder, e.g. usingpip install -r requirements.txt.
from transformers import AutoTokenizer
import pandas as pd
import gc
import torch
from datasets import Dataset
from transformers import (
AutoModelForCausalLM,
BitsAndBytesConfig,
pipeline,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer, SFTConfig
from twinweaver import (
DataManager,
Config,
DataSplitterForecasting,
DataSplitterEvents,
ConverterInstruction,
DataSplitter,
)
# Some key settings
BASE_MODEL = "microsoft/Phi-4-mini-instruct" # NOTE: we haven't tested the performance of this model beyond examples
Generate training data¶
First, we need to set up the configuration. This includes specifying which constant variables to use.
# Load data
df_events = pd.read_csv("./example_data/events.csv")
df_constant = pd.read_csv("./example_data/constant.csv")
df_constant_description = pd.read_csv("./example_data/constant_description.csv")
# Manually set up which constant columns we want to use
config = Config() # Override values here to customize pipeline
config.constant_columns_to_use = ["birthyear", "gender", "histology", "smoking_history"]
config.constant_birthdate_column = "birthyear"
Here we initialize the DataManager to handle data loading and processing.
We also set up the Data Splitters:
DataSplitterEvents: Handles splitting of event data (diagnoses, treatments).DataSplitterForecasting: Handles splitting of time-series data (lab values) and statistics generation.
Finally, ConverterInstruction is initialized. This component is responsible for translating the structured patient data splits into the textual instruction format (Prompt + Completion) that the LLM understands.
dm = DataManager(config=config)
dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)
dm.process_indication_data()
dm.setup_unique_mapping_of_events()
dm.setup_dataset_splits()
dm.setup_dataset_splits()
dm.infer_var_types()
data_splitter_events = DataSplitterEvents(dm, config=config)
data_splitter_events.setup_variables()
data_splitter_forecasting = DataSplitterForecasting(
data_manager=dm,
config=config,
)
# If you don't want to do forecasting QA, proportional sampling, or 3-sigma filtering, you can skip this step
data_splitter_forecasting.setup_statistics()
# We will also use the easier interface that combines both data splitters
data_splitter = DataSplitter(data_splitter_events, data_splitter_forecasting)
converter = ConverterInstruction(
dm.data_frames["constant_description"],
nr_tokens_budget_total=8192,
config=config,
dm=dm,
variable_stats=data_splitter_forecasting.variable_stats, # Optional, needed for forecasting QA tasks
)
# Get all training + validation patientids
training_patientids = dm.get_all_patientids_in_split(config.train_split_name)
validation_patientids = dm.get_all_patientids_in_split(config.validation_split_name)
The generate_transformers_df function iterates through each patient and generates input/output pairs.
For each patient, it may generate multiple "splits" (different reference dates in their history). Each split is converted into a text prompt (history) and a text completion (future events/values).
The result is a DataFrame with "prompt" and "completion" columns.
def generate_transformers_df(patientids_list):
df = []
for patientid in patientids_list:
patient_data = dm.get_patient_data(patientid)
forecasting_splits, events_splits, reference_dates = data_splitter.get_splits_from_patient_with_target(
patient_data,
forecasting_filter_outliers=False,
)
for split_idx in range(len(forecasting_splits)):
p_converted = converter.forward_conversion(
forecasting_splits=forecasting_splits[split_idx],
event_splits=events_splits[split_idx],
override_mode_to_select_forecasting="both",
)
new_data = {
"prompt": p_converted["instruction"],
"completion": p_converted["answer"],
"patientid": f"{patientid}_split{split_idx}", # Just for ease of finding later
}
df.append(new_data)
df = pd.DataFrame(df)
return df
# Generate training and validation dfs
df_train = generate_transformers_df(training_patientids)
df_validation = generate_transformers_df(validation_patientids)
df_train
Fine-tune LLM¶
We start by setting up the tokenizer. We set the padding token to be the same as the EOS (End of Sequence) token, which is a common practice for causal language models.
# Setup tokenizer and datasets
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
# Set padding token to eos_token
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
train_dataset = Dataset.from_pandas(df_train)
validation_dataset = Dataset.from_pandas(df_validation)
Instruction-tuned models expect data in a specific conversational format (e.g., User: ... Assistant: ...).
We use format_chat_template to structure our raw prompt/completion strings into this list-of-messages format using the user and assistant roles.
# Format data for chat template
def format_chat_template(example):
"""Convert prompt/completion pairs to proper prompt/completion format"""
return {
"prompt": [{"role": "user", "content": example["prompt"]}],
"completion": [{"role": "assistant", "content": example["completion"]}],
}
# Apply formatting to datasets
train_dataset = train_dataset.map(format_chat_template)
validation_dataset = validation_dataset.map(format_chat_template)
We configure 4-bit quantization using BitsAndBytesConfig (QLoRA). This significantly lowers memory usage, allowing us to fine-tune the model on consumer GPUs.
# Define Quantization Config (4-bit loading)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16, # This should be set based on your GPU capabilities
bnb_4bit_use_double_quant=True,
)
Here we set up Low-Rank Adaptation (LoRA) configuration. LoraConfig defines the adapter parameters (rank r, alpha). we target linear layers (q_proj, k_proj etc.) which generally yields better results than just attending to query/value projections.
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=8, # Rank (higher = more parameters to train)
bias="none",
task_type="CAUSAL_LM",
# Target all linear layers for best performance (specific to Llama architecture)
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)
We define the training arguments in SFTConfig. Notice the higher learning rate (1e-4) compared to typical full fine-tuning in the GDT paper. We also set bf16=True for newer GPUs (Ampere+) to improve training stability.
training_arguments = SFTConfig(
output_dir="./results",
num_train_epochs=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=1,
optim="paged_adamw_32bit",
save_steps=10,
logging_steps=10,
eval_strategy="steps",
eval_steps=10,
per_device_eval_batch_size=1,
learning_rate=1e-4, # LR is higher for PEFT, see TwinWeaver paper for full fine-tuning details
fp16=False, # Use fp16 for T4/V100, bf16 for Ampere and later (A100/3090/4090)
bf16=True,
max_grad_norm=1.0,
warmup_ratio=0.1,
group_by_length=True,
save_total_limit=1,
lr_scheduler_type="cosine",
max_length=8192,
packing=False, # Disable packing for instruction tuning
completion_only_loss=True, # Only compute loss on assistant responses
)
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=False,
)
# Disable cache for training (required for gradient checkpointing)
model.config.use_cache = False
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
processing_class=tokenizer,
args=training_arguments,
eval_dataset=validation_dataset,
peft_config=peft_config,
)
# Start training - takes around 5 mins, depending on hardware
trainer.train()
# Save the fine-tuned adapter
adapter_path = "./results/final_adapter"
trainer.save_model(adapter_path)
print(f"Adapter saved to {adapter_path}")
del trainer
del model
gc.collect()
torch.cuda.empty_cache()
Inference example¶
Inference example for a test set patient, where we want to make predictions after the first line of therapy.
# Get the first test set patient
test_patientid = dm.get_all_patientids_in_split(config.test_split_name)[0]
patient_data = dm.get_patient_data(test_patientid)
# Lets simulate forecasts for after the first line of therapy
df_constant_patient = patient_data["constant"].copy()
df_events_patient = patient_data["events"].copy()
date_of_first_lot = df_events_patient.loc[
df_events_patient["event_category"] == config.event_category_lot, "date"
].min()
# Only keep data until (and including) first line of therapy
df_events_patient = df_events_patient.loc[df_events_patient["date"] <= date_of_first_lot]
# Lets forecast hemoglobin at 4, 8, and 12 weeks
# and death within 52 weeks
forecasting_times_to_predict = {
"hemoglobin_-_718-7": [4, 8, 12],
}
forecast_split, events_split = data_splitter.get_splits_from_patient_inference(
patient_data,
inference_type="both",
forecasting_override_variables_to_predict=["hemoglobin_-_718-7"],
events_override_category="death",
events_override_observation_time_delta=pd.Timedelta(days=52 * 7),
)
We convert the patient data into an instruction prompt. Unlike training, forward_conversion_inference only generates the input prompt (without the target answer), as we want the LLM to generate the answer.
# Convert to instruction
converted = converter.forward_conversion_inference(
forecasting_split=forecast_split,
forecasting_future_weeks_per_variable=forecasting_times_to_predict,
event_split=events_split,
custom_tasks=None,
)
For inference, we load the base model again (clean slate) to avoid any state from training, and then attach the adapter we trained. PeftModel handles the integration of the LoRA weights with the base model.
For inference, we load the base model again (clean slate) and then attach the adapter we trained. PeftModel handles the integration of the LoRA weights.
# 1. Load the Base Model again (clean instance)
base_model_inference = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
quantization_config=bnb_config, # Reuse the 4-bit config
device_map="auto",
trust_remote_code=False,
)
# 2. Load the Saved Adapter
# This wraps the base model with the fine-tuned LoRA layers
inference_model = PeftModel.from_pretrained(base_model_inference, adapter_path)
# 3. Switch to evaluation mode
inference_model.eval()
# Create text generation pipeline
# Re-enable cache for inference
inference_model.config.use_cache = True
text_gen_pipeline = pipeline("text-generation", model=inference_model, tokenizer=tokenizer)
# Generate with LLM
generated_answer = text_gen_pipeline(
[{"role": "user", "content": converted["instruction"]}],
max_new_tokens=128,
return_full_text=False,
do_sample=True, # Using nucleus sampling
temperature=0.7,
top_p=0.9,
)[0]["generated_text"]
# Show the generated answer
generated_answer
The raw text output from the model needs to be parsed back into structured data. reverse_conversion handles this, returning a list of dictionaries with the predicted results for each task.
# Reverse convert
return_list = converter.reverse_conversion(generated_answer, dm, date_of_first_lot)
# Task 1 reverse conversion
return_list[0]["result"]
# Task 2 reverse conversion
return_list[1]["result"]