Source code for eclypse.report.reporters.tensorboard
# pylint: disable=no-member, unused-argument
"""Module for TensorBoardReporter class.
It is used to report the simulation metrics on a TensorBoard file, using the
TensorBoardX library. It creates a separate plot for each callback, where the x-axis is
the combination of 'event_name' and 'event_idx', and the y-axis is the value. Each plot
contains multiple lines, one for each unique path in the data dictionary.
"""
from __future__ import annotations
from importlib import import_module
from typing import (
TYPE_CHECKING,
Any,
)
from eclypse.report.reporter import Reporter
from eclypse.utils.defaults import TENSORBOARD_REPORT_DIR
if TYPE_CHECKING:
from collections.abc import (
Generator,
)
from pathlib import Path
from tensorboardX import SummaryWriter
from eclypse.workflow.event import EclypseEvent
[docs]
class TensorBoardReporter(Reporter):
"""Asynchronous reporter for simulation metrics in TensorBoardX format."""
[docs]
def __init__(self, report_path: str | Path):
"""Initialize the TensorBoard reporter."""
super().__init__(report_path)
self.report_path = self.report_path / TENSORBOARD_REPORT_DIR
self._writer = None
[docs]
async def init(self):
"""Initialize the TensorBoard reporter."""
SummaryWriter = import_module("tensorboardX").SummaryWriter
self._writer = SummaryWriter(log_dir=self.report_path)
[docs]
async def close(self):
"""Close the TensorBoard writer."""
if self._writer is not None:
self._writer.close()
self._writer = None
[docs]
def report(
self,
_: str,
event_idx: int,
callback: EclypseEvent,
) -> Generator[Any, None, None]:
"""Generate TensorBoard-compatible metric tuples.
Args:
_ (str): The name of the event.
event_idx (int): The index of the event trigger (step).
callback (EclypseEvent):
The executed callback containing the data to report.
Returns:
Generator[Any, None, None]:
Tuples with (callback_name, metric_dict, event_idx).
"""
if callback.type is None:
return
for line in self.callback_rows(callback):
if line[-1] is None:
continue
metric_name = "/".join(line[:-1]) or "value"
yield (callback.name, {metric_name: line[-1]}, event_idx)
[docs]
async def write(
self, callback_type: str, data: list[tuple[str, dict[str, float], int]]
):
"""Write the collected metrics to TensorBoard.
Args:
callback_type (str): The type of the callback (used for organizing plots).
data (list[tuple[str, dict[str, float], int]]): Tuples
containing (callback_name, metric_dict, event_idx).
"""
for cb_name, metric_dict, step in data:
self.writer.add_scalars(f"{callback_type}/{cb_name}", metric_dict, step)
@property
def writer(self) -> SummaryWriter:
"""Get the TensorBoardX SummaryWriter."""
if self._writer is None:
raise RuntimeError("TensorBoard reporter is not initialised.")
return self._writer