Methods for statistical inference on the generalization error.
Package website: release | dev
# Install release from CRAN
install.packages("mlr3inferr")
# Install development version from GitHub
::pkg_install("mlr-org/mlr3inferr") pak
mlr3inferr
?The main purpose of the package is to allow to obtain confidence
intervals for the generalization error for a number of resampling
methods. Below, we evaluate a decision tree on the sonar task using a
holdout resampling and obtain a confidence interval for the
generalization error. This is achieved using the
msr("ci.holdout")
measure, to which we pass another
mlr3::Measure
that determines the loss function.
library(mlr3inferr)
= resample(tsk("sonar"), lrn("classif.rpart"), rsmp("holdout"))
rr # 0.05 is also the default
= msr("ci.holdout", "classif.acc", alpha = 0.05)
ci $aggregate(ci)
rr#> classif.acc classif.acc.lower classif.acc.upper
#> 0.7391304 0.6347628 0.8434981
It is also possible to select the default inference method for a
certain Resampling
method using msr("ci")
= msr("ci", "classif.acc")
ci_default $aggregate(ci_default)
rr#> classif.acc classif.acc.lower classif.acc.upper
#> 0.7391304 0.6347628 0.8434981
With mlr3viz
,
it is also possible to visualize multiple confidence intervals. Below,
we compare a random forest with a decision tree and a featureless
learner:
library(mlr3learners)
library(mlr3viz)
= benchmark(benchmark_grid(
bmr tsks(c("sonar", "german_credit")),
lrns(c("classif.rpart", "classif.ranger", "classif.featureless")),
rsmp("subsampling")
))
autoplot(bmr, "ci", msr("ci", "classif.ce"))
Note that:
$obs_loss
field.:warning: Different point estimates for the same measure
In general, the point estimate of
$aggregate(msr("ci", <key>))
will
not always exactly identical to the point estimate of
$aggregate(msr(<key>))
. This is because the point
estimation for the former is defined by the inference method, and can
for example, as is the case for nested cross-validation, contain a bias
correction term, or use a different aggregation method. This is
demonstrated in the example below.
= resample(tsk("iris"), lrn("classif.rpart"), rsmp("ncv", folds = 5L, repeats = 20L))
rr = msr("classif.ce")
ce = msr("ci", ce)
ci c(rr$aggregate(ce)[[1]], rr$aggregate(ci)[[1]])
#> [1] 0.06466667 0.06646667
Key | Label | Resamplings | Only Pointwise Loss |
---|---|---|---|
ci.con_z | Conservative-Z Interval | PairedSubsampling | false |
ci.cor_t | Corrected-T Interval | Subsampling | false |
ci.holdout | Holdout Interval | Holdout | yes |
ci.ncv | Nested CV Interval | NestedCV | yes |
ci.wald_cv | Wald CV Interval | CV, LOO | yes |
If you use mlr3inferr, please cite our paper:
@misc{kuempelfischer2024ciforge,
title={Constructing Confidence Intervals for 'the' Generalization Error -- a Comprehensive Benchmark Study},
author={Hannah Schulz-Kümpel and Sebastian Fischer and Thomas Nagler and Anne-Laure Boulesteix and Bernd Bischl and Roman Hornung},
year={2024},
eprint={2409.18836},
archivePrefix={arXiv},
primaryClass={stat.ML},
url={https://arxiv.org/abs/2409.18836},
}
This R package is developed as part of the Mathematical Research Data Initiative.
mlr3inferr is a free and open source software project that encourages participation and feedback. If you have any issues, questions, suggestions or feedback, please do not hesitate to open an “issue” about it on the GitHub page!
In case of problems / bugs, it is often helpful if you provide a “minimum working example” that showcases the behaviour (but don’t worry about this if the bug is obvious).
Please understand that the resources of the project are limited: response may sometimes be delayed by a few days, and some feature suggestions may be rejected if they are deemed too tangential to the vision behind the project.