from __future__ import annotations
from typing import TYPE_CHECKING
import altair as alt
import pandas as pd
from passengersim.reporting import report_figure
from .generic import GenericSimulationTables, SimulationTableItem
from .tools import aggregate_by_concat_dataframe, aggregate_by_summing_dataframe
if TYPE_CHECKING:
from passengersim import Simulation
[docs]
def extract_demands(sim: Simulation) -> pd.DataFrame | None:
"""Extract demand-level summary data from a Simulation."""
dmd_data = []
for dmd in sim.eng.demands:
dmd_data.append(
{
"orig": dmd.orig,
"dest": dmd.dest,
"segment": dmd.segment,
"base_demand": dmd.base_demand,
"reference_price": dmd.reference_price,
"distance": dmd.distance,
"gt_demand": dmd.gt_demand,
"gt_revenue": dmd.gt_revenue,
"gt_sold": dmd.gt_sold,
"gt_eliminated_no_offers": dmd.gt_eliminated_no_offers,
"gt_eliminated_chose_nothing": dmd.gt_eliminated_chose_nothing,
"gt_eliminated_wtp": dmd.gt_eliminated_wtp,
}
)
if len(dmd_data) == 0:
return None
return pd.DataFrame(dmd_data).set_index(["orig", "dest", "segment"])
[docs]
def extract_demand_history(sim: Simulation) -> pd.DataFrame | None:
"""Extract demand_history from the Demand class."""
combined_data = []
for dmd in sim.eng.demands:
hist = dmd.get_demand_history()
combined_data += hist
if len(combined_data) == 0:
return None
df = pd.DataFrame.from_dict(combined_data)
df = df.set_index(["trial", "sample", "orig", "dest", "segment"])
return df
[docs]
class SimTabDemands(GenericSimulationTables):
"""Container for summary tables and figures extracted from a Simulation.
This class is a subclass of GenericSimulationTables, which is defined in
the generic module. It lists the items that are available in the
SimulationTables class, and provides type hints and (optionally, but
ideally) documentation for the data that is stored in each item.
"""
demands: pd.DataFrame = SimulationTableItem(
aggregation_func=aggregate_by_summing_dataframe(
"demands",
extra_idxs=["base_demand", "reference_price", "distance"],
), # don't sum extra_idxs, they should be identical over trials
extraction_func=extract_demands,
doc="Demand-level summary data.",
)
demand_history: pd.DataFrame | None = SimulationTableItem(
aggregation_func=aggregate_by_concat_dataframe("demand_history"),
extraction_func=extract_demand_history,
doc="Demand-level summary data from each sample.",
)
[docs]
@report_figure
def fig_demand_segmentation_distribution(
self,
x: str | None = None,
y: str | None = None,
*,
raw_df: bool = False,
also_df: bool = False,
) -> alt.Chart | pd.DataFrame | tuple[alt.Chart, pd.DataFrame]:
"""
Create a scatter plot showing the distribution of demands by segment.
Parameters
----------
x : str, optional
The column to use for the x-axis. If not provided, the first segment column
will be used.
y : str, optional
The column to use for the y-axis. If not provided, the second segment column
will be used if there are two segments, otherwise 'total' will be used.
raw_df : bool, default False
If True, return the raw DataFrame used to create the plot instead of the plot itself
also_df : bool, default False
If True, return a tuple of (plot, DataFrame) instead of just the plot
Returns
-------
alt.Chart or pd.DataFrame or tuple[alt.Chart, pd.DataFrame]
The scatter plot, the raw DataFrame, or both, depending on the parameters.
"""
df = self.demands.pivot_table(
index=["orig", "dest"], columns="segment", values="base_demand", aggfunc="sum"
).fillna(0)
# determine x and y if not provided
if x is None:
x = df.columns[0]
if y is None:
y = df.columns[1] if len(df.columns) == 2 else "total"
# compute total and ratios
total = df.sum(axis=1)
ratios = df.div(total, axis=0).add_suffix("_share")
df["total"] = total
df = pd.concat([df, ratios], axis=1)
if raw_df:
return df
df = df.reset_index()
tooltips = []
for col in df.columns:
if col in ["orig", "dest"]:
tooltips.append(col)
elif col.endswith("_share"):
tooltips.append(alt.Tooltip(col, format=".2%"))
else:
tooltips.append(alt.Tooltip(col, format=".2f"))
chart = (
alt.Chart(df)
.mark_point()
.encode(x=x, y=y, tooltip=tooltips)
.properties(title="Demand Segmentation Distribution")
)
if also_df:
return chart.interactive(), df
return chart.interactive()