Source code for eclypse.policies.replay.interpolated_replay

"""Interpolated replay policies."""

from __future__ import annotations

from itertools import pairwise
from typing import TYPE_CHECKING

from eclypse.policies.replay._helpers import (
    infer_value_columns,
    normalise_records,
)
from eclypse.policies.replay.from_records import from_records

if TYPE_CHECKING:
    from eclypse.utils.types import (
        ReplayTarget,
        UpdatePolicy,
    )


[docs] def interpolated_replay( record_source, *, target: ReplayTarget, node_id_column: str = "node_id", source_column: str = "source", target_column: str = "target", time_column: str = "time", value_columns: list[str] | tuple[str, ...] | None = None, **kwargs, ) -> UpdatePolicy: """Replay records after filling integer steps by linear interpolation. Args: record_source (Any): Iterable of sparse replay records. target (ReplayTarget): Replay target, either ``"nodes"`` or ``"edges"``. node_id_column (str): Column containing node identifiers. source_column (str): Column containing edge source identifiers. target_column (str): Column containing edge target identifiers. time_column (str): Column containing replay steps. value_columns (list[str] | tuple[str, ...] | None): Optional explicit numeric columns to interpolate. kwargs (Any): Additional keyword arguments forwarded to ``from_records``. Returns: Stateful replay policy using interpolated records. """ records = normalise_records(record_source) identity_columns = ( [node_id_column] if target == "nodes" else [source_column, target_column] ) columns = infer_value_columns( records, reserved_columns=[*identity_columns, time_column], value_columns=value_columns, ) interpolated = _interpolate_records( records, identity_columns=identity_columns, time_column=time_column, value_columns=columns, ) return from_records( interpolated, target=target, node_id_column=node_id_column, source_column=source_column, target_column=target_column, time_column=time_column, value_columns=columns, **kwargs, )
def _interpolate_records( records: list[dict], *, identity_columns: list[str], time_column: str, value_columns: list[str], ) -> list[dict]: grouped: dict[tuple, list[dict]] = {} for record in records: identity = tuple(record[column] for column in identity_columns) grouped.setdefault(identity, []).append(record) result: list[dict] = [] for identity, group in grouped.items(): ordered = sorted(group, key=lambda record: int(record[time_column])) for left, right in pairwise(ordered): left_step = int(left[time_column]) right_step = int(right[time_column]) result.append(dict(left)) for step in range(left_step + 1, right_step): progress = (step - left_step) / (right_step - left_step) interpolated = { column: value for column, value in zip(identity_columns, identity, strict=True) } interpolated[time_column] = step for column in value_columns: if column in left and column in right: interpolated[column] = left[column] + ( (right[column] - left[column]) * progress ) result.append(interpolated) if ordered: result.append(dict(ordered[-1])) return sorted(result, key=lambda record: int(record[time_column]))