Forecasting Inference with TwinWeaver and vLLM¶
This notebook demonstrates how to use TwinWeaver's forecasting inference pipeline to predict future clinical values (e.g., lab results like hemoglobin) using a fine-tuned LLM served via vLLM.
Unlike the TTE (time-to-event) probability pipeline, which scores fixed completions, this pipeline uses free-text generation: the model produces an answer string that is then reverse-converted back into a structured DataFrame with predicted dates and values.
Pipeline overview¶
Patient data ──► DataSplitter (forecasting) ──► ConverterInstruction
│
instruction text per patient
│
┌──────────────────────────────┘
▼
vLLM server (OpenAI-compatible API)
│
generated text completions
│
▼
parse_forecasting_results()
(calls converter.reverse_conversion internally)
│
structured DataFrame
with predicted dates & values
⚠️ Important: The quality of the forecasts depends critically on having a fine-tuned model. An off-the-shelf instruction model (like the default
microsoft/Phi-4-mini-instructused here for demonstration) will produce meaningless predictions. Always fine-tune on your clinical dataset first (see03_end_to_end_llm_finetuning.ipynb).
Requirements:
- A GPU with enough memory to serve the model via vLLM (≥16 GB for a 4-bit 8B model)
pip install twinweaver[fine-tuning-example] vllm openai
import subprocess
import time
import sys
import os
import pandas as pd
from twinweaver import (
DataManager,
Config,
DataSplitterForecasting,
DataSplitterEvents,
DataSplitter,
ConverterInstruction,
run_forecasting_inference_notebook,
parse_forecasting_results,
)
1. Configuration¶
We define all key settings up front so they are easy to change in one place.
Note: Replace
MODEL_PATHwith the path to your fine-tuned model for meaningful results. The defaultmicrosoft/Phi-4-mini-instructis only used here so that the notebook is self-contained.
# ---------------------------------------------------------------------------
# Model & server settings
# ---------------------------------------------------------------------------
MODEL_PATH = "microsoft/Phi-4-mini-instruct" # ⚠️ Replace with your fine-tuned model path!!!!!!!!!!!
TOKENIZER_PATH = MODEL_PATH # Usually the same as the model path
VLLM_PORT = 8000
PREDICTION_URL = f"http://0.0.0.0:{VLLM_PORT}/v1/"
MAX_CONTEXT_LENGTH = 4096 # Must match what the model was trained with
# Generation settings
MAX_NEW_TOKENS = 256 # Max tokens for the generated forecast answer
TEMPERATURE = 0.9 # Sampling temperature (0 = greedy)
TOP_P = 0.9 # Nucleus sampling
N_SAMPLES = 3 # Number of independent samples per patient (>1 enables aggregation)
# ---------------------------------------------------------------------------
# TwinWeaver data settings (same as used during training)
# ---------------------------------------------------------------------------
config = Config()
# 1. Event category used for data splitting
config.split_event_category = "lot"
# 2. Forecasting categories
config.event_category_forecast = ["lab"]
# 3. Time-to-event variables (needed to initialise splitters, even if we only forecast)
config.event_category_events_prediction_with_naming = {
"death": "death",
"progression": "next progression",
}
# 4. Constant (static) columns
config.constant_columns_to_use = ["birthyear", "gender", "histology", "smoking_history"]
config.constant_birthdate_column = "birthyear"
2. Load and prepare data¶
We use the same example data shipped with TwinWeaver. In a real scenario you would load your own clinical dataset here.
# Load example data (adjust paths if running from a different directory)
df_events = pd.read_csv("../../example_data/events.csv")
df_constant = pd.read_csv("../../example_data/constant.csv")
df_constant_description = pd.read_csv("../../example_data/constant_description.csv")
print(f"Loaded {len(df_events)} events for {df_events['patientid'].nunique()} patients")
# Initialise DataManager and all splitters
dm = DataManager(config=config)
dm.load_indication_data(
df_events=df_events,
df_constant=df_constant,
df_constant_description=df_constant_description,
)
dm.process_indication_data()
dm.setup_unique_mapping_of_events()
dm.setup_hold_out_sets(validation_split=0.1, test_split=0.1)
dm.infer_var_types()
data_splitter_events = DataSplitterEvents(
dm,
config=config,
max_length_to_sample=pd.Timedelta(weeks=104),
min_length_to_sample=pd.Timedelta(weeks=1),
)
data_splitter_events.setup_variables()
data_splitter_forecasting = DataSplitterForecasting(
data_manager=dm,
config=config,
max_forecasted_trajectory_length=pd.Timedelta(days=90),
)
data_splitter_forecasting.setup_statistics()
# Combined interface
data_splitter = DataSplitter(data_splitter_events, data_splitter_forecasting)
converter = ConverterInstruction(
nr_tokens_budget_total=MAX_CONTEXT_LENGTH,
config=config,
dm=dm,
variable_stats=data_splitter_forecasting.variable_stats,
)
print("✅ Data pipeline ready")
3. Generate forecasting instruction prompts¶
For each test patient we generate a forecasting-only instruction prompt. The key parameters are:
forecasting_override_variables_to_predict– which variables to forecast (e.g. hemoglobin)forecasting_future_weeks_per_variable– at which future time points (in weeks) to request predictions
The converter produces a text instruction asking the model to predict the specified variable(s) at the given future time points.
# Choose patients to evaluate (here: all test-set patients)
test_patientids = dm.get_all_patientids_in_split(config.test_split_name)
print(f"Number of test patients: {len(test_patientids)}")
# Define the prediction task:
# Which variables to predict and at which future week offsets
VARIABLES_TO_PREDICT = ["hemoglobin_-_718-7"]
FUTURE_WEEKS = {
"hemoglobin_-_718-7": [4, 8, 12], # Predict hemoglobin at 4, 8, and 12 weeks
}
# Build the list of prompt payloads for the forecasting API
# Each payload is a dict with "patientid", "instruction", and "split_date"
prompts_with_meta: list[dict] = []
for pid in test_patientids:
patient_data = dm.get_patient_data(pid)
# Use only data up to first lot event (simulates baseline information)
patient_data["events"] = patient_data["events"].sort_values("date")
first_lot_date = patient_data["events"][patient_data["events"]["event_category"] == "lot"]["date"].min()
assert pd.notna(first_lot_date), f"Patient {pid} has no lot event"
patient_data["events"] = patient_data["events"][patient_data["events"]["date"] <= first_lot_date].copy()
# Generate the forecasting-only split for inference
forecast_split, _ = data_splitter.get_splits_from_patient_inference(
patient_data,
inference_type="forecasting",
forecasting_override_variables_to_predict=VARIABLES_TO_PREDICT,
)
# Convert to instruction text (no target answer)
converted = converter.forward_conversion_inference(
forecasting_split=forecast_split,
forecasting_future_weeks_per_variable=FUTURE_WEEKS,
)
# Collect the prompt payload
prompts_with_meta.append(
{
"patientid": pid,
"instruction": converted["instruction"],
"split_date": forecast_split.split_date_included_in_input,
}
)
print(f"Generated {len(prompts_with_meta)} instruction prompts")
# Let's inspect one instruction to see what the model will receive
sample = prompts_with_meta[0]
print(f"=== Patient: {sample['patientid']} ===")
print(f"Split date: {sample['split_date']}\n")
print(sample["instruction"])
4. Launch the vLLM server¶
We launch a vLLM OpenAI-compatible server as a background process.
If you already have a vLLM server running, skip this cell and just update
PREDICTION_URLandMODEL_PATHin the configuration section above.
Tip: For production use, launch the server in a separate terminal.
# Launch vLLM server as a background subprocess
# Set to True to skip launching (if you already have a server running)
SKIP_VLLM_LAUNCH = False
vllm_process = None
if not SKIP_VLLM_LAUNCH:
env = os.environ.copy()
env["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN"
vllm_command = [
sys.executable,
"-m",
"vllm.entrypoints.openai.api_server",
"--port",
str(VLLM_PORT),
"--model",
MODEL_PATH,
"--tokenizer",
TOKENIZER_PATH,
"--enable-prefix-caching",
]
print(f"🚀 Launching vLLM server:\n {' '.join(vllm_command)}\n")
vllm_process = subprocess.Popen(
vllm_command,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
# Wait for the server to be ready
WAIT_SECONDS = 240
print(f"⏳ Waiting up to {WAIT_SECONDS}s for the server to start...")
import urllib.request
for i in range(WAIT_SECONDS):
time.sleep(1)
try:
urllib.request.urlopen(f"http://localhost:{VLLM_PORT}/health")
print(f"✅ vLLM server is ready after {i + 1}s")
break
except Exception:
pass
else:
print("⚠️ Server did not respond in time. Check GPU memory and logs.")
print(" You can read server output with: vllm_process.stdout.read()")
else:
print("Skipping vLLM launch – using existing server.")
5. Run forecasting inference¶
This is the core step. run_forecasting_inference_notebook sends each patient's
instruction to the vLLM server and collects the generated text completions.
Under the hood it:
- Wraps each instruction in a chat message (with optional system prompt).
- Calls the OpenAI-compatible
/v1/chat/completionsendpoint. - Returns the generated text(s) alongside patient metadata.
When n_samples > 1, multiple independent completions are generated per
patient, which can later be aggregated into a mean trajectory.
# Run the forecasting inference
# This calls the vLLM server asynchronously for all patients
raw_results = await run_forecasting_inference_notebook(
prompts_with_meta,
prediction_url=PREDICTION_URL,
prediction_model=MODEL_PATH,
max_concurrent_requests=40,
max_new_tokens=MAX_NEW_TOKENS,
temperature=TEMPERATURE,
top_p=TOP_P,
n_samples=N_SAMPLES,
api_key="EMPTY",
timeout=600.0,
)
# Check for failures
n_success = sum(1 for r in raw_results if r is not None)
n_fail = sum(1 for r in raw_results if r is None)
print(f"✅ Generated forecasts for {n_success} patients, {n_fail} failures")
# Let's inspect the raw generated text for one patient
for r in raw_results:
if r is not None:
print(f"=== Patient: {r['patientid']} ===")
for i, text in enumerate(r["generated_texts"]):
print(f"\n--- Sample {i} ---")
print(text)
break
6. Parse results with reverse conversion¶
parse_forecasting_results takes the raw generated texts and:
- Calls
converter.reverse_conversionon each generated text to parse it back into a structured DataFrame with dates and predicted values. - When
n_samples > 1andaggregate_samples=True, aggregates multiple trajectories usingconverter.aggregate_multiple_responses(e.g. averaging numeric predictions). - Returns a single long-format DataFrame with all patients' predictions.
Note: Reverse conversion is robust to slightly malformed model output thanks to
inference_override=True, but a fine-tuned model will produce much more parseable results than a generic instruction model.
converter = ConverterInstruction(
nr_tokens_budget_total=MAX_CONTEXT_LENGTH,
config=config,
dm=dm,
variable_stats=data_splitter_forecasting.variable_stats,
)
print("✅ Converter reloaded with latest code")
# Parse the generated texts into structured DataFrames
df_results = parse_forecasting_results(
raw_results,
converter,
dm,
drop_failures=True,
aggregate_samples=(N_SAMPLES > 1), # Only aggregate if we have multiple samples
)
if df_results.empty:
print("⚠️ No predictions could be parsed. This is expected when using a non-fine-tuned model.")
print(" Fine-tune a model first (see 03_end_to_end_llm_finetuning.ipynb) for meaningful results.")
else:
print(f"Parsed {len(df_results)} prediction rows for {df_results['patientid'].nunique()} patients")
df_results.head(20)
Understanding the output¶
The returned DataFrame has the standard TwinWeaver event format:
| Column | Description |
|---|---|
date |
Predicted date (computed from split_date + week offset) |
event_name |
The variable being predicted (e.g. hemoglobin_-_718-7) |
event_value |
The predicted value |
event_category |
Category of the event (e.g. lab) |
patientid |
Patient identifier |
task_type |
Which task type produced this prediction |
sample_idx |
Sample index (when aggregate_samples=False and n_samples > 1) |
7. Multi-sample aggregation (optional)¶
When using n_samples > 1, each patient gets multiple independent forecast
trajectories. Aggregation (e.g. averaging numeric predictions) can reduce
variance and give more robust estimates.
Below is an example of running with multiple samples and then aggregating.
# Example: generate 3 samples per patient and aggregate
N_SAMPLES_AGG = 3
raw_results_multi = await run_forecasting_inference_notebook(
prompts_with_meta,
prediction_url=PREDICTION_URL,
prediction_model=MODEL_PATH,
max_concurrent_requests=40,
max_new_tokens=MAX_NEW_TOKENS,
temperature=TEMPERATURE,
top_p=TOP_P,
n_samples=N_SAMPLES_AGG,
api_key="EMPTY",
)
# Parse with aggregation enabled
df_aggregated = parse_forecasting_results(
raw_results_multi,
converter,
dm,
drop_failures=True,
aggregate_samples=True, # Average numeric values across samples
)
if df_aggregated.empty:
print("⚠️ No predictions could be parsed. This is expected when using a non-fine-tuned model.")
else:
print(f"Aggregated results: {len(df_aggregated)} rows for {df_aggregated['patientid'].nunique()} patients")
df_aggregated.head(20)
# You can also get the individual (non-aggregated) samples for deeper analysis
df_individual = parse_forecasting_results(
raw_results_multi,
converter,
dm,
drop_failures=True,
aggregate_samples=False, # Keep individual samples
)
if df_individual.empty:
print("⚠️ No predictions could be parsed. This is expected when using a non-fine-tuned model.")
else:
print(f"Individual results: {len(df_individual)} rows")
print(f"Samples per patient: {df_individual.groupby('patientid')['sample_idx'].nunique().describe()}")
df_individual.head(20)
8. Clean up¶
Shut down the vLLM server if we launched it from this notebook.
if vllm_process is not None:
print("Terminating vLLM server...")
vllm_process.terminate()
vllm_process.wait(timeout=10)
print("✅ Server stopped.")
Summary¶
This notebook demonstrated the full forecasting inference workflow:
- Data preparation – Load clinical data and generate forecasting instruction prompts
- Model serving – Launch (or connect to) a vLLM OpenAI-compatible server
- Text generation – Use
run_forecasting_inference_notebookto generate completions - Reverse conversion – Use
parse_forecasting_resultsto convert text → structured DataFrame - Aggregation – Optionally average multiple samples for more robust predictions
Key functions¶
| Function | Purpose |
|---|---|
run_forecasting_inference() |
Generate completions for all patients (sync wrapper) |
run_forecasting_inference_notebook() |
Same but async – for notebooks |
parse_forecasting_results() |
Reverse-convert generated text → structured DataFrame |
Comparison with TTE probability inference¶
| Aspect | TTE Inference | Forecasting Inference |
|---|---|---|
| Method | Log-prob scoring of fixed completions | Free-text generation |
| Output | Probabilities (censored/occurred/not occurred) | Predicted values at future time points |
| API endpoint | /v1/completions (scoring) |
/v1/chat/completions (generation) |
| Post-processing | Softmax over log-probs | Reverse conversion (text → DataFrame) |
| Multi-sample | Not applicable | Average trajectories via aggregate_multiple_responses |
Next steps¶
- Fine-tune a model on your dataset using
03_end_to_end_llm_finetuning.ipynb - Combine with TTE by using
inference_type="both"in the data splitter - Evaluate predictions against ground truth to assess model performance
- Experiment with different variables, time horizons, and sample counts