Source code for eclypse.policies.replay.replay_edges

"""Replay edge attributes from records."""

from __future__ import annotations

from dataclasses import dataclass
from typing import (
    TYPE_CHECKING,
    Any,
)

from eclypse.policies._helpers import validate_missing_behaviour
from eclypse.policies.replay._helpers import (
    group_records_by_step,
    infer_value_columns,
    initial_step,
    normalise_records,
    resolve_replay_step,
)

if TYPE_CHECKING:
    from eclypse.graph.asset_graph import AssetGraph
    from eclypse.policies._filters import EdgeFilter
    from eclypse.utils.types import (
        MissingPolicyBehaviour,
        UpdatePolicy,
    )


@dataclass(slots=True)
class ReplayEdgesPolicy:
    """Replay edge attributes from time-indexed records."""

    records_by_step: dict[int, list[dict[str, Any]]]
    columns: list[str]
    source_column: str = "source"
    target_column: str = "target"
    selected_edge_ids: set[tuple[str, str]] | None = None
    edge_filter: EdgeFilter | None = None
    missing: MissingPolicyBehaviour = "ignore"
    cyclic: bool = False
    current_step: int = 0

    def __call__(self, graph: AssetGraph):
        """Apply the replay records for the current step to matching edges."""
        replay_step = resolve_replay_step(
            self.records_by_step,
            self.current_step,
            cyclic=self.cyclic,
        )
        for record in self.records_by_step.get(replay_step, []):
            _update_edge_from_record(
                graph,
                record,
                columns=self.columns,
                source_column=self.source_column,
                target_column=self.target_column,
                selected_edge_ids=self.selected_edge_ids,
                edge_filter=self.edge_filter,
                missing=self.missing,
            )

        graph.logger.trace(f"Applied replay_edges policy for step {replay_step}.")
        self.current_step += 1


[docs] def replay_edges( record_source, *, source_column: str = "source", target_column: str = "target", time_column: str = "time", value_columns: list[str] | tuple[str, ...] | None = None, edge_ids: list[tuple[str, str]] | None = None, edge_filter: EdgeFilter | None = None, missing: MissingPolicyBehaviour = "ignore", start_step: int | None = None, cyclic: bool = False, ) -> UpdatePolicy: """Replay edge attributes from time-indexed records. Args: record_source (Any): Iterable of mapping records to replay. source_column (str): Column containing edge source identifiers. target_column (str): Column containing edge target identifiers. time_column (str): Column containing replay steps. value_columns (list[str] | tuple[str, ...] | None): Optional explicit columns to copy from records. edge_ids (list[tuple[str, str]] | None): Optional explicit edge identifiers to mutate. edge_filter (EdgeFilter | None): Optional predicate receiving ``(source, target, data)``. missing (MissingPolicyBehaviour): Behaviour when a replay record targets a missing edge. start_step (int | None): Optional starting replay step. cyclic (bool): Whether to wrap past the final available replay step. Returns: Stateful edge replay policy. """ validate_missing_behaviour(missing) records = normalise_records(record_source) columns = infer_value_columns( records, reserved_columns=[source_column, target_column, time_column], value_columns=value_columns, ) records_by_step = group_records_by_step(records, time_column=time_column) selected_edge_ids = set(edge_ids) if edge_ids is not None else None current_step = initial_step(records_by_step, start_step) return ReplayEdgesPolicy( records_by_step=records_by_step, columns=columns, source_column=source_column, target_column=target_column, selected_edge_ids=selected_edge_ids, edge_filter=edge_filter, missing=missing, cyclic=cyclic, current_step=current_step, )
def _update_edge_from_record( graph: AssetGraph, record, *, columns: list[str], source_column: str, target_column: str, selected_edge_ids: set[tuple[str, str]] | None, edge_filter, missing: MissingPolicyBehaviour, ): edge_id = (record[source_column], record[target_column]) if selected_edge_ids is not None and edge_id not in selected_edge_ids: return if not graph.has_edge(*edge_id): if missing == "error": raise KeyError( f'Edge "{edge_id[0]} -> {edge_id[1]}" not found in the graph.' ) return data = graph.edges[edge_id] if edge_filter is not None and not edge_filter(*edge_id, data): return for column in columns: if column in record: data[column] = record[column]