๐ Challenge 2: End-to-End LLM Fine-tuningยถ
Difficulty: โญโญโญ (Advanced) | Time: 90-120 minutes
๐ฏ Learning Objectivesยถ
By completing this challenge, you will:
- Generate training datasets from clinical data
- Configure and understand LoRA/QLoRA parameters
- Fine-tune an LLM for medical forecasting
- Run inference and evaluate model predictions
โ ๏ธ Prerequisitesยถ
- Complete Challenge 1 first!
- GPU with at least 30GB memory
- Install:
pip install twinweaver[fine-tuning-example]
๐ Rulesยถ
- Complete all
# TODO:sections - Answer quiz questions before proceeding
- Make predictions about hyperparameter effects BEFORE running experiments
- No peeking at the original tutorial!
import pandas as pd
import gc
import torch
from datasets import Dataset
from transformers import (
BitsAndBytesConfig,
)
from peft import LoraConfig
from trl import SFTConfig
from twinweaver import (
DataManager,
Config,
)
Part 1: Decide on Model and Context Lengthยถ
Before starting, you need to make important decisions about your setup.
โ Quiz 1: Model Selectionยถ
Q1.1: Why might you choose a smaller model (e.g., Phi-4-mini) over a larger one (e.g., Llama-70B) for this task?
Q1.2: What is the trade-off between context length and memory usage?
Q1.3: Why do we use an "instruction-tuned" base model instead of a base model?
Your Answers:
Q1.1:
Q1.2:
Q1.3:
# TODO: Choose your base model
# Options to consider:
# - "microsoft/Phi-4-mini-instruct" (small, fast)
# - "meta-llama/Llama-3.2-3B-Instruct" (medium)
# - "mistralai/Mistral-7B-Instruct-v0.3" (larger)
BASE_MODEL = None # Choose your model
# TODO: Choose your context length
# Consider: Longer = more patient history, but more memory
# Reasonable range: 2048 - 16384
MAX_CONTEXT_LENGTH = None # Choose your context length
# Load data
df_events = pd.read_csv("../example_data/events.csv")
df_constant = pd.read_csv("../example_data/constant.csv")
df_constant_description = pd.read_csv("../example_data/constant_description.csv")
# TODO: Configure the pipeline (you did this in Challenge 1!)
# Set up:
# - split_event_category
# - event_category_forecast
# - event_category_events_prediction_with_naming
# - constant_columns_to_use
# - constant_birthdate_column
config = Config()
# Your configuration here...
# TODO: Initialize DataManager and all splitters
# This should be familiar from Challenge 1
dm = DataManager(config=config)
# ... complete the setup
# Initialize splitters
# ...
# Initialize converter with YOUR chosen context length
# ...
๐ง Exercise 2.1: Implement Dataset Generatorยถ
Write a function to generate the training dataset. This is a key skill!
# Get patient IDs for each split
training_patientids = dm.get_all_patientids_in_split(config.train_split_name)
validation_patientids = dm.get_all_patientids_in_split(config.validation_split_name)
print(f"Training patients: {len(training_patientids)}")
print(f"Validation patients: {len(validation_patientids)}")
# TODO: Implement the dataset generation function
# This function should:
# 1. Iterate through each patient
# 2. Get splits for each patient
# 3. Convert each split to instruction format
# 4. Return a DataFrame with 'prompt' and 'completion' columns
def generate_transformers_df(patientids_list):
"""
Generate training data from a list of patient IDs.
Args:
patientids_list: List of patient IDs to process
Returns:
pd.DataFrame with columns: prompt, completion, patientid
"""
df = []
# TODO: Implement the function
# HINT: Use dm.get_patient_data(), data_splitter.get_splits_from_patient_with_target(),
# and converter.forward_conversion()
pass
return pd.DataFrame(df)
# Generate datasets
df_train = generate_transformers_df(training_patientids)
df_validation = generate_transformers_df(validation_patientids)
print(f"Training examples: {len(df_train)}")
print(f"Validation examples: {len(df_validation)}")
๐ Checkpoint 2.1: Validate Datasetยถ
# Validate the generated dataset
def validate_dataset(df, name):
errors = []
if df is None or len(df) == 0:
errors.append(f"โ {name} is empty")
elif "prompt" not in df.columns:
errors.append(f"โ {name} missing 'prompt' column")
elif "completion" not in df.columns:
errors.append(f"โ {name} missing 'completion' column")
else:
print(f"โ
{name}: {len(df)} examples")
print(f" Avg prompt length: {df['prompt'].str.len().mean():.0f} chars")
print(f" Avg completion length: {df['completion'].str.len().mean():.0f} chars")
return True
for e in errors:
print(e)
return False
train_valid = validate_dataset(df_train, "Training set")
val_valid = validate_dataset(df_validation, "Validation set")
if train_valid and val_valid:
print("\n๐ Datasets ready for training!")
Part 3: Tokenizer and Data Formattingยถ
LLMs expect data in a specific chat format. Let's set this up.
# TODO: Load the tokenizer for your chosen model
tokenizer = None # Load tokenizer
# TODO: Set the padding token
# HINT: A common approach is to use the EOS token as the padding token
โ Quiz 2: Chat Templatesยถ
Q2.1: What is a "chat template" and why do instruction-tuned models need it?
Q2.2: What roles are typically used in a chat format?
Q2.3: Why do we set completion_only_loss=True during training?
Your Answers:
Q2.1:
Q2.2:
Q2.3:
# TODO: Implement the chat formatting function
# Convert raw prompt/completion to chat message format
def format_chat_template(example):
"""
Convert a single example to chat format.
Args:
example: Dict with 'prompt' and 'completion' keys
Returns:
Dict with 'prompt' as list of user messages and 'completion' as list of assistant messages
"""
# TODO: Implement this
# HINT: Return format should be:
# {"prompt": [{"role": "user", "content": ...}],
# "completion": [{"role": "assistant", "content": ...}]}
pass
# Convert to HuggingFace datasets and apply formatting
train_dataset = Dataset.from_pandas(df_train)
validation_dataset = Dataset.from_pandas(df_validation)
train_dataset = train_dataset.map(format_chat_template)
validation_dataset = validation_dataset.map(format_chat_template)
Part 4: Configure Quantization (QLoRA)ยถ
Quantization allows training large models on limited hardware.
โ Quiz 3: Quantization Understandingยถ
Q3.1: What does "4-bit quantization" mean? What are we quantizing?
Q3.2: What is the trade-off between quantization level (4-bit vs 8-bit vs full precision)?
Q3.3: What does "nf4" (NormalFloat4) quantization type do differently than regular int4?
Your Answers:
Q3.1:
Q3.2:
Q3.3:
# TODO: Configure 4-bit quantization
# Parameters to set:
# - load_in_4bit: Enable 4-bit loading
# - bnb_4bit_quant_type: Use "nf4" for better quality
# - bnb_4bit_compute_dtype: Use torch.bfloat16 for modern GPUs
# - bnb_4bit_use_double_quant: Enable for additional memory savings
bnb_config = BitsAndBytesConfig(
# TODO: Fill in the parameters
)
Part 5: Configure LoRAยถ
LoRA (Low-Rank Adaptation) enables efficient fine-tuning by training only a small number of parameters.
๐งช Experiment 5.1: LoRA Hyperparameter Impactยถ
Before configuring, make predictions about how these hyperparameters affect training:
| Parameter | Your Prediction: What happens if we INCREASE it? |
|---|---|
r (rank) |
|
lora_alpha |
|
lora_dropout |
|
| Number of target modules |
Your Predictions:
r(rank):lora_alpha:lora_dropout:- Number of target modules:
# TODO: Configure LoRA
# Make deliberate choices for each parameter and justify them!
peft_config = LoraConfig(
# TODO: Set lora_alpha (scaling factor, common values: 8, 16, 32)
lora_alpha=None,
# TODO: Set lora_dropout (regularization, common values: 0.05-0.2)
lora_dropout=None,
# TODO: Set r (rank - higher = more parameters, common values: 4, 8, 16, 32)
r=None,
bias="none",
task_type="CAUSAL_LM",
# TODO: Choose which modules to target
# Options:
# - Minimal: ["q_proj", "v_proj"] - faster but less expressive
# - Full attention: ["q_proj", "k_proj", "v_proj", "o_proj"]
# - All linear: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
target_modules=None,
)
# Document your reasoning:
print("My LoRA configuration choices:")
print(f" r={peft_config.r}: [Your reasoning here]")
print(f" lora_alpha={peft_config.lora_alpha}: [Your reasoning here]")
print(f" target_modules={peft_config.target_modules}: [Your reasoning here]")
๐งช Experiment 6.1: Learning Rate Analysisยถ
Q6.1: The tutorial uses learning_rate=1e-4. This is higher than typical full fine-tuning (1e-5 to 5e-5) we have used. Why might PEFT methods benefit from higher learning rates?
Q6.2: What problems might you see if the learning rate is:
- Too high?
- Too low?
Your Answers:
Q6.1:
Q6.2 (too high):
Q6.2 (too low):
# TODO: Configure training arguments
# Think carefully about each parameter!
training_arguments = SFTConfig(
output_dir="./results_challenge",
# TODO: Set number of training epochs (consider: small dataset = more epochs OK)
num_train_epochs=None,
# TODO: Set batch size (limited by GPU memory)
per_device_train_batch_size=None,
# TODO: Set gradient accumulation (effective_batch = batch_size * grad_accum)
gradient_accumulation_steps=None,
# Optimizer settings
optim="paged_adamw_32bit",
# Logging and evaluation
save_steps=10,
logging_steps=10,
eval_strategy="steps",
eval_steps=10,
per_device_eval_batch_size=1,
# TODO: Set learning rate (PEFT typically uses higher LR than full fine-tuning)
learning_rate=None,
# Precision settings
fp16=False, # Set True for older GPUs (V100, T4)
bf16=True, # Set True for newer GPUs (A100, 3090, 4090)
# Regularization
max_grad_norm=1.0,
# TODO: Set warmup ratio (what fraction of training for warmup?)
warmup_ratio=None,
group_by_length=True,
save_total_limit=1,
lr_scheduler_type="cosine",
max_length=MAX_CONTEXT_LENGTH,
packing=False,
completion_only_loss=True,
)
# TODO: Load the base model with quantization
# Parameters needed:
# - Model name (BASE_MODEL)
# - quantization_config (bnb_config)
# - device_map="auto"
# - trust_remote_code=False
model = None # Load the model
# Don't forget: Disable cache for training
# model.config.use_cache = False
# TODO: Create the SFTTrainer
# Parameters needed:
# - model
# - train_dataset
# - processing_class (tokenizer)
# - args (training_arguments)
# - eval_dataset (validation_dataset)
# - peft_config
trainer = None # Create trainer
๐งช Experiment 7.1: Training Observationยถ
Before running training, predict what you expect to see:
Q7.1: What should happen to the training loss over time?
Q7.2: What might it mean if validation loss increases while training loss decreases?
Q7.3: How long do you expect training to take? (Make a guess!)
Your Predictions:
Q7.1:
Q7.2:
Q7.3:
# Run training!
# Note: This takes ~5 minutes on a good GPU
# Watch the training loss and validation loss as it runs
trainer.train()
๐ Post-Training Analysisยถ
After training completes, analyze what happened.
# TODO: Analyze the training results
# Questions to answer:
# 1. What was the final training loss?
# 2. What was the final validation loss?
# 3. Did validation loss ever increase? (sign of overfitting)
# 4. Did your predictions match reality?
print("Training Analysis:")
print("==================")
# Your analysis here...
# Save the adapter
adapter_path = "./results_challenge/final_adapter"
trainer.save_model(adapter_path)
print(f"Adapter saved to {adapter_path}")
# Clean up
del trainer
del model
gc.collect()
torch.cuda.empty_cache()
# Get a test patient
test_patientid = dm.get_all_patientids_in_split(config.test_split_name)[0]
patient_data = dm.get_patient_data(test_patientid)
# Get the date of first line of therapy
df_events_patient = patient_data["events"].copy()
date_of_first_lot = df_events_patient.loc[df_events_patient["event_category"] == "lot", "date"].min()
print(f"Test patient: {test_patientid}")
print(f"First LoT date: {date_of_first_lot}")
๐ง Exercise 8.1: Design Your Prediction Taskยถ
Choose what you want to predict for this patient.
# TODO: Design your forecasting task
# What variable do you want to forecast? At what time points?
# Example structure:
# forecasting_times_to_predict = {
# "variable_name": [week1, week2, week3, ...]
# }
forecasting_times_to_predict = {
# TODO: Fill in - choose a lab value and time points
}
# Get inference splits
forecast_split, events_split = data_splitter.get_splits_from_patient_inference(
patient_data,
inference_type="both",
# TODO: Set the variable(s) to predict
forecasting_override_variables_to_predict=None, # List of variable names
# TODO: Set the event to predict
events_override_category=None, # e.g., "death"
events_override_observation_time_delta=pd.Timedelta(days=52 * 7), # 1 year
)
# TODO: Convert to instruction format for inference
# Use converter.forward_conversion_inference()
converted = None # Your code here
# Print the prompt to see what we're asking the model
print("=" * 50)
print("INFERENCE PROMPT:")
print("=" * 50)
print(converted["instruction"][:2000]) # First 2000 chars
print("\n... [truncated]")
๐งช Experiment 8.2: Prediction Before Runningยถ
Based on the patient's history in the prompt, make your own predictions:
Q8.1: What do you predict for the forecasted values?
Q8.2: What do you predict for the time-to-event?
Q8.3: How confident are you in these predictions? Why?
Your Predictions:
Q8.1:
Q8.2:
Q8.3:
# TODO: Load the base model and adapter for inference
# 1. Load base model with quantization
base_model_inference = None # Your code
# 2. Load the saved adapter
inference_model = None # Use PeftModel.from_pretrained()
# 3. Set to evaluation mode
# inference_model.eval()
# TODO: Create text generation pipeline
inference_model.config.use_cache = True
text_gen_pipeline = None # Create pipeline("text-generation", ...)
# TODO: Generate prediction
# Use the pipeline with appropriate generation parameters:
# - max_new_tokens: 128 is usually enough
# - return_full_text: False (we only want the generated part)
# - do_sample: True (for nucleus sampling)
# - temperature: 0.7 (controls randomness)
# - top_p: 0.9 (nucleus sampling threshold)
generated_answer = None # Your generation code
# Show the generated answer
print("=" * 50)
print("MODEL PREDICTION:")
print("=" * 50)
print(generated_answer)
# TODO: Reverse convert to structured data
return_list = None # Use converter.reverse_conversion()
# Display structured results
for i, result in enumerate(return_list):
print(f"\nTask {i + 1}:")
print(result["result"])
๐ Final Analysisยถ
Compare the model's predictions to your own and reflect on the results.
# TODO: Write your analysis
# 1. How did the model's predictions compare to yours?
# 2. Are the predictions clinically reasonable?
# 3. What improvements would you suggest?
print("""
Final Analysis:
===============
1. Comparison to my predictions:
[Your analysis here]
2. Clinical reasonableness:
[Your analysis here]
3. Suggested improvements:
[Your analysis here]
""")
๐ Bonus Challenge 1: Hyperparameter Experimentยถ
+20 points
Train two more models with different LoRA configurations:
- Low rank (r=4) with minimal target modules
- High rank (r=32) with all linear modules
Compare their performance on the same test patient.
# BONUS: Implement your hyperparameter experiment
๐ Bonus Challenge 2: Multi-Sample Inferenceยถ
+15 points
Generate multiple predictions (N=5) for the same patient using different random seeds. Analyze the variance in predictions. What does this tell you about model confidence?
# BONUS: Implement multi-sample inference and analyze variance
๐ Bonus Challenge 3: Evaluation Frameworkยถ
+25 points
Build an evaluation framework that:
- Runs inference on all test patients
- Compares predictions to ground truth
- Computes metrics (MAE for forecasting, accuracy for events)
- Generates a summary report
# BONUS: Implement the evaluation framework
๐ Challenge Complete!ยถ
Congratulations on completing the advanced challenge! You've learned to:
- โ Generate training datasets from clinical data
- โ Configure quantization for memory-efficient training
- โ Design and justify LoRA hyperparameter choices
- โ Fine-tune an LLM for medical forecasting
- โ Run inference and analyze predictions
- โ Convert model outputs back to structured data
๐ Reflection Questionsยถ
Take a moment to reflect on what you learned:
- What was the most challenging part of this challenge?
- What would you do differently if you had more compute resources?
- How would you adapt this pipeline for a different clinical task?