Source code for flightrisk.eval.survival_metrics

from __future__ import annotations

from collections.abc import Mapping
from dataclasses import dataclass

import numpy as np
from sksurv.metrics import (
    concordance_index_censored,
    cumulative_dynamic_auc,
    integrated_brier_score,
)

from flightrisk.models.survival.labels import to_structured_array


[docs] @dataclass(frozen=True) class SurvivalMetrics: """Metrics emitted by the survival track. :param c_index: Harrell's concordance on censored data. :param time_dependent_auc_mean: Mean of time-dependent AUC across horizons. :param integrated_brier: Integrated Brier score across horizons. """ c_index: float time_dependent_auc_mean: float integrated_brier: float
[docs] def as_dict(self) -> Mapping[str, float]: """Return the metrics as a plain ``dict`` for MLflow logging. ``NaN`` entries are dropped so MLflow's metric store accepts the call when the IPCW estimator could not be computed (e.g. degenerate censoring). :returns: Mapping from metric name to finite value. """ candidates = { "c_index": self.c_index, "time_dependent_auc_mean": self.time_dependent_auc_mean, "integrated_brier": self.integrated_brier, } return {k: v for k, v in candidates.items() if not (v is None or v != v)}
[docs] def survival_metrics( *, train_durations: np.ndarray, train_events: np.ndarray, test_durations: np.ndarray, test_events: np.ndarray, risk_scores: np.ndarray, survival_at_horizons: np.ndarray, horizons: np.ndarray, ) -> SurvivalMetrics: """Compute the survival metric suite. :param train_durations: Train follow-up times. :param train_events: Train 0/1 event indicators. :param test_durations: Test follow-up times. :param test_events: Test 0/1 event indicators. :param risk_scores: Per-row risk score on test (higher = earlier event). :param survival_at_horizons: ``(n_samples, len(horizons))`` matrix of S(t). :param horizons: 1-D array of horizons matching ``survival_at_horizons``. :returns: A :class:`SurvivalMetrics` bundle. """ train_struct = to_structured_array(train_durations, train_events) test_struct = to_structured_array(test_durations, test_events) c_index = float( concordance_index_censored(test_struct["event"], test_struct["time"], risk_scores)[0] ) try: auc_per_horizon, _ = cumulative_dynamic_auc( train_struct, test_struct, risk_scores, times=horizons ) td_auc = float(np.mean(auc_per_horizon)) except (ValueError, ZeroDivisionError): td_auc = float("nan") try: brier = float( integrated_brier_score(train_struct, test_struct, survival_at_horizons, times=horizons) ) except (ValueError, ZeroDivisionError): brier = float("nan") return SurvivalMetrics( c_index=c_index, time_dependent_auc_mean=td_auc, integrated_brier=brier, )