Source code for eclypse.report.reporters.tensorboard

# pylint: disable=no-member, unused-argument
"""Module for TensorBoardReporter class.

It is used to report the simulation metrics on a TensorBoard file, using the
TensorBoardX library. It creates a separate plot for each callback, where the x-axis is
the combination of 'event_name' and 'event_idx', and the y-axis is the value. Each plot
contains multiple lines, one for each unique path in the data dictionary.
"""

from __future__ import annotations

from importlib import import_module
from typing import (
    TYPE_CHECKING,
    Any,
)

from eclypse.report.reporter import Reporter
from eclypse.utils.defaults import TENSORBOARD_REPORT_DIR

if TYPE_CHECKING:
    from collections.abc import (
        Generator,
    )
    from pathlib import Path

    from tensorboardX import SummaryWriter

    from eclypse.workflow.event import EclypseEvent


[docs] class TensorBoardReporter(Reporter): """Asynchronous reporter for simulation metrics in TensorBoardX format."""
[docs] def __init__(self, report_path: str | Path): """Initialize the TensorBoard reporter.""" super().__init__(report_path) self.report_path = self.report_path / TENSORBOARD_REPORT_DIR self._writer = None
[docs] async def init(self): """Initialize the TensorBoard reporter.""" SummaryWriter = import_module("tensorboardX").SummaryWriter self._writer = SummaryWriter(log_dir=self.report_path)
[docs] async def close(self): """Close the TensorBoard writer.""" if self._writer is not None: self._writer.close() self._writer = None
[docs] def report( self, _: str, event_idx: int, callback: EclypseEvent, ) -> Generator[Any, None, None]: """Generate TensorBoard-compatible metric tuples. Args: _ (str): The name of the event. event_idx (int): The index of the event trigger (step). callback (EclypseEvent): The executed callback containing the data to report. Returns: Generator[Any, None, None]: Tuples with (callback_name, metric_dict, event_idx). """ if callback.type is None: return for line in self.callback_rows(callback): if line[-1] is None: continue metric_name = "/".join(line[:-1]) or "value" yield (callback.name, {metric_name: line[-1]}, event_idx)
[docs] async def write( self, callback_type: str, data: list[tuple[str, dict[str, float], int]] ): """Write the collected metrics to TensorBoard. Args: callback_type (str): The type of the callback (used for organizing plots). data (list[tuple[str, dict[str, float], int]]): Tuples containing (callback_name, metric_dict, event_idx). """ for cb_name, metric_dict, step in data: self.writer.add_scalars(f"{callback_type}/{cb_name}", metric_dict, step)
@property def writer(self) -> SummaryWriter: """Get the TensorBoardX SummaryWriter.""" if self._writer is None: raise RuntimeError("TensorBoard reporter is not initialised.") return self._writer