Source code for flightrisk.data.splits

from __future__ import annotations

from dataclasses import dataclass

import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold


[docs] @dataclass(frozen=True) class TemporalSplit: """Indices for a strict temporal train / validation / test split. :param train_idx: Row positions in the train slice. :param val_idx: Row positions in the validation slice. :param test_idx: Row positions in the test slice. :param train_cutoff: Last timestamp included in train. :param val_cutoff: Last timestamp included in validation. """ train_idx: np.ndarray val_idx: np.ndarray test_idx: np.ndarray train_cutoff: pd.Timestamp val_cutoff: pd.Timestamp
[docs] def temporal_split( timestamps: pd.Series, *, train_cutoff: str | pd.Timestamp, val_cutoff: str | pd.Timestamp, ) -> TemporalSplit: """Split rows into past / near-past / future slices, in that order. No row may appear in more than one slice. Rows beyond ``val_cutoff`` go to test; rows in ``(train_cutoff, val_cutoff]`` go to validation; the rest forms the train slice. :param timestamps: Per-row timestamps. Must be convertible to datetime. :param train_cutoff: Last timestamp kept in train. :param val_cutoff: Last timestamp kept in validation. Must be later than ``train_cutoff``. :returns: A :class:`TemporalSplit` with non-overlapping integer indices. :raises ValueError: If ``val_cutoff`` does not lie strictly after ``train_cutoff``. """ train_cutoff_ts = pd.Timestamp(train_cutoff) val_cutoff_ts = pd.Timestamp(val_cutoff) if val_cutoff_ts <= train_cutoff_ts: raise ValueError("val_cutoff must be strictly after train_cutoff") ts = pd.to_datetime(timestamps).reset_index(drop=True) train_mask = ts <= train_cutoff_ts val_mask = (ts > train_cutoff_ts) & (ts <= val_cutoff_ts) test_mask = ts > val_cutoff_ts return TemporalSplit( train_idx=np.flatnonzero(train_mask.values), val_idx=np.flatnonzero(val_mask.values), test_idx=np.flatnonzero(test_mask.values), train_cutoff=train_cutoff_ts, val_cutoff=val_cutoff_ts, )
[docs] def stratified_rct_folds( treatment: pd.Series, outcome: pd.Series, *, n_splits: int = 5, seed: int = 1337, ) -> list[tuple[np.ndarray, np.ndarray]]: """Build cross-validation folds that preserve the treatment/outcome ratio. Stratification is done on the joint ``treatment * 2 + outcome`` label so each fold sees a representative slice of the four cells. :param treatment: Binary 0/1 treatment indicator. :param outcome: Binary 0/1 outcome indicator. :param n_splits: Number of folds. :param seed: Random seed for fold ordering. :returns: List of ``(train_idx, val_idx)`` tuples. :raises ValueError: If treatment and outcome have different lengths. """ if len(treatment) != len(outcome): raise ValueError("treatment and outcome must share the same length") joint = treatment.astype(int).values * 2 + outcome.astype(int).values skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed) return [(train, val) for train, val in skf.split(joint, joint)]