Skip to content

TTE Inference

twinweaver.utils.tte_inference

Time-to-event (TTE) inference helpers for probability estimation.

Provides functions to estimate the length-normalized probabilities of three mutually exclusive TTE outcomes (censored, occurred, not occurred) by scoring each completion against an OpenAI-compatible API (e.g. a local vLLM server).

The prompt construction is driven entirely by a :class:~twinweaver.common.config.Config instance, so the same code works for any dataset / prompt template.

Typical usage

import asyncio from transformers import AutoTokenizer from twinweaver.common.config import Config from twinweaver.utils.tte_inference import ( ... run_tte_probability_estimation, ... compute_length_normalized_probabilities, ... ) config = Config() tokenizer = AutoTokenizer.from_pretrained("my-model") data = [("patient_1", "instruction text ...")] raw = asyncio.run(run_tte_probability_estimation( ... data, tokenizer, config, prediction_url="http://localhost:8000/v1/", ... prediction_model="my-model")) df = compute_length_normalized_probabilities(raw)

Classes

Functions

build_scored_prompt

build_scored_prompt(
    instruction, tokenizer, config, *, system_prompt=None
)

Construct the full prompt prefix and the three scored completions.

This is the model-agnostic equivalent of the prompt assembly that was previously hard-coded for Llama-3.1 chat tokens. It uses the tokenizer's apply_chat_template when available, falling back to a simple concatenation otherwise.

Parameters:

Name Type Description Default
instruction str

The user-facing instruction text (one patient / time-point).

required
tokenizer Any

A HuggingFace-compatible tokenizer (must support encode; ideally also apply_chat_template).

required
config Config

The configuration object.

required
system_prompt str or None

An optional system prompt. When None the chat template is built without a system message.

None

Returns:

Name Type Description
prompt_prefix str

The fully assembled prompt up to (and including) the config.target_prompt_start fragment.

completions list[tuple[str, str]]

The three (label, suffix_text) pairs to score.

Notes

The slicing index (to separate prefix tokens from completion tokens) is not computed here because BPE tokenizers can merge tokens across the prefix/suffix boundary. The caller should compute the slicing index from the full concatenated prompt instead.

Source code in twinweaver/utils/tte_inference.py
def build_scored_prompt(
    instruction: str,
    tokenizer: Any,
    config: Config,
    *,
    system_prompt: str | None = None,
) -> tuple[str, list[tuple[str, str]]]:
    """Construct the full prompt prefix and the three scored completions.

    This is the **model-agnostic** equivalent of the prompt assembly that was
    previously hard-coded for Llama-3.1 chat tokens.  It uses the tokenizer's
    ``apply_chat_template`` when available, falling back to a simple
    concatenation otherwise.

    Parameters
    ----------
    instruction : str
        The user-facing instruction text (one patient / time-point).
    tokenizer
        A HuggingFace-compatible tokenizer (must support ``encode``; ideally
        also ``apply_chat_template``).
    config : Config
        The configuration object.
    system_prompt : str or None, optional
        An optional system prompt.  When *None* the chat template is built
        without a system message.

    Returns
    -------
    prompt_prefix : str
        The fully assembled prompt up to (and including) the
        ``config.target_prompt_start`` fragment.
    completions : list[tuple[str, str]]
        The three ``(label, suffix_text)`` pairs to score.

    Notes
    -----
    The slicing index (to separate prefix tokens from completion tokens)
    is **not** computed here because BPE tokenizers can merge tokens across
    the prefix/suffix boundary.  The caller should compute the slicing
    index from the full concatenated prompt instead.
    """

    # Do basic assertion that only the first task is present, since this approach doesn't support multiple tasks now.
    assert config.task_prompt_each_task.format(task_nr=1).strip() in instruction, "Task 1 not in instruction."
    assert config.task_prompt_each_task.format(task_nr=2).strip() not in instruction, (
        "Task 2 found in instruction (not supported)."
    )
    assert config.task_prompt_each_task.format(task_nr=3).strip() not in instruction, (
        "Task 3 found in instruction (not supported)."
    )

    # Extract the event name from the instruction to fill in the target prompt.
    event_name = _extract_event_name_from_instruction(instruction, config)
    target_start = config.target_prompt_start.format(event_name=event_name)
    completions = _build_tte_completion_strings(config)

    # --- Build the task target prefix that precedes the actual prediction ---
    # During training, _generate_target_string wraps each task's target with:
    #   task_target_start.format(task_nr=N) + task_type + task_text + task_target_end
    # For TTE inference there is always exactly one events task (task_nr=1),
    # so the target the model learned to produce starts with:
    #   "Task 1 is time to event prediction:\nHere is the prediction: ..."
    # We must replicate this prefix so that the scored prompt matches training.
    task_target_prefix = config.task_target_start.format(task_nr=1) + config.task_prompt_events + target_start

    instruction_clean = instruction.strip()

    # --- assemble the full text that precedes each scored suffix ----------
    has_chat_template = hasattr(tokenizer, "apply_chat_template")

    if has_chat_template:
        messages: list[dict[str, str]] = []
        if system_prompt is not None:
            messages.append({"role": "system", "content": system_prompt})
        messages.append({"role": "user", "content": instruction_clean})

        # Tokenize *without* generation prompt so we can append the
        # assistant preamble (target_prompt_start) ourselves.
        chat_prefix = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
        prompt_prefix = chat_prefix + task_target_prefix
    else:
        # Simple fallback – just concatenate.
        parts = []
        if system_prompt is not None:
            parts.append(system_prompt)
        parts.append(instruction_clean)
        prompt_prefix = "\n\n".join(parts) + task_target_prefix

    return prompt_prefix, completions

compute_length_normalized_probabilities

compute_length_normalized_probabilities(
    raw_results, *, drop_failures=False
)

Convert raw per-token log-probs into length-normalized softmax probabilities.

For each patient the function:

  1. Computes the mean log-probability (length-normalised) of each of the three completion strings.
  2. Applies a softmax over the three mean log-probs to obtain calibrated probabilities that sum to one.
  3. Derives a hard prediction (censored, occurred columns) by picking the class with the highest softmax score.

Parameters:

Name Type Description Default
raw_results list[dict or None]

Output of :func:run_tte_probability_estimation.

required
drop_failures bool

If True (default), silently drop None entries (API failures). If False, raise a ValueError when any entry is None.

False

Returns:

Type Description
DataFrame

Columns: patientid, censored (bool), occurred (bool), probability_occurrence (float), probability_no_occurrence (float), probability_censored (float), plus the intermediate avg_logprob_* and softmax_* columns.

Raises:

Type Description
ValueError

If drop_failures is False and any result is None.

Source code in twinweaver/utils/tte_inference.py
def compute_length_normalized_probabilities(
    raw_results: list[dict | None],
    *,
    drop_failures: bool = False,
) -> pd.DataFrame:
    """Convert raw per-token log-probs into length-normalized softmax probabilities.

    For each patient the function:

    1. Computes the **mean log-probability** (length-normalised) of each of the
       three completion strings.
    2. Applies a **softmax** over the three mean log-probs to obtain calibrated
       probabilities that sum to one.
    3. Derives a hard **prediction** (``censored``, ``occurred`` columns) by
       picking the class with the highest softmax score.

    Parameters
    ----------
    raw_results : list[dict or None]
        Output of :func:`run_tte_probability_estimation`.
    drop_failures : bool
        If *True* (default), silently drop ``None`` entries (API failures).
        If *False*, raise a ``ValueError`` when any entry is ``None``.

    Returns
    -------
    pd.DataFrame
        Columns: ``patientid``, ``censored`` (bool), ``occurred`` (bool),
        ``probability_occurrence`` (float), ``probability_no_occurrence`` (float),
        ``probability_censored`` (float), plus the intermediate
        ``avg_logprob_*`` and ``softmax_*`` columns.

    Raises
    ------
    ValueError
        If *drop_failures* is *False* and any result is ``None``.
    """
    if drop_failures:
        valid = [r for r in raw_results if r is not None]
    else:
        if any(r is None for r in raw_results):
            raise ValueError("Some results are None (API failures). Set drop_failures=True to silently ignore them.")
        valid = raw_results  # type: ignore[assignment]

    if not valid:
        raise ValueError("No valid results to process.")

    df = pd.DataFrame(valid)
    df = df[["patientid", LABEL_OCCURRED, LABEL_NOT_OCCURRED, LABEL_CENSORED]]

    # --- 1. Length-normalised mean log-probability ---------------------------
    for label in (LABEL_OCCURRED, LABEL_NOT_OCCURRED, LABEL_CENSORED):
        df[f"avg_logprob_{label}"] = df[label].apply(np.mean)

    # --- 2. Softmax across the three states ----------------------------------
    logprob_cols = [
        f"avg_logprob_{LABEL_OCCURRED}",
        f"avg_logprob_{LABEL_NOT_OCCURRED}",
        f"avg_logprob_{LABEL_CENSORED}",
    ]
    softmax_cols = [
        f"softmax_{LABEL_OCCURRED}",
        f"softmax_{LABEL_NOT_OCCURRED}",
        f"softmax_{LABEL_CENSORED}",
    ]

    df[softmax_cols] = df[logprob_cols].apply(
        lambda row: pd.Series(scipy.special.softmax(row.values)),
        axis=1,
    )

    # --- 3. Hard prediction ---------------------------------------------------
    selection = [LABEL_OCCURRED, LABEL_NOT_OCCURRED, LABEL_CENSORED]

    def _hard_prediction(row):
        probs = [
            row[f"softmax_{LABEL_OCCURRED}"],
            row[f"softmax_{LABEL_NOT_OCCURRED}"],
            row[f"softmax_{LABEL_CENSORED}"],
        ]
        best = selection[int(np.argmax(probs))]
        return pd.Series(
            {
                "censored": best == LABEL_CENSORED,
                "occurred": best == LABEL_OCCURRED,
            }
        )

    df[["censored", "occurred"]] = df.apply(_hard_prediction, axis=1)

    # --- 4. Friendly probability columns --------------------------------------
    df["probability_occurrence"] = df[f"softmax_{LABEL_OCCURRED}"]
    df["probability_no_occurrence"] = df[f"softmax_{LABEL_NOT_OCCURRED}"]
    df["probability_censored"] = df[f"softmax_{LABEL_CENSORED}"]

    return df

run_tte_probability_estimation

run_tte_probability_estimation(
    instructions_with_ids,
    tokenizer,
    config,
    *,
    prediction_url="http://0.0.0.0:8000/v1/",
    prediction_model="default-model",
    max_concurrent_requests=40,
    system_prompt=None,
    api_key="EMPTY",
    timeout=600.0
)

Score all patients against an OpenAI-compatible API and return raw log-probs.

This is the main entry-point for running TTE probability estimation. It is synchronous (calls asyncio.run internally) so it can be used from plain scripts or notebooks.

Parameters:

Name Type Description Default
instructions_with_ids list[tuple[str, str]]

Each element is (patientid, instruction_text).

required
tokenizer Any

HuggingFace-compatible tokenizer (used for token counting).

required
config Config

TwinWeaver configuration object.

required
prediction_url str

Base URL of the OpenAI-compatible inference server.

'http://0.0.0.0:8000/v1/'
prediction_model str

Model name / path served by the inference server.

'default-model'
max_concurrent_requests int

Maximum number of concurrent API requests.

40
system_prompt str or None

Optional system prompt.

None
api_key str

API key ("EMPTY" for local vLLM servers).

'EMPTY'
timeout float

Per-request timeout in seconds.

600.0

Returns:

Type Description
list[dict or None]

One dict per patient with keys "patientid", "occurred", "not_occurred", "censored" whose values are lists of per-token log-probabilities. None entries indicate API failures.

Source code in twinweaver/utils/tte_inference.py
def run_tte_probability_estimation(
    instructions_with_ids: list[tuple[str, str]],
    tokenizer: Any,
    config: Config,
    *,
    prediction_url: str = "http://0.0.0.0:8000/v1/",
    prediction_model: str = "default-model",
    max_concurrent_requests: int = 40,
    system_prompt: str | None = None,
    api_key: str = "EMPTY",
    timeout: float = 600.0,
) -> list[dict | None]:
    """Score all patients against an OpenAI-compatible API and return raw log-probs.

    This is the main entry-point for running TTE probability estimation.
    It is synchronous (calls ``asyncio.run`` internally) so it can be used
    from plain scripts or notebooks.

    Parameters
    ----------
    instructions_with_ids : list[tuple[str, str]]
        Each element is ``(patientid, instruction_text)``.
    tokenizer
        HuggingFace-compatible tokenizer (used for token counting).
    config : Config
        TwinWeaver configuration object.
    prediction_url : str
        Base URL of the OpenAI-compatible inference server.
    prediction_model : str
        Model name / path served by the inference server.
    max_concurrent_requests : int
        Maximum number of concurrent API requests.
    system_prompt : str or None
        Optional system prompt.
    api_key : str
        API key (``"EMPTY"`` for local vLLM servers).
    timeout : float
        Per-request timeout in seconds.

    Returns
    -------
    list[dict or None]
        One dict per patient with keys ``"patientid"``,
        ``"occurred"``, ``"not_occurred"``, ``"censored"``
        whose values are lists of per-token log-probabilities.
        ``None`` entries indicate API failures.
    """
    return asyncio.run(
        _run_tte_probability_estimation_async(
            instructions_with_ids,
            tokenizer,
            config,
            prediction_url=prediction_url,
            prediction_model=prediction_model,
            max_concurrent_requests=max_concurrent_requests,
            system_prompt=system_prompt,
            api_key=api_key,
            timeout=timeout,
        )
    )

run_tte_probability_estimation_notebook

run_tte_probability_estimation_notebook(
    instructions_with_ids,
    tokenizer,
    config,
    *,
    prediction_url="http://0.0.0.0:8000/v1/",
    prediction_model="default-model",
    max_concurrent_requests=40,
    system_prompt=None,
    api_key="EMPTY",
    timeout=600.0
)

Score all patients against an OpenAI-compatible API and return raw log-probs.

This is the main entry-point for running TTE probability estimation, for use in Jupyter notebooks. It is asynchronous and returns a coroutine, so it can be awaited directly in notebooks.

Parameters:

Name Type Description Default
instructions_with_ids list[tuple[str, str]]

Each element is (patientid, instruction_text).

required
tokenizer Any

HuggingFace-compatible tokenizer (used for token counting).

required
config Config

TwinWeaver configuration object.

required
prediction_url str

Base URL of the OpenAI-compatible inference server.

'http://0.0.0.0:8000/v1/'
prediction_model str

Model name / path served by the inference server.

'default-model'
max_concurrent_requests int

Maximum number of concurrent API requests.

40
system_prompt str or None

Optional system prompt.

None
api_key str

API key ("EMPTY" for local vLLM servers).

'EMPTY'
timeout float

Per-request timeout in seconds.

600.0

Returns:

Type Description
list[dict or None]

One dict per patient with keys "patientid", "occurred", "not_occurred", "censored" whose values are lists of per-token log-probabilities. None entries indicate API failures.

Source code in twinweaver/utils/tte_inference.py
def run_tte_probability_estimation_notebook(
    instructions_with_ids: list[tuple[str, str]],
    tokenizer: Any,
    config: Config,
    *,
    prediction_url: str = "http://0.0.0.0:8000/v1/",
    prediction_model: str = "default-model",
    max_concurrent_requests: int = 40,
    system_prompt: str | None = None,
    api_key: str = "EMPTY",
    timeout: float = 600.0,
) -> list[dict | None]:
    """Score all patients against an OpenAI-compatible API and return raw log-probs.

    This is the main entry-point for running TTE probability estimation, for use in Jupyter notebooks.
    It is asynchronous and returns a coroutine, so it can be awaited directly in notebooks.

    Parameters
    ----------
    instructions_with_ids : list[tuple[str, str]]
        Each element is ``(patientid, instruction_text)``.
    tokenizer
        HuggingFace-compatible tokenizer (used for token counting).
    config : Config
        TwinWeaver configuration object.
    prediction_url : str
        Base URL of the OpenAI-compatible inference server.
    prediction_model : str
        Model name / path served by the inference server.
    max_concurrent_requests : int
        Maximum number of concurrent API requests.
    system_prompt : str or None
        Optional system prompt.
    api_key : str
        API key (``"EMPTY"`` for local vLLM servers).
    timeout : float
        Per-request timeout in seconds.

    Returns
    -------
    list[dict or None]
        One dict per patient with keys ``"patientid"``,
        ``"occurred"``, ``"not_occurred"``, ``"censored"``
        whose values are lists of per-token log-probabilities.
        ``None`` entries indicate API failures.
    """

    # No need for a separate function in the notebook – just call the async version directly.

    return _run_tte_probability_estimation_async(
        instructions_with_ids,
        tokenizer,
        config,
        prediction_url=prediction_url,
        prediction_model=prediction_model,
        max_concurrent_requests=max_concurrent_requests,
        system_prompt=system_prompt,
        api_key=api_key,
        timeout=timeout,
    )