TTE Inference¶
twinweaver.utils.tte_inference ¶
Time-to-event (TTE) inference helpers for probability estimation.
Provides functions to estimate the length-normalized probabilities of three mutually exclusive TTE outcomes (censored, occurred, not occurred) by scoring each completion against an OpenAI-compatible API (e.g. a local vLLM server).
The prompt construction is driven entirely by a :class:~twinweaver.common.config.Config
instance, so the same code works for any dataset / prompt template.
Typical usage
import asyncio from transformers import AutoTokenizer from twinweaver.common.config import Config from twinweaver.utils.tte_inference import ( ... run_tte_probability_estimation, ... compute_length_normalized_probabilities, ... ) config = Config() tokenizer = AutoTokenizer.from_pretrained("my-model") data = [("patient_1", "instruction text ...")] raw = asyncio.run(run_tte_probability_estimation( ... data, tokenizer, config, prediction_url="http://localhost:8000/v1/", ... prediction_model="my-model")) df = compute_length_normalized_probabilities(raw)
Classes¶
Functions¶
build_scored_prompt ¶
Construct the full prompt prefix and the three scored completions.
This is the model-agnostic equivalent of the prompt assembly that was
previously hard-coded for Llama-3.1 chat tokens. It uses the tokenizer's
apply_chat_template when available, falling back to a simple
concatenation otherwise.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
instruction
|
str
|
The user-facing instruction text (one patient / time-point). |
required |
tokenizer
|
Any
|
A HuggingFace-compatible tokenizer (must support |
required |
config
|
Config
|
The configuration object. |
required |
system_prompt
|
str or None
|
An optional system prompt. When None the chat template is built without a system message. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
prompt_prefix |
str
|
The fully assembled prompt up to (and including) the
|
completions |
list[tuple[str, str]]
|
The three |
Notes
The slicing index (to separate prefix tokens from completion tokens) is not computed here because BPE tokenizers can merge tokens across the prefix/suffix boundary. The caller should compute the slicing index from the full concatenated prompt instead.
Source code in twinweaver/utils/tte_inference.py
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 | |
compute_length_normalized_probabilities ¶
Convert raw per-token log-probs into length-normalized softmax probabilities.
For each patient the function:
- Computes the mean log-probability (length-normalised) of each of the three completion strings.
- Applies a softmax over the three mean log-probs to obtain calibrated probabilities that sum to one.
- Derives a hard prediction (
censored,occurredcolumns) by picking the class with the highest softmax score.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
raw_results
|
list[dict or None]
|
Output of :func: |
required |
drop_failures
|
bool
|
If True (default), silently drop |
False
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
Columns: |
Raises:
| Type | Description |
|---|---|
ValueError
|
If drop_failures is False and any result is |
Source code in twinweaver/utils/tte_inference.py
553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 | |
run_tte_probability_estimation ¶
run_tte_probability_estimation(
instructions_with_ids,
tokenizer,
config,
*,
prediction_url="http://0.0.0.0:8000/v1/",
prediction_model="default-model",
max_concurrent_requests=40,
system_prompt=None,
api_key="EMPTY",
timeout=600.0
)
Score all patients against an OpenAI-compatible API and return raw log-probs.
This is the main entry-point for running TTE probability estimation.
It is synchronous (calls asyncio.run internally) so it can be used
from plain scripts or notebooks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
instructions_with_ids
|
list[tuple[str, str]]
|
Each element is |
required |
tokenizer
|
Any
|
HuggingFace-compatible tokenizer (used for token counting). |
required |
config
|
Config
|
TwinWeaver configuration object. |
required |
prediction_url
|
str
|
Base URL of the OpenAI-compatible inference server. |
'http://0.0.0.0:8000/v1/'
|
prediction_model
|
str
|
Model name / path served by the inference server. |
'default-model'
|
max_concurrent_requests
|
int
|
Maximum number of concurrent API requests. |
40
|
system_prompt
|
str or None
|
Optional system prompt. |
None
|
api_key
|
str
|
API key ( |
'EMPTY'
|
timeout
|
float
|
Per-request timeout in seconds. |
600.0
|
Returns:
| Type | Description |
|---|---|
list[dict or None]
|
One dict per patient with keys |
Source code in twinweaver/utils/tte_inference.py
run_tte_probability_estimation_notebook ¶
run_tte_probability_estimation_notebook(
instructions_with_ids,
tokenizer,
config,
*,
prediction_url="http://0.0.0.0:8000/v1/",
prediction_model="default-model",
max_concurrent_requests=40,
system_prompt=None,
api_key="EMPTY",
timeout=600.0
)
Score all patients against an OpenAI-compatible API and return raw log-probs.
This is the main entry-point for running TTE probability estimation, for use in Jupyter notebooks. It is asynchronous and returns a coroutine, so it can be awaited directly in notebooks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
instructions_with_ids
|
list[tuple[str, str]]
|
Each element is |
required |
tokenizer
|
Any
|
HuggingFace-compatible tokenizer (used for token counting). |
required |
config
|
Config
|
TwinWeaver configuration object. |
required |
prediction_url
|
str
|
Base URL of the OpenAI-compatible inference server. |
'http://0.0.0.0:8000/v1/'
|
prediction_model
|
str
|
Model name / path served by the inference server. |
'default-model'
|
max_concurrent_requests
|
int
|
Maximum number of concurrent API requests. |
40
|
system_prompt
|
str or None
|
Optional system prompt. |
None
|
api_key
|
str
|
API key ( |
'EMPTY'
|
timeout
|
float
|
Per-request timeout in seconds. |
600.0
|
Returns:
| Type | Description |
|---|---|
list[dict or None]
|
One dict per patient with keys |