Time-to-Event (TTE) Probability Inference with TwinWeaver¶
This notebook demonstrates how to use TwinWeaver's TTE inference pipeline to estimate the probability of clinical events (e.g., death, disease progression) using a fine-tuned LLM served via vLLM.
Instead of generating free-text answers, this pipeline scores three mutually exclusive completions for each patient:
| Outcome | Meaning |
|---|---|
| Censored | The patient's observation window ended before the event could be observed |
| Occurred | The event happened within the observation window |
| Not occurred | The event did not happen within the observation window |
The model returns per-token log-probabilities for each completion, which are then length-normalised and passed through a softmax to yield calibrated probabilities summing to 1.
Pipeline overview¶
Patient data ──► DataSplitter (events) ──► ConverterInstruction
│
instruction text per patient
│
┌──────────────────────────┘
▼
vLLM server (OpenAI-compatible API)
│
log-probs for 3 completions
│
▼
compute_length_normalized_probabilities()
│
calibrated probabilities
+ hard predictions (DataFrame)
⚠️ Important: The quality of the probability estimates depends critically on having a fine-tuned model. An off-the-shelf instruction model (like the default
microsoft/Phi-4-mini-instructused here for demonstration) will produce random / meaningless probabilities. Always fine-tune on your clinical dataset first (see03_end_to_end_llm_finetuning.ipynb).
Requirements:
- A GPU with enough memory to serve the model via vLLM (≥16 GB for a 4-bit 8B model)
pip install twinweaver[fine-tuning-example] vllm openai
import subprocess
import time
import sys
import os
import pandas as pd
from transformers import AutoTokenizer
from twinweaver import (
DataManager,
Config,
DataSplitterForecasting,
DataSplitterEvents,
DataSplitter,
ConverterInstruction,
run_tte_probability_estimation_notebook,
compute_length_normalized_probabilities,
)
1. Configuration¶
We define all key settings up front so they are easy to change in one place.
Note: Replace
MODEL_PATHwith the path to your fine-tuned model for meaningful results. The defaultmicrosoft/Phi-4-mini-instructis only used here so that the notebook is self-contained.
# ---------------------------------------------------------------------------
# Model & server settings
# ---------------------------------------------------------------------------
MODEL_PATH = "microsoft/Phi-4-mini-instruct" # ⚠️ Replace with your fine-tuned model path
TOKENIZER_PATH = MODEL_PATH # Usually the same as the model path
VLLM_PORT = 8000
PREDICTION_URL = f"http://0.0.0.0:{VLLM_PORT}/v1/"
MAX_CONTEXT_LENGTH = 4096 # Must match what the model was trained with
# ---------------------------------------------------------------------------
# TwinWeaver data settings (same as used during training)
# ---------------------------------------------------------------------------
config = Config()
# 1. Event category used for data splitting
config.split_event_category = "lot"
# 2. Forecasting categories (needed to initialise splitters, even if not used for TTE)
config.event_category_forecast = ["lab"]
# 3. Time-to-event variables to predict
# Keys = event_category values in the events DataFrame
# Values = human-readable name used in the prompt
config.data_splitter_events_variables_category_mapping = {
"death": "death",
"progression": "next progression",
}
# 4. Constant (static) columns
config.constant_columns_to_use = ["birthyear", "gender", "histology", "smoking_history"]
config.constant_birthdate_column = "birthyear"
2. Load and prepare data¶
We use the same example data shipped with TwinWeaver. In a real scenario you would load your own clinical dataset here.
# Load example data (adjust paths if running from a different directory)
df_events = pd.read_csv("../../example_data/events.csv")
df_constant = pd.read_csv("../../example_data/constant.csv")
df_constant_description = pd.read_csv("../../example_data/constant_description.csv")
print(f"Loaded {len(df_events)} events for {df_events['patientid'].nunique()} patients")
# Initialise DataManager and all splitters
dm = DataManager(config=config)
dm.load_indication_data(
df_events=df_events,
df_constant=df_constant,
df_constant_description=df_constant_description,
)
dm.process_indication_data()
dm.setup_unique_mapping_of_events()
dm.setup_dataset_splits()
dm.infer_var_types()
data_splitter_events = DataSplitterEvents(dm, config=config)
data_splitter_events.setup_variables()
data_splitter_forecasting = DataSplitterForecasting(data_manager=dm, config=config)
data_splitter_forecasting.setup_statistics()
# Combined interface
data_splitter = DataSplitter(data_splitter_events, data_splitter_forecasting)
converter = ConverterInstruction(
nr_tokens_budget_total=MAX_CONTEXT_LENGTH,
config=config,
dm=dm,
variable_stats=data_splitter_forecasting.variable_stats,
)
print("✅ Data pipeline ready")
3. Generate TTE instruction prompts¶
For each patient we want to score, we generate an events-only instruction prompt. The key parameters are:
events_override_category– which event to predict (e.g."death")events_override_observation_time_delta– the prediction horizon
The converter produces a text instruction that asks the model to predict whether the event was censored, occurred, or did not occur within the given window.
# Choose patients to evaluate (here: all test-set patients)
test_patientids = dm.get_all_patientids_in_split(config.test_split_name)
print(f"Number of test patients: {len(test_patientids)}")
# Define the prediction task
EVENT_TO_PREDICT = "death"
OBSERVATION_WEEKS = 52 # predict death within 52 weeks
observation_delta = pd.Timedelta(weeks=OBSERVATION_WEEKS)
# Build (patientid, instruction) pairs for the TTE scoring API
instructions_with_ids: list[tuple[str, str]] = []
for pid in test_patientids:
patient_data = dm.get_patient_data(pid)
#: use only up to first lot event, similar to baseline information in trials
patient_data["events"] = patient_data["events"].sort_values("date")
first_lot_event_date = patient_data["events"][patient_data["events"]["event_category"] == "lot"]["date"].min()
assert pd.notna(first_lot_event_date), f"Patient {pid} has no lot event"
patient_data["events"] = patient_data["events"][patient_data["events"]["date"] <= first_lot_event_date].copy()
# Generate the events-only split for inference
_, events_split = data_splitter.get_splits_from_patient_inference(
patient_data,
inference_type="events",
events_override_category=EVENT_TO_PREDICT,
events_override_observation_time_delta=observation_delta,
)
# Convert to instruction text (no target answer)
converted = converter.forward_conversion_inference(
event_split=events_split,
)
instructions_with_ids.append((pid, converted["instruction"]))
print(f"Generated {len(instructions_with_ids)} instruction prompts")
# Let's inspect one instruction to see what the model will receive
sample_pid, sample_instruction = instructions_with_ids[0]
print(f"=== Patient: {sample_pid} ===\n")
print(sample_instruction)
4. Launch the vLLM server¶
We launch a vLLM OpenAI-compatible server as a background process.
The server must support echo=True and logprobs in the completions endpoint,
which vLLM does by default.
If you already have a vLLM server running, skip this cell and just update
PREDICTION_URLandMODEL_PATHin the configuration section above.
Tip: For production use, launch the server in a separate terminal.
# Launch vLLM server as a background subprocess
# Set to True to skip launching (if you already have a server running)
SKIP_VLLM_LAUNCH = False
vllm_process = None
if not SKIP_VLLM_LAUNCH:
env = os.environ.copy()
env["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN"
vllm_command = [
sys.executable,
"-m",
"vllm.entrypoints.openai.api_server",
"--port",
str(VLLM_PORT),
"--model",
MODEL_PATH,
"--tokenizer",
TOKENIZER_PATH,
"--enable-prefix-caching",
]
print(f"🚀 Launching vLLM server:\n {' '.join(vllm_command)}\n")
vllm_process = subprocess.Popen(
vllm_command,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
# Wait for the server to be ready
WAIT_SECONDS = 240
print(f"⏳ Waiting up to {WAIT_SECONDS}s for the server to start...")
import urllib.request
for i in range(WAIT_SECONDS):
time.sleep(1)
try:
urllib.request.urlopen(f"http://localhost:{VLLM_PORT}/health")
print(f"✅ vLLM server is ready after {i + 1}s")
break
except Exception:
pass
else:
print("⚠️ Server did not respond in time. Check GPU memory and logs.")
print(" You can read server output with: vllm_process.stdout.read()")
else:
print("Skipping vLLM launch – using existing server.")
5. Run TTE probability estimation¶
This is the core step. run_tte_probability_estimation sends each patient's
instruction to the vLLM server three times (once per outcome) and collects the
per-token log-probabilities of each completion.
Under the hood it:
- Uses
build_scored_promptto constructprompt_prefix + completion_suffixfor each of the 3 outcomes. - Calls the OpenAI-compatible
/v1/completionsendpoint withmax_tokens=0andecho=Trueto score (not generate) each completion. - Slices the returned log-probs to keep only the completion tokens.
# Load tokenizer (must match the served model)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)
print(f"Tokenizer loaded: {TOKENIZER_PATH}")
print(f"Vocab size: {tokenizer.vocab_size}")
# Run the TTE probability estimation
# This calls the vLLM server asynchronously for all patients
raw_results = await run_tte_probability_estimation_notebook(
instructions_with_ids,
tokenizer,
config,
prediction_url=PREDICTION_URL,
prediction_model=MODEL_PATH,
max_concurrent_requests=40, # Adjust based on your GPU memory / server capacity
api_key="EMPTY", # Use "EMPTY" for local vLLM servers
timeout=600.0,
)
# Check for failures
n_success = sum(1 for r in raw_results if r is not None)
n_fail = sum(1 for r in raw_results if r is None)
print(f"✅ Scored {n_success} patients, {n_fail} failures")
6. Compute length-normalised probabilities¶
compute_length_normalized_probabilities takes the raw per-token log-probs and:
- Length-normalises by computing the mean log-prob for each completion (so longer completions are not unfairly penalised).
- Applies a softmax across the three outcomes to get calibrated probabilities that sum to 1.
- Derives a hard prediction by selecting the outcome with the highest probability.
NOTE: As shown in our paper, these are not well calibrated probabilities, so future research should explore better ways of calculating this quanitity.
# Convert raw log-probs to probabilities
df_results = compute_length_normalized_probabilities(raw_results, drop_failures=True)
# Show the key columns
display_cols = [
"patientid",
"probability_occurrence",
"probability_no_occurrence",
"probability_censored",
"censored",
"occurred",
]
df_results[display_cols]
Understanding the output columns¶
| Column | Description |
|---|---|
probability_occurrence |
Softmax probability that the event occurred within the window |
probability_no_occurrence |
Softmax probability that the event did not occur |
probability_censored |
Softmax probability that the patient was censored |
censored |
Hard prediction: True if censored has the highest probability |
occurred |
Hard prediction: True if occurred has the highest probability |
The intermediate columns (avg_logprob_*, softmax_*) are also available for
deeper analysis.
# Summary statistics
print("=== Prediction Summary ===")
print(f"Total patients scored: {len(df_results)}")
print(f"Predicted occurred: {df_results['occurred'].sum()}")
print(f"Predicted not occurred: {(~df_results['occurred'] & ~df_results['censored']).sum()}")
print(f"Predicted censored: {df_results['censored'].sum()}")
print(f"\nMean P(occurrence): {df_results['probability_occurrence'].mean():.4f}")
print(f"Mean P(no occurrence): {df_results['probability_no_occurrence'].mean():.4f}")
print(f"Mean P(censored): {df_results['probability_censored'].mean():.4f}")
7. Evaluate multiple time horizons (optional)¶
In practice you may want to evaluate the same event across several observation windows (e.g. 8, 26, 52, and 104 weeks). The code below shows how to loop over multiple horizons and collect all results into a single DataFrame.
EVALUATION_HORIZONS_WEEKS = [8, 26, 52, 104]
all_horizon_results = []
for weeks in EVALUATION_HORIZONS_WEEKS:
delta = pd.Timedelta(weeks=weeks)
# Build instructions for this horizon
horizon_instructions = []
for pid in test_patientids:
patient_data = dm.get_patient_data(pid)
# Get for first lot event as baseline, similar to trial inclusion criteria
patient_data["events"] = patient_data["events"].sort_values("date")
first_lot_event_date = patient_data["events"][patient_data["events"]["event_category"] == "lot"]["date"].min()
assert pd.notna(first_lot_event_date), f"Patient {pid} has no lot event"
patient_data["events"] = patient_data["events"][patient_data["events"]["date"] <= first_lot_event_date].copy()
# Get splits
_, events_split = data_splitter.get_splits_from_patient_inference(
patient_data,
inference_type="events",
events_override_category=EVENT_TO_PREDICT,
events_override_observation_time_delta=delta,
)
# Generate the actual prompt
converted = converter.forward_conversion_inference(event_split=events_split)
# Save the prompt
horizon_instructions.append((f"{pid}", converted["instruction"]))
# Get predicted probabilities from the model for this horizon
raw = await run_tte_probability_estimation_notebook(
horizon_instructions,
tokenizer,
config,
prediction_url=PREDICTION_URL,
prediction_model=MODEL_PATH,
)
# Post-process
df_horizon = compute_length_normalized_probabilities(raw, drop_failures=True)
df_horizon["week_horizon"] = weeks
all_horizon_results.append(df_horizon)
n_ok = sum(1 for r in raw if r is not None)
print(f" Week {weeks:>3d}: scored {n_ok} patients")
df_all_horizons = pd.concat(all_horizon_results, ignore_index=True)
print(f"\n✅ Total rows: {len(df_all_horizons)}")
df_all_horizons = df_all_horizons.sort_values(["patientid", "week_horizon"])
df_all_horizons[
["patientid", "week_horizon", "probability_occurrence", "probability_no_occurrence", "probability_censored"]
]
9. Clean up¶
Shut down the vLLM server if we launched it from this notebook.
if vllm_process is not None:
print("Terminating vLLM server...")
vllm_process.terminate()
vllm_process.wait(timeout=10)
print("✅ Server stopped.")
Summary¶
This notebook demonstrated the full TTE probability inference workflow:
- Data preparation – Load clinical data and generate events-only instruction prompts
- Model serving – Launch (or connect to) a vLLM OpenAI-compatible server
- Log-prob scoring – Use
run_tte_probability_estimationto score three completions per patient - Post-processing – Use
compute_length_normalized_probabilitiesto get calibrated probabilities - Analysis – Inspect, visualise, and evaluate predictions across time horizons
Key functions¶
| Function | Purpose |
|---|---|
run_tte_probability_estimation() |
Score all patients against the API (sync wrapper) |
compute_length_normalized_probabilities() |
Convert raw log-probs → softmax probabilities |
build_scored_prompt() |
Build prompt prefix + 3 completion suffixes (useful for debugging) |
Next steps¶
- Fine-tune a model on your dataset using
03_end_to_end_llm_finetuning.ipynb - Evaluate predictions against ground truth using landmark analysis
- Experiment with different observation windows and event categories