Source code for eclypse.policies.distribution.discrete

"""Weighted discrete numeric multiplier distribution policy."""

from __future__ import annotations

from typing import TYPE_CHECKING

from eclypse.policies.distribution._helpers import (
    build_sampled_distribution_policy,
    normalize_distributions,
    validate_distributions,
)
from eclypse.utils.constants import MIN_FLOAT

if TYPE_CHECKING:
    from random import Random

    from eclypse.policies._filters import (
        EdgeFilter,
        NodeFilter,
    )
    from eclypse.utils.types import UpdatePolicy

DiscreteDistribution = list[tuple[float, float]] | tuple[tuple[float, float], ...]


[docs] def discrete( *, node_assets: str | list[str] | None = None, edge_assets: str | list[str] | None = None, node_distribution: DiscreteDistribution = ((1.0, 1.0),), edge_distribution: DiscreteDistribution | None = None, node_asset_distributions: dict[str, DiscreteDistribution] | None = None, edge_asset_distributions: dict[str, DiscreteDistribution] | None = None, minimum: float = MIN_FLOAT, node_ids: list[str] | None = None, node_filter: NodeFilter | None = None, edge_ids: list[tuple[str, str]] | None = None, edge_filter: EdgeFilter | None = None, ) -> UpdatePolicy: """Sample numeric multipliers from weighted choices. Args: node_assets (str | list[str] | None): Optional node asset key selector. edge_assets (str | list[str] | None): Optional edge asset key selector. node_distribution (DiscreteDistribution): Default ``(value, weight)`` choices for node assets. edge_distribution (DiscreteDistribution | None): Default choices for edge assets. When omitted, ``node_distribution`` is reused. node_asset_distributions (dict[str, DiscreteDistribution] | None): Optional per-node-asset choices. edge_asset_distributions (dict[str, DiscreteDistribution] | None): Optional per-edge-asset choices. minimum (float): Lower bound after applying the sampled multiplier. node_ids (list[str] | None): Optional explicit node identifiers to mutate. node_filter (NodeFilter | None): Optional predicate receiving ``(node_id, data)``. edge_ids (list[tuple[str, str]] | None): Optional explicit edge identifiers to mutate. edge_filter (EdgeFilter | None): Optional predicate receiving ``(source, target, data)``. Returns: Policy that multiplies selected numeric assets by discrete samples. """ effective_edge_distribution = ( node_distribution if edge_distribution is None else edge_distribution ) validate_distributions( { **normalize_distributions("node_distribution", node_distribution), **normalize_distributions("edge_distribution", effective_edge_distribution), **normalize_distributions( "node_asset_distributions", node_asset_distributions ), **normalize_distributions( "edge_asset_distributions", edge_asset_distributions ), }, checks=[ (lambda distribution: len(distribution) > 0, "must not be empty."), ( lambda distribution: all(weight >= 0 for _, weight in distribution), "must use non-negative weights.", ), ( lambda distribution: sum(weight for _, weight in distribution) > 0, "must include at least one positive weight.", ), ], ) return build_sampled_distribution_policy( kind="discrete", node_assets=node_assets, edge_assets=edge_assets, node_distribution=tuple(node_distribution), edge_distribution=tuple(effective_edge_distribution), node_asset_distributions=node_asset_distributions, edge_asset_distributions=edge_asset_distributions, minimum=minimum, node_ids=node_ids, node_filter=node_filter, edge_ids=edge_ids, edge_filter=edge_filter, sampler=_sample_discrete, )
def _sample_discrete(rnd: Random, distribution: DiscreteDistribution) -> float: values, weights = zip(*distribution, strict=True) return rnd.choices(values, weights=weights, k=1)[0]