Skip to content

Exploratory repository to study predictive survival analysis models

License

Notifications You must be signed in to change notification settings

soda-inria/survival-analysis-benchmark

Repository files navigation

Benchmarking predictive survival analysis models

This repository is dedicated to the evaluation of predictive survival models on large-ish datasets.

Software dependencies

To run the jupytercon-2023 tutorial notebooks, you will need:

conda create -n jupytercon-survival -c conda-forge python jupyterlab scikit-learn lifelines scikit-survival matplotlib-base plotly seaborn pandas pyarrow ibis-duckdb polars 

conda activate jupytercon-survival
jupyter lab

Notebooks

The notebooks folder holds the two main notebooks for the jupytercon-2023, namely:

  • tutorial_part_1.ipynb
  • tutorial_part_2.ipynb

and the ancillary notebook used to generate the dataset used in "part 1", namely:

  • truck_dataset.ipynb

Note that running truck_dataset.ipynb consumes a significant share of RAM, so you might prefer downloading the datasets from zip (500 MB) instead of generating them.

The notebooks display our benchmark results and show how to use our wrappers to cross validate various models.

  • kkbox_cv_benchmark.ipynb

    Benchmark of the KKBox challenge inspired from the pycox paper.

  • msk_mettropism.ipynb

    Exploration of the MSK cancer dataset and survival probability predictions using our models.

Datasets

WSDM - KKBox's Churn Prediction Challenge (from Kaggle)

The datasets/kkbox_churn folder contains Python code to efficiently preprocess the raw transaction logs of the KKBox's Churn Prediction Challenge using ibis and duckdb.

The objectives are to:

  • make everything reproducible from the event-based logs;
  • implement efficient, parallel and out-of-core "sessionization" of the past transactions for all members: here is a "session" is an uninterrupted sequence of transactions;
  • implement efficient, parallel and out-of-core tabularization (feature and churn target with censoring);
  • make it possible to compute the cumulative state of the subscription data and the censored churn events at any point in time.

Alternatively, kkbox_cv_benchmark.ipynb directly uses the preprocessing steps from pycox to ensure reproducibility, at the cost of memory and speed performances.

Models

This repository introduces a novel survival estimator named Gradient Boosting CIF. This estimator is based on the HistGradientBoostingClassifier of scikit-learn under the hood.

It is named CIF because it has the capability to estimate cause-specific Cumulative Incidence Functions in a competing risks setting by minimizing a cause specific Integrated Brier Score (IBS) objective function:

from models.gradient_boosted_cif import GradientBoostedCIF

X_train = np.array([
  [5.0, 0.1, 2.0],
  [3.0, 1.1, 2.2],
  [2.0, 0.3, 1.1],
  [4.0, 1.0, 0.9],
])
y_train_multi_event = np.array([
  (2, 33.2),
  (0, 10.1),
  (0, 50.0),
  (1, 20.0),
  ], dtype=[("event", np.bool_), ("duration", np.float64)]
)
time_grid = np.linspace(0.0, 30.0, 10)

gb_cif = GradientBoostedCIF(event_of_interest=1)
gb_cif.fit(X_train, y_train_multi_event, time_grid)

X_test = np.array([[3.0, 1.0, 9.0]])
cif_curves = gb_cif.predict_cumulative_incidence(X_test, time_grid)

Alternatively, you can estimate the probability of an event to be experienced at a specific time horizon:

gb_cif.predict_proba(X_test, time_horizon=20)

Conversely, you can estimate the conditional quantile time to event e.g. answering the question "at which time horizon does the CIF reach 50%?":

gb_cif.predict_quantile(X_test, quantile=0.5)

You can also estimate the survival function in the single event setting (binary event). Warning: this metric only makes sense when y is binary or when setting event_of_interest='any'.

y_train_single_event = np.array([
  (1, 12.0),
  (0, 5.1),
  (1, 1.1),
  (0, 29.0),
  ], dtype=[("event", np.bool_), ("duration", np.float64)]
) 
gb_cif.fit(X_train, y_train_single_event, time_grid)
survival_curves = gb_cif.predict_survival_function(X_test, time_grid)