Skip to content

Forecasting Inference

twinweaver.utils.forecasting_inference

Forecasting inference helpers for vLLM-based text generation.

Provides functions to generate forecasting predictions (future lab values, vitals, etc.) by sending instruction prompts to an OpenAI-compatible API (e.g. a local vLLM server) and parsing the model's text output back into structured DataFrames via :class:~twinweaver.instruction.converter_instruction.ConverterInstruction.

The prompt construction is driven by :class:~twinweaver.instruction.converter_instruction.ConverterInstruction (specifically its forward_conversion_inference method), so the same code works for any dataset / prompt template.

Typical usage

import asyncio from twinweaver.common.config import Config from twinweaver.utils.forecasting_inference import ( ... run_forecasting_inference, ... parse_forecasting_results, ... ) config = Config()

prompts_with_meta is a list of dicts with keys:

"patientid", "instruction", "split_date"

results = asyncio.run(run_forecasting_inference( ... prompts_with_meta, ... prediction_url="http://localhost:8000/v1/", ... prediction_model="my-model", ... )) df = parse_forecasting_results(results, converter, dm)

Functions

parse_forecasting_results

parse_forecasting_results(
    raw_results,
    converter,
    data_manager,
    *,
    drop_failures=False,
    aggregate_samples=True
)

Parse raw generated texts into structured DataFrames via reverse conversion.

For each patient the function:

  1. Calls converter.reverse_conversion on every generated text to obtain structured forecasting DataFrames.
  2. When n_samples > 1 and aggregate_samples is True, aggregates the multiple trajectories using converter.aggregate_multiple_responses.
  3. Returns a single long-format DataFrame with all patients' predictions.

Parameters:

Name Type Description Default
raw_results list[dict or None]

Output of :func:run_forecasting_inference.

required
converter ConverterInstruction

The same converter instance used to generate the instruction prompts. Must expose reverse_conversion and aggregate_multiple_responses.

required
data_manager DataManager

The data manager instance (passed to reverse_conversion).

required
drop_failures bool

If True, silently drop None entries (API failures). If False, raise a ValueError when any entry is None.

False
aggregate_samples bool

If True (default) and multiple samples were generated per patient, aggregate them via converter.aggregate_multiple_responses. If False, each sample is returned as a separate row block with a "sample_idx" column.

True

Returns:

Type Description
DataFrame

A long-format DataFrame with columns from the reverse-converted forecasting data plus "patientid" and optionally "sample_idx".

Raises:

Type Description
ValueError

If drop_failures is False and any result is None.

Source code in twinweaver/utils/forecasting_inference.py
def parse_forecasting_results(
    raw_results: list[dict | None],
    converter: Any,
    data_manager: Any,
    *,
    drop_failures: bool = False,
    aggregate_samples: bool = True,
) -> pd.DataFrame:
    """Parse raw generated texts into structured DataFrames via reverse conversion.

    For each patient the function:

    1. Calls ``converter.reverse_conversion`` on every generated text to obtain
       structured forecasting DataFrames.
    2. When *n_samples > 1* and ``aggregate_samples`` is ``True``, aggregates the
       multiple trajectories using ``converter.aggregate_multiple_responses``.
    3. Returns a single long-format DataFrame with all patients' predictions.

    Parameters
    ----------
    raw_results : list[dict or None]
        Output of :func:`run_forecasting_inference`.
    converter : ConverterInstruction
        The same converter instance used to generate the instruction prompts.
        Must expose ``reverse_conversion`` and ``aggregate_multiple_responses``.
    data_manager : DataManager
        The data manager instance (passed to ``reverse_conversion``).
    drop_failures : bool
        If *True*, silently drop ``None`` entries (API failures).
        If *False*, raise a ``ValueError`` when any entry is ``None``.
    aggregate_samples : bool
        If *True* (default) and multiple samples were generated per patient,
        aggregate them via ``converter.aggregate_multiple_responses``.
        If *False*, each sample is returned as a separate row block with
        a ``"sample_idx"`` column.

    Returns
    -------
    pd.DataFrame
        A long-format DataFrame with columns from the reverse-converted
        forecasting data plus ``"patientid"`` and optionally ``"sample_idx"``.

    Raises
    ------
    ValueError
        If *drop_failures* is *False* and any result is ``None``.
    """
    if drop_failures:
        valid = [r for r in raw_results if r is not None]
    else:
        if any(r is None for r in raw_results):
            raise ValueError("Some results are None (API failures). Set drop_failures=True to silently ignore them.")
        valid = raw_results  # type: ignore[assignment]

    if not valid:
        raise ValueError("No valid results to process.")

    all_rows: list[pd.DataFrame] = []

    for result in valid:
        patientid = result["patientid"]
        split_date = result["split_date"]
        generated_texts = result["generated_texts"]

        # Reverse-convert each sample
        sample_dfs: list[pd.DataFrame] = []
        for sample_idx, text in enumerate(generated_texts):
            try:
                parsed_tasks = converter.reverse_conversion(
                    text,
                    data_manager,
                    split_date,
                    patientid=patientid,
                    inference_override=True,
                )
            except Exception as exc:
                print(f"Warning: reverse_conversion failed for patient {patientid} sample {sample_idx}: {exc}")
                continue

            # Collect forecasting task results
            for task_result in parsed_tasks:
                task_type = task_result.get("task_type", "")
                result_data = task_result.get("result")

                if isinstance(result_data, pd.DataFrame) and not result_data.empty:
                    df_task = result_data.copy()
                    df_task["patientid"] = patientid
                    df_task["sample_idx"] = sample_idx
                    df_task["task_type"] = task_type
                    sample_dfs.append(df_task)

        if not sample_dfs:
            continue

        if aggregate_samples and len(generated_texts) > 1 and len(sample_dfs) > 1:
            # Use the converter's built-in aggregation (groups by task type)
            # Separate by task type, aggregate each, then combine
            combined = pd.concat(sample_dfs, ignore_index=True)
            task_types = combined["task_type"].unique()
            agg_parts = []
            for tt in task_types:
                task_subset = combined[combined["task_type"] == tt]
                # Group into per-sample DataFrames for the aggregator
                per_sample = [
                    task_subset[task_subset["sample_idx"] == si].drop(
                        columns=["sample_idx", "task_type"], errors="ignore"
                    )
                    for si in task_subset["sample_idx"].unique()
                ]
                try:
                    agg_df, _meta = converter.aggregate_multiple_responses(per_sample)
                    agg_df["task_type"] = tt
                    agg_df["patientid"] = patientid
                    agg_parts.append(agg_df)
                except Exception as exc:
                    print(f"Warning: aggregation failed for patient {patientid}: {exc}")
                    # Fallback: just keep first sample
                    fallback = per_sample[0].copy()
                    fallback["task_type"] = tt
                    fallback["patientid"] = patientid
                    agg_parts.append(fallback)

            if agg_parts:
                all_rows.append(pd.concat(agg_parts, ignore_index=True))
        else:
            all_rows.extend(sample_dfs)

    if not all_rows:
        return pd.DataFrame()

    df_out = pd.concat(all_rows, ignore_index=True)
    return df_out

run_forecasting_inference

run_forecasting_inference(
    prompts_with_meta,
    *,
    prediction_url="http://0.0.0.0:8000/v1/",
    prediction_model="default-model",
    max_concurrent_requests=40,
    system_prompt=None,
    max_new_tokens=512,
    temperature=0.7,
    top_p=0.9,
    n_samples=1,
    api_key="EMPTY",
    timeout=600.0
)

Generate forecasting predictions for all patients via an OpenAI-compatible API.

This is the main synchronous entry-point. It calls asyncio.run internally so it can be used from plain scripts.

Parameters:

Name Type Description Default
prompts_with_meta list[PromptPayload]

Each element is a dict with at least the following keys:

  • "patientid" – unique identifier (str)
  • "instruction" – full instruction text produced by ConverterInstruction.forward_conversion_inference
  • "split_date" – the reference date (datetime) used when building the split; needed later for reverse_conversion

Any extra keys are passed through unchanged to the results.

required
prediction_url str

Base URL of the OpenAI-compatible inference server.

'http://0.0.0.0:8000/v1/'
prediction_model str

Model name / path served by the inference server.

'default-model'
max_concurrent_requests int

Maximum number of concurrent API requests.

40
system_prompt str or None

Optional system prompt.

None
max_new_tokens int

Maximum number of tokens to generate per completion.

512
temperature float

Sampling temperature (0 = greedy).

0.7
top_p float

Nucleus-sampling probability mass.

0.9
n_samples int

Number of independent completions per prompt. Useful for trajectory aggregation (see :meth:ConverterInstruction.aggregate_multiple_responses).

1
api_key str

API key ("EMPTY" for local vLLM servers).

'EMPTY'
timeout float

Per-request timeout in seconds.

600.0

Returns:

Type Description
list[dict or None]

One dict per patient. Each dict contains all keys from the input payload (except "instruction") plus "generated_texts" – a list of n_samples generated completion strings. None entries indicate API failures.

Source code in twinweaver/utils/forecasting_inference.py
def run_forecasting_inference(
    prompts_with_meta: list[PromptPayload],
    *,
    prediction_url: str = "http://0.0.0.0:8000/v1/",
    prediction_model: str = "default-model",
    max_concurrent_requests: int = 40,
    system_prompt: str | None = None,
    max_new_tokens: int = 512,
    temperature: float = 0.7,
    top_p: float = 0.9,
    n_samples: int = 1,
    api_key: str = "EMPTY",
    timeout: float = 600.0,
) -> list[dict | None]:
    """Generate forecasting predictions for all patients via an OpenAI-compatible API.

    This is the main synchronous entry-point.  It calls ``asyncio.run``
    internally so it can be used from plain scripts.

    Parameters
    ----------
    prompts_with_meta : list[PromptPayload]
        Each element is a dict with **at least** the following keys:

        * ``"patientid"``   – unique identifier (str)
        * ``"instruction"`` – full instruction text produced by
          ``ConverterInstruction.forward_conversion_inference``
        * ``"split_date"``  – the reference date (datetime) used when
          building the split; needed later for ``reverse_conversion``

        Any extra keys are passed through unchanged to the results.

    prediction_url : str
        Base URL of the OpenAI-compatible inference server.
    prediction_model : str
        Model name / path served by the inference server.
    max_concurrent_requests : int
        Maximum number of concurrent API requests.
    system_prompt : str or None
        Optional system prompt.
    max_new_tokens : int
        Maximum number of tokens to generate per completion.
    temperature : float
        Sampling temperature (0 = greedy).
    top_p : float
        Nucleus-sampling probability mass.
    n_samples : int
        Number of independent completions per prompt.  Useful for
        trajectory aggregation (see
        :meth:`ConverterInstruction.aggregate_multiple_responses`).
    api_key : str
        API key (``"EMPTY"`` for local vLLM servers).
    timeout : float
        Per-request timeout in seconds.

    Returns
    -------
    list[dict or None]
        One dict per patient.  Each dict contains all keys from the input
        payload (except ``"instruction"``) plus ``"generated_texts"`` – a
        list of *n_samples* generated completion strings.
        ``None`` entries indicate API failures.
    """
    return asyncio.run(
        _run_forecasting_inference_async(
            prompts_with_meta,
            prediction_url=prediction_url,
            prediction_model=prediction_model,
            max_concurrent_requests=max_concurrent_requests,
            system_prompt=system_prompt,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            n_samples=n_samples,
            api_key=api_key,
            timeout=timeout,
        )
    )

run_forecasting_inference_notebook

run_forecasting_inference_notebook(
    prompts_with_meta,
    *,
    prediction_url="http://0.0.0.0:8000/v1/",
    prediction_model="default-model",
    max_concurrent_requests=40,
    system_prompt=None,
    max_new_tokens=512,
    temperature=0.7,
    top_p=0.9,
    n_samples=1,
    api_key="EMPTY",
    timeout=600.0
)

Generate forecasting predictions – async variant for Jupyter notebooks.

Identical to :func:run_forecasting_inference but returns a coroutine that can be await-ed directly in a notebook cell (which already has a running event loop).

Returns:

Type Description
Coroutine[..., list[dict or None]]
Source code in twinweaver/utils/forecasting_inference.py
def run_forecasting_inference_notebook(
    prompts_with_meta: list[PromptPayload],
    *,
    prediction_url: str = "http://0.0.0.0:8000/v1/",
    prediction_model: str = "default-model",
    max_concurrent_requests: int = 40,
    system_prompt: str | None = None,
    max_new_tokens: int = 512,
    temperature: float = 0.7,
    top_p: float = 0.9,
    n_samples: int = 1,
    api_key: str = "EMPTY",
    timeout: float = 600.0,
) -> list[dict | None]:
    """Generate forecasting predictions – async variant for Jupyter notebooks.

    Identical to :func:`run_forecasting_inference` but returns a *coroutine*
    that can be ``await``-ed directly in a notebook cell (which already has
    a running event loop).

    Returns
    -------
    Coroutine[..., list[dict or None]]
    """
    return _run_forecasting_inference_async(
        prompts_with_meta,
        prediction_url=prediction_url,
        prediction_model=prediction_model,
        max_concurrent_requests=max_concurrent_requests,
        system_prompt=system_prompt,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        n_samples=n_samples,
        api_key=api_key,
        timeout=timeout,
    )