from __future__ import annotations
from collections.abc import Callable, MutableMapping
from typing import TYPE_CHECKING
from passengersim_core import (
Event,
)
from passengersim.tracers.generic import GenericTracer
if TYPE_CHECKING:
from passengersim_core import SimulationEngine
from passengersim.rm.systems import RmSys
[docs]
class CallbackMixin:
if TYPE_CHECKING:
eng: SimulationEngine
[docs]
def end_sample_callback(self, callback: Callable[[CallbackMixin], dict | None] | GenericTracer):
"""Register a function to be triggered at the end of each sample.
The callback function will be triggered before counters are reset or
history buffers are rolled over.
Parameters
----------
callback : Callable[[Simulation], None]
The callback function to register. It should accept a single argument,
which will be the Simulation object, and return a dictionary of interesting
things to store, or nothing.
"""
if not hasattr(self, "end_sample_callbacks"):
self.end_sample_callbacks = []
if isinstance(callback, GenericTracer):
callback = callback.fresh()
self.end_sample_callbacks.append(callback)
return callback
[docs]
def begin_sample_callback(self, callback: Callable[[CallbackMixin], dict | None] | GenericTracer):
"""Register a function to be triggered at the beginning of each sample.
The callback function will be triggered after initial setup including
all RM steps for the initial DCP, but before any customers can arrive.
Parameters
----------
callback : Callable[[Simulation], dict | None]
The callback function to register. It should accept a single argument,
which will be the Simulation object, and return a dictionary of interesting
things to store, or nothing.
"""
if not hasattr(self, "begin_sample_callbacks"):
self.begin_sample_callbacks = []
if isinstance(callback, GenericTracer):
callback = callback.fresh()
self.begin_sample_callbacks.append(callback)
return callback
[docs]
def daily_callback(self, callback: Callable[[CallbackMixin, int], dict | None] | GenericTracer | RmSys):
"""Register a function to be triggered each day during a sample.
The callback function will be triggered after all RM steps when the day
coincides with a DCP.
Parameters
----------
callback : Callable[[Simulation, int], dict | None]
The callback function to register. It should accept two arguments,
which will be the Simulation object and the days_prior, and return None.
"""
if not hasattr(self, "daily_callbacks"):
self.daily_callbacks = []
if isinstance(callback, GenericTracer):
callback = callback.fresh()
self.daily_callbacks.append(callback)
return callback
[docs]
def callback_functions(self) -> dict[str, list[Callable]]:
"""Get all callback functions."""
cb = {}
for k in ["begin_sample", "end_sample", "daily"]:
if hasattr(self, f"{k}_callbacks"):
cb[f"{k}_callbacks"] = getattr(self, f"{k}_callbacks", [])
return cb
[docs]
def apply_callback_functions(self, sim: CallbackMixin):
callbacks = self.callback_functions()
for cb in callbacks.get("begin_sample_callbacks", []):
sim.begin_sample_callback(cb)
for cb in callbacks.get("end_sample_callbacks", []):
sim.end_sample_callback(cb)
for cb in callbacks.get("daily_callbacks", []):
sim.daily_callback(cb)
[docs]
def use_registered_callbacks(self):
"""Adopt all globally registered callbacks."""
from .registration import _REGISTERED_CALLBACKS
_REGISTERED_CALLBACKS.apply_callback_functions(self)
[docs]
def add_callback_events(self):
"""Add callback events to the simulation event queue."""
dcp_hour = self.eng.config.simulation_controls.dcp_hour
for callback in getattr(self, "begin_sample_callbacks", []):
dcp = self.dcp_list[0]
# no customers can arrive within 5 seconds of a DCP.
# we want these callbacks to be triggered after the first DCP
# but before any customers can arrive, so we add one second.
event_time = int(self.eng.base_time - dcp * 86400 + 3600 * dcp_hour) + 1
rm_event = Event(
(
"callback_begin_sample",
callback,
),
event_time,
)
self.eng.add_event(rm_event)
for callback in getattr(self, "end_sample_callbacks", []):
# we want these callbacks to be triggered after the last DCP
# so we add one second.
event_time = int(self.eng.base_time + 3600 * dcp_hour) + 1
rm_event = Event(
(
"callback_end_sample",
callback,
),
event_time,
)
self.eng.add_event(rm_event)
for callback in getattr(self, "daily_callbacks", []):
day = self.dcp_list[0]
# The priority is a number of seconds before or after the overnight
# trigger time. If priority is negative, the daily callback will be
# run before any RM system daily or same-time DCP events. If priority
# is positive it will be run after other RM events.
priority = getattr(callback, "priority", 0)
while day >= 0:
event_time = int(self.eng.base_time - day * 86400 + 3600 * dcp_hour + priority)
rm_event = Event(("callback_daily", callback, day), event_time)
self.eng.add_event(rm_event)
day -= 1
def _flatten_dict_keys(d):
for k, v in d.items():
if isinstance(v, dict):
for sub_k, sub_v in _flatten_dict_keys(v):
yield f"{k}.{sub_k}", sub_v
elif isinstance(v, list | tuple):
for i, sub_v in enumerate(v):
yield f"{k}[{i}]", sub_v
else:
yield k, v
def _flatten_iter_of_dict(i):
for item in i:
yield dict(_flatten_dict_keys(item))
[docs]
class CallbackData(MutableMapping):
"""Data collected during callbacks."""
def __init__(self):
self._data = {}
self._cached_dataframes = {}
[docs]
def get_data(self, label: str, trial: int, sample: int, days_prior: int | None = None):
key_match = {"trial": trial, "sample": sample}
if days_prior is not None:
key_match["days_prior"] = days_prior
if label not in self._data:
self._data[label] = [key_match]
store = self._data[label][-1]
if any(store.get(k) != v for k, v in key_match.items()):
self._data[label].append(key_match)
store = self._data[label][-1]
return store
[docs]
def update_data(
self,
label: str,
trial: int,
sample: int,
days_prior: int | None = None,
**kwargs,
):
store = self.get_data(label, trial, sample, days_prior)
store.update(kwargs)
self._cached_dataframes.pop(label, None)
[docs]
def to_dataframe(self, item: str):
import pandas as pd
try:
if item in self._cached_dataframes:
return self._cached_dataframes[item]
except AttributeError:
self._cached_dataframes = {}
if item in self._data:
df = pd.DataFrame(_flatten_iter_of_dict(self._data[item]))
self._cached_dataframes[item] = df
return df
else:
raise KeyError(f"{item} not found in callback data")
def __getitem__(self, item):
return self._data[item]
def __setitem__(self, key, value):
self._data[key] = value
self._cached_dataframes.pop(key, None)
def __delitem__(self, key):
del self._data[key]
self._cached_dataframes.pop(key, None)
def __iter__(self):
return iter(self._data)
def __len__(self):
return len(self._data)
def __getattr__(self, item):
if not item.startswith("_") and item in self._data:
return self._data[item]
raise AttributeError(f"{self.__class__.__name__} has no attribute '{item}'")
def __repr__(self):
if self._data:
keys = ", ".join(self._data.keys())
return f"<{self.__class__.__module__}.{self.__class__.__name__} from {keys}>"
else:
return f"<{self.__class__.__module__}.{self.__class__.__name__} with no data>"
def __bool__(self):
return bool(self._data)
def __add__(self, other):
if isinstance(other, CallbackData):
new = CallbackData()
for k in self._data:
try:
aggregation_process = self._data[k].attrs.get("aggregation_process", None)
except AttributeError:
aggregation_process = None
if k in other._data and aggregation_process == "wgt_by_samples":
n1 = self._data[k].attrs.get("n_samples", 1)
n2 = other._data[k].attrs.get("n_samples", 1)
new._data[k] = ((self._data[k] * n1) + (other._data[k] * n2)) / (n1 + n2)
new._data[k].attrs["n_samples"] = n1 + n2
new._data[k].attrs["aggregation_process"] = "wgt_by_samples"
elif k in other._data and aggregation_process == "wgt_by_trials":
n1 = self._data[k].attrs.get("n_trials", 1)
n2 = other._data[k].attrs.get("n_trials", 1)
new._data[k] = ((self._data[k] * n1) + (other._data[k] * n2)) / (n1 + n2)
new._data[k].attrs["n_samples"] = n1 + n2
new._data[k].attrs["aggregation_process"] = "wgt_by_trials"
else:
new._data[k] = self._data[k]
if k in other._data:
new._data[k] += other._data[k]
for k in other._data:
if k not in self._data:
new._data[k] = other._data[k]
return new
elif isinstance(other, int) and other == 0:
return self
elif other is None:
return self
else:
return NotImplemented
def __radd__(self, other):
if isinstance(other, int) and other == 0:
return self
elif other is None:
return self
else:
return NotImplemented
def __dir__(self):
return self._data.keys()