View Jupyter notebook on the GitHub.

Prediction intervals#

Binder

This notebook contains overview of prediction intervals functionality in ETNA library.

Table of contents

  • Loading and preparing data

  • Estimating intervals using builtin method

    • Accessing prediction intervals in ``TSDataset` <#chapter2_1>`__

    • Computing interval metrics

  • Estimating prediction intervals using ``experimental.prediction_intervals` module <#chapter3>`__

    • `NaiveVariancePredictionIntervals <#chapter3_1>`__

    • `ConformalPredictionIntervals <#chapter3_2>`__

    • `EmpiricalPredictionIntervals <#chapter3_3>`__

    • Prediction intervals for ensembles

  • Custom prediction interval method

    • Non-parametric method

    • Estimating historical residuals

[1]:
import warnings
from copy import deepcopy

import numpy as np
import pandas as pd

from etna.analysis.forecast import plot_forecast
from etna.datasets import TSDataset
from etna.metrics import Coverage
from etna.metrics import Width
from etna.models import CatBoostMultiSegmentModel
from etna.pipeline import Pipeline
from etna.transforms import DateFlagsTransform
from etna.transforms import LagTransform
from etna.transforms import SegmentEncoderTransform

warnings.filterwarnings("ignore")
[2]:
HORIZON = 30

Loading and preparing data#

Consider the dataset data/example_dataset.csv.

This data will be used to show how prediction intervals could be estimated and accessed in ETNA library.

The first step is to load data and convert it to the TSDataset.

[3]:
df = pd.read_csv("data/example_dataset.csv")
df = TSDataset.to_dataset(df=df)
ts = TSDataset(df=df, freq="D")

ts
[3]:
segment segment_a segment_b segment_c segment_d
feature target target target target
timestamp
2019-01-01 170 102 92 238
2019-01-02 243 123 107 358
2019-01-03 267 130 103 366
2019-01-04 287 138 103 385
2019-01-05 279 137 104 384
... ... ... ... ...
2019-11-26 591 259 196 941
2019-11-27 606 264 196 949
2019-11-28 555 242 207 896
2019-11-29 581 247 186 905
2019-11-30 502 206 169 721

334 rows × 4 columns

[4]:
ts.plot()
../_images/tutorials_306-prediction_intervals_6_0.png

Here we have four segments in the dataset. All segments have seasonalities, and some of them show signs of trend. Note that segment C contains an obvious outlier, that may affect quality of estimated intervals.

In the next step, we split our dataset into two parts: train and test. The test part will be used as a hold-out dataset for metrics computation and result analysis.

[5]:
train_ts, test_ts = ts.train_test_split(test_size=HORIZON)

Estimating intervals using builtin method#

Prediction interval is an estimation of the range in which a future observation will fall, with a certain probability, given historical observations.

There are several ways of estimation: model-specific and model-agnostic methods. Model-specific methods use features of underlying models, that are able to produce probabilistic estimates. Examples of such models are SARIMAX, Holt-Winters and TBATS. Model-agnostic methods treat models as black boxes and implement separate methods to do the estimation. The topic of this notebook is model-agnostic methods only.

Currently there are several types of prediction intervals in the library: 1. Quantiles estimates 2. Arbitrary interval borders, that tend to provide desired coverage

Quantiles estimation methods, implemented in the library, use univariate distribution to estimate quantiles at each timestamp in the horizon. There is the possibility of treating all timestamps in the horizon jointly as multivariate random variable to estimate quantiles, but this approach is not implemented right now. The extension of current method pool will be discussed in the last section of this notebook.

So there are some naming convention to achieve distinction between two types of intervals. Borders that approximate quantiles named using the following format {target_{q:.4g}}, where q is the corresponding quantile level. And there are no particular rules for the arbitrary borders. But it is implementation responsibility to name them appropriately.

Before estimating prediction intervals we need to fit a model. Here CatBoostMultiSegmentModel is used with lag and date features. This model requires computed features, so we add corresponding transforms to the pipeline.

[6]:
seg = SegmentEncoderTransform()
lags = LagTransform(in_column="target", lags=list(range(HORIZON, 20 + HORIZON)), out_column="lag")
date_flags = DateFlagsTransform(
    day_number_in_week=True,
    day_number_in_month=True,
    week_number_in_month=True,
    week_number_in_year=True,
    month_number_in_year=True,
    year_number=True,
    is_weekend=True,
    out_column="flag",
)

transforms = [lags, date_flags, seg]
[7]:
model = CatBoostMultiSegmentModel()
[8]:
pipeline = Pipeline(model=model, transforms=transforms, horizon=HORIZON)

pipeline.fit(ts=train_ts);

After the pipeline is defined and fitted, we are able to estimate prediction intervals with the default method. To do so set the prediction_interval=True parameter of the forecast method.

This method is based on residual variance estimation and \(z\)-scores. Variance estimation is done via running historical backtest on non-overlapping folds. Number of folds is controlled by the n_folds parameter.

[9]:
forecast = pipeline.forecast(ts=train_ts, prediction_interval=True, n_folds=7)
forecast
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    5.8s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:   10.4s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:   16.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:   22.3s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   5 out of   5 | elapsed:   27.9s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   6 out of   6 | elapsed:   33.9s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   7 out of   7 | elapsed:   39.6s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   7 out of   7 | elapsed:   39.6s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.3s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.4s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.6s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   5 out of   5 | elapsed:    0.7s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   6 out of   6 | elapsed:    0.9s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   7 out of   7 | elapsed:    1.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   7 out of   7 | elapsed:    1.0s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.2s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   5 out of   5 | elapsed:    0.2s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   6 out of   6 | elapsed:    0.2s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   7 out of   7 | elapsed:    0.2s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   7 out of   7 | elapsed:    0.2s finished
[9]:
segment segment_a ... segment_d
feature flag_day_number_in_month flag_day_number_in_week flag_is_weekend flag_month_number_in_year flag_week_number_in_month flag_week_number_in_year flag_year_number lag_30 lag_31 lag_32 ... lag_44 lag_45 lag_46 lag_47 lag_48 lag_49 segment_code target target_0.025 target_0.975
timestamp
2019-11-01 1 4 False 11 1 44 2019 516.0 558.0 551.0 ... 860.0 859.0 833.0 592.0 616.0 824.0 3 884.890130 733.931966 1035.848293
2019-11-02 2 5 True 11 1 44 2019 489.0 516.0 558.0 ... 822.0 860.0 859.0 833.0 592.0 616.0 3 766.540795 615.582631 917.498959
2019-11-03 3 6 True 11 1 44 2019 471.0 489.0 516.0 ... 908.0 822.0 860.0 859.0 833.0 592.0 3 735.275270 584.317106 886.233433
2019-11-04 4 0 False 11 2 45 2019 371.0 471.0 489.0 ... 648.0 908.0 822.0 860.0 859.0 833.0 3 871.086070 720.127906 1022.044233
2019-11-05 5 1 False 11 2 45 2019 359.0 371.0 471.0 ... 599.0 648.0 908.0 822.0 860.0 859.0 3 854.657442 703.699278 1005.615606
2019-11-06 6 2 False 11 2 45 2019 499.0 359.0 371.0 ... 821.0 599.0 648.0 908.0 822.0 860.0 3 845.998526 695.040362 996.956690
2019-11-07 7 3 False 11 2 45 2019 528.0 499.0 359.0 ... 883.0 821.0 599.0 648.0 908.0 822.0 3 863.000343 712.042180 1013.958507
2019-11-08 8 4 False 11 2 45 2019 550.0 528.0 499.0 ... 923.0 883.0 821.0 599.0 648.0 908.0 3 886.018643 735.060480 1036.976807
2019-11-09 9 5 True 11 2 45 2019 547.0 550.0 528.0 ... 908.0 923.0 883.0 821.0 599.0 648.0 3 786.671129 635.712965 937.629293
2019-11-10 10 6 True 11 2 45 2019 544.0 547.0 550.0 ... 874.0 908.0 923.0 883.0 821.0 599.0 3 770.142993 619.184829 921.101157
2019-11-11 11 0 False 11 3 46 2019 423.0 544.0 547.0 ... 712.0 874.0 908.0 923.0 883.0 821.0 3 883.026096 732.067932 1033.984259
2019-11-12 12 1 False 11 3 46 2019 402.0 423.0 544.0 ... 691.0 712.0 874.0 908.0 923.0 883.0 3 867.552488 716.594324 1018.510651
2019-11-13 13 2 False 11 3 46 2019 550.0 402.0 423.0 ... 980.0 691.0 712.0 874.0 908.0 923.0 3 850.850720 699.892556 1001.808883
2019-11-14 14 3 False 11 3 46 2019 582.0 550.0 402.0 ... 1037.0 980.0 691.0 712.0 874.0 908.0 3 866.155902 715.197738 1017.114066
2019-11-15 15 4 False 11 3 46 2019 559.0 582.0 550.0 ... 969.0 1037.0 980.0 691.0 712.0 874.0 3 891.686379 740.728216 1042.644543
2019-11-16 16 5 True 11 3 46 2019 543.0 559.0 582.0 ... 929.0 969.0 1037.0 980.0 691.0 712.0 3 795.462403 644.504240 946.420567
2019-11-17 17 6 True 11 3 46 2019 523.0 543.0 559.0 ... 874.0 929.0 969.0 1037.0 980.0 691.0 3 794.104098 643.145934 945.062262
2019-11-18 18 0 False 11 4 47 2019 422.0 523.0 543.0 ... 664.0 874.0 929.0 969.0 1037.0 980.0 3 869.079738 718.121574 1020.037902
2019-11-19 19 1 False 11 4 47 2019 403.0 422.0 523.0 ... 623.0 664.0 874.0 929.0 969.0 1037.0 3 855.177747 704.219583 1006.135911
2019-11-20 20 2 False 11 4 47 2019 538.0 403.0 422.0 ... 841.0 623.0 664.0 874.0 929.0 969.0 3 850.930059 699.971896 1001.888223
2019-11-21 21 3 False 11 4 47 2019 532.0 538.0 403.0 ... 879.0 841.0 623.0 664.0 874.0 929.0 3 870.278782 719.320619 1021.236946
2019-11-22 22 4 False 11 4 47 2019 515.0 532.0 538.0 ... 886.0 879.0 841.0 623.0 664.0 874.0 3 890.245006 739.286843 1041.203170
2019-11-23 23 5 True 11 4 47 2019 520.0 515.0 532.0 ... 934.0 886.0 879.0 841.0 623.0 664.0 3 775.886589 624.928425 926.844753
2019-11-24 24 6 True 11 4 47 2019 511.0 520.0 515.0 ... 885.0 934.0 886.0 879.0 841.0 623.0 3 751.033775 600.075611 901.991938
2019-11-25 25 0 False 11 5 48 2019 502.0 511.0 520.0 ... 672.0 885.0 934.0 886.0 879.0 841.0 3 874.921950 723.963786 1025.880114
2019-11-26 26 1 False 11 5 48 2019 499.0 502.0 511.0 ... 621.0 672.0 885.0 934.0 886.0 879.0 3 878.395149 727.436985 1029.353313
2019-11-27 27 2 False 11 5 48 2019 534.0 499.0 502.0 ... 859.0 621.0 672.0 885.0 934.0 886.0 3 860.259958 709.301795 1011.218122
2019-11-28 28 3 False 11 5 48 2019 502.0 534.0 499.0 ... 931.0 859.0 621.0 672.0 885.0 934.0 3 889.115142 738.156978 1040.073305
2019-11-29 29 4 False 11 5 48 2019 497.0 502.0 534.0 ... 897.0 931.0 859.0 621.0 672.0 885.0 3 891.706478 740.748315 1042.664642
2019-11-30 30 5 True 11 5 48 2019 501.0 497.0 502.0 ... 882.0 897.0 931.0 859.0 621.0 672.0 3 779.465792 628.507628 930.423955

30 rows × 124 columns

Here we have a point forecast for the full horizon, along with an estimated prediction interval for each segment.

The section below describes how one can perform manipulations with intervals in the dataset with forecasts.

Accessing prediction intervals in TSDataset#

Column names for the estimated prediction intervals can be obtained using the TSDataset.prediction_intervals_names property.

[10]:
forecast.prediction_intervals_names
[10]:
('target_0.025', 'target_0.975')

Here segment names are omitted, because they share interval estimation method. So column names identical for all the segments.

A dataframe with prediction intervals for each segment can be obtained by using the TSDataset.get_prediction_intervals() method.

Here, we save such dataframe to a separate object to use later.

[11]:
prediction_intervals = forecast.get_prediction_intervals()
prediction_intervals
[11]:
segment segment_a segment_b segment_c segment_d
feature target_0.025 target_0.975 target_0.025 target_0.975 target_0.025 target_0.975 target_0.025 target_0.975
timestamp
2019-11-01 470.753353 654.622675 182.998175 287.116897 33.867819 260.423229 733.931966 1035.848293
2019-11-02 339.787734 523.657056 137.674295 241.793017 5.273997 231.829407 615.582631 917.498959
2019-11-03 328.057195 511.926517 136.459431 240.578153 8.896867 235.452277 584.317106 886.233433
2019-11-04 464.257076 648.126398 193.140438 297.259160 41.574061 268.129472 720.127906 1022.044233
2019-11-05 471.667537 655.536859 197.889160 302.007882 46.067009 272.622419 703.699278 1005.615606
2019-11-06 458.729249 642.598571 189.420442 293.539164 44.816511 271.371921 695.040362 996.956690
2019-11-07 477.280688 661.150010 197.562499 301.681221 46.092846 272.648256 712.042180 1013.958507
2019-11-08 474.032182 657.901504 189.213408 293.332130 48.396520 274.951930 735.060480 1036.976807
2019-11-09 342.156468 526.025790 145.463552 249.582274 19.344079 245.899489 635.712965 937.629293
2019-11-10 335.175499 519.044821 143.336403 247.455125 17.194591 243.750001 619.184829 921.101157
2019-11-11 488.455591 672.324913 202.226734 306.345456 52.974858 279.530268 732.067932 1033.984259
2019-11-12 483.315864 667.185186 203.543648 307.662371 55.758080 282.313490 716.594324 1018.510651
2019-11-13 466.986283 650.855605 194.021186 298.139908 50.141150 276.696560 699.892556 1001.808883
2019-11-14 487.949334 671.818656 199.896469 304.015191 52.315531 278.870942 715.197738 1017.114066
2019-11-15 482.792684 666.662006 193.939262 298.057985 61.206762 287.762172 740.728216 1042.644543
2019-11-16 341.268023 525.137344 149.079266 253.197988 23.712223 250.267633 644.504240 946.420567
2019-11-17 342.778361 526.647683 152.365006 256.483728 30.379443 256.934853 643.145934 945.062262
2019-11-18 479.061701 662.931023 202.628841 306.747563 62.492294 289.047705 718.121574 1020.037902
2019-11-19 499.022561 682.891883 198.347958 302.466681 57.142646 283.698056 704.219583 1006.135911
2019-11-20 479.542225 663.411547 203.284886 307.403608 54.747406 281.302817 699.971896 1001.888223
2019-11-21 481.230732 665.100054 205.877484 309.996207 63.412175 289.967586 719.320619 1021.236946
2019-11-22 464.535473 648.404795 190.887708 295.006430 62.151579 288.706989 739.286843 1041.203170
2019-11-23 346.853867 530.723188 159.171331 263.290053 32.471348 259.026758 624.928425 926.844753
2019-11-24 343.750821 527.620143 157.939006 262.057728 34.721080 261.276491 600.075611 901.991938
2019-11-25 492.355361 676.224683 212.837537 316.956260 80.679910 307.235320 723.963786 1025.880114
2019-11-26 495.916553 679.785875 212.490423 316.609145 72.072010 298.627420 727.436985 1029.353313
2019-11-27 503.081899 686.951221 216.030382 320.149104 83.249364 309.804775 709.301795 1011.218122
2019-11-28 504.440254 688.309575 209.345608 313.464330 80.577678 307.133088 738.156978 1040.073305
2019-11-29 495.299379 679.168701 199.689472 303.808194 76.070869 302.626279 740.748315 1042.664642
2019-11-30 379.529016 563.398338 148.540798 252.659521 37.882227 264.437637 628.507628 930.423955

If estimated intervals are no longer needed or there is a necessity to remove prediction intervals from the dataset use TSDataset.drop_prediction_intervals() method.

[12]:
forecast.drop_prediction_intervals()
forecast.prediction_intervals_names
[12]:
()

Here we see that property contains an empty tuple now. It is an indication that no intervals are registered.

[13]:
forecast.get_prediction_intervals()

Calling TSDataset.get_prediction_intervals() in such a case will return None.

There is a possibility of adding existing prediction intervals to the dataset. To do so, one should use TSDataset.add_prediction_intervals() method.

There are a couple requirements when adding existing intervals to the dataset. 1. Absence of the intervals in the dataset. This could be checked via the prediction_intervals_names property. 2. The dataframe with intervals should be in ETNA wide format. 3. All segments should be matched between the dataset and intervals dataframe. 4. Interval borders names are matched across all the segments.

[14]:
forecast.add_prediction_intervals(prediction_intervals_df=prediction_intervals)
forecast.prediction_intervals_names
[14]:
('target_0.025', 'target_0.975')
[15]:
forecast
[15]:
segment segment_a ... segment_d
feature flag_day_number_in_month flag_day_number_in_week flag_is_weekend flag_month_number_in_year flag_week_number_in_month flag_week_number_in_year flag_year_number lag_30 lag_31 lag_32 ... lag_44 lag_45 lag_46 lag_47 lag_48 lag_49 segment_code target target_0.025 target_0.975
timestamp
2019-11-01 1 4 False 11 1 44 2019 516.0 558.0 551.0 ... 860.0 859.0 833.0 592.0 616.0 824.0 3 884.890130 733.931966 1035.848293
2019-11-02 2 5 True 11 1 44 2019 489.0 516.0 558.0 ... 822.0 860.0 859.0 833.0 592.0 616.0 3 766.540795 615.582631 917.498959
2019-11-03 3 6 True 11 1 44 2019 471.0 489.0 516.0 ... 908.0 822.0 860.0 859.0 833.0 592.0 3 735.275270 584.317106 886.233433
2019-11-04 4 0 False 11 2 45 2019 371.0 471.0 489.0 ... 648.0 908.0 822.0 860.0 859.0 833.0 3 871.086070 720.127906 1022.044233
2019-11-05 5 1 False 11 2 45 2019 359.0 371.0 471.0 ... 599.0 648.0 908.0 822.0 860.0 859.0 3 854.657442 703.699278 1005.615606
2019-11-06 6 2 False 11 2 45 2019 499.0 359.0 371.0 ... 821.0 599.0 648.0 908.0 822.0 860.0 3 845.998526 695.040362 996.956690
2019-11-07 7 3 False 11 2 45 2019 528.0 499.0 359.0 ... 883.0 821.0 599.0 648.0 908.0 822.0 3 863.000343 712.042180 1013.958507
2019-11-08 8 4 False 11 2 45 2019 550.0 528.0 499.0 ... 923.0 883.0 821.0 599.0 648.0 908.0 3 886.018643 735.060480 1036.976807
2019-11-09 9 5 True 11 2 45 2019 547.0 550.0 528.0 ... 908.0 923.0 883.0 821.0 599.0 648.0 3 786.671129 635.712965 937.629293
2019-11-10 10 6 True 11 2 45 2019 544.0 547.0 550.0 ... 874.0 908.0 923.0 883.0 821.0 599.0 3 770.142993 619.184829 921.101157
2019-11-11 11 0 False 11 3 46 2019 423.0 544.0 547.0 ... 712.0 874.0 908.0 923.0 883.0 821.0 3 883.026096 732.067932 1033.984259
2019-11-12 12 1 False 11 3 46 2019 402.0 423.0 544.0 ... 691.0 712.0 874.0 908.0 923.0 883.0 3 867.552488 716.594324 1018.510651
2019-11-13 13 2 False 11 3 46 2019 550.0 402.0 423.0 ... 980.0 691.0 712.0 874.0 908.0 923.0 3 850.850720 699.892556 1001.808883
2019-11-14 14 3 False 11 3 46 2019 582.0 550.0 402.0 ... 1037.0 980.0 691.0 712.0 874.0 908.0 3 866.155902 715.197738 1017.114066
2019-11-15 15 4 False 11 3 46 2019 559.0 582.0 550.0 ... 969.0 1037.0 980.0 691.0 712.0 874.0 3 891.686379 740.728216 1042.644543
2019-11-16 16 5 True 11 3 46 2019 543.0 559.0 582.0 ... 929.0 969.0 1037.0 980.0 691.0 712.0 3 795.462403 644.504240 946.420567
2019-11-17 17 6 True 11 3 46 2019 523.0 543.0 559.0 ... 874.0 929.0 969.0 1037.0 980.0 691.0 3 794.104098 643.145934 945.062262
2019-11-18 18 0 False 11 4 47 2019 422.0 523.0 543.0 ... 664.0 874.0 929.0 969.0 1037.0 980.0 3 869.079738 718.121574 1020.037902
2019-11-19 19 1 False 11 4 47 2019 403.0 422.0 523.0 ... 623.0 664.0 874.0 929.0 969.0 1037.0 3 855.177747 704.219583 1006.135911
2019-11-20 20 2 False 11 4 47 2019 538.0 403.0 422.0 ... 841.0 623.0 664.0 874.0 929.0 969.0 3 850.930059 699.971896 1001.888223
2019-11-21 21 3 False 11 4 47 2019 532.0 538.0 403.0 ... 879.0 841.0 623.0 664.0 874.0 929.0 3 870.278782 719.320619 1021.236946
2019-11-22 22 4 False 11 4 47 2019 515.0 532.0 538.0 ... 886.0 879.0 841.0 623.0 664.0 874.0 3 890.245006 739.286843 1041.203170
2019-11-23 23 5 True 11 4 47 2019 520.0 515.0 532.0 ... 934.0 886.0 879.0 841.0 623.0 664.0 3 775.886589 624.928425 926.844753
2019-11-24 24 6 True 11 4 47 2019 511.0 520.0 515.0 ... 885.0 934.0 886.0 879.0 841.0 623.0 3 751.033775 600.075611 901.991938
2019-11-25 25 0 False 11 5 48 2019 502.0 511.0 520.0 ... 672.0 885.0 934.0 886.0 879.0 841.0 3 874.921950 723.963786 1025.880114
2019-11-26 26 1 False 11 5 48 2019 499.0 502.0 511.0 ... 621.0 672.0 885.0 934.0 886.0 879.0 3 878.395149 727.436985 1029.353313
2019-11-27 27 2 False 11 5 48 2019 534.0 499.0 502.0 ... 859.0 621.0 672.0 885.0 934.0 886.0 3 860.259958 709.301795 1011.218122
2019-11-28 28 3 False 11 5 48 2019 502.0 534.0 499.0 ... 931.0 859.0 621.0 672.0 885.0 934.0 3 889.115142 738.156978 1040.073305
2019-11-29 29 4 False 11 5 48 2019 497.0 502.0 534.0 ... 897.0 931.0 859.0 621.0 672.0 885.0 3 891.706478 740.748315 1042.664642
2019-11-30 30 5 True 11 5 48 2019 501.0 497.0 502.0 ... 882.0 897.0 931.0 859.0 621.0 672.0 3 779.465792 628.507628 930.423955

30 rows × 124 columns

We called prediction_intervals_names here to ensure that intervals were correctly added and printed out the resulting dataset.

Results visualization could be done using the plot_forecast function. Setting parameter prediction_intervals=True will enable plotting estimated prediction intervals.

[16]:
plot_forecast(forecast, test_ts, train_ts, prediction_intervals=True, n_train_samples=30)
../_images/tutorials_306-prediction_intervals_33_0.png

Computing interval metrics#

There are a couple of metrics in the library that can help estimate the quality of computed prediction intervals: * Coverage - percentage of points in the horizon that fall between interval borders * Width - mean distance between intervals borders at each timestamp of the horizon.

These metrics require initialization. To specify which interval to use provide border names by setting lower_name and upper_name parameters. After initialization these metrics will try to find specified borders in the dataset with predicted values. If provided names are not found, a corresponding error will be raised.

If estimated intervals borders are quantiles desired levels could be selected by setting the quantiles parameter. Usage of both ways of interval selection will lead to an error.

Here we wrap metrics estimation in one function.

[17]:
def interval_metrics(test_ts, forecast):
    lower_name, upper_name = forecast.prediction_intervals_names

    coverage = Coverage(lower_name=lower_name, upper_name=upper_name)(test_ts, forecast)
    width = Width(lower_name=lower_name, upper_name=upper_name)(test_ts, forecast)

    return coverage, width
[18]:
coverage, width = interval_metrics(test_ts=test_ts, forecast=forecast)
[19]:
coverage
[19]:
{'segment_a': 0.9666666666666667,
 'segment_b': 1.0,
 'segment_c': 1.0,
 'segment_d': 0.9666666666666667}
[20]:
width
[20]:
{'segment_a': 183.8693219098439,
 'segment_b': 104.11872216880786,
 'segment_c': 226.55541023779915,
 'segment_d': 301.9163274096744}

Estimating prediction intervals using experimental.prediction_intervals module#

The ETNA library provides several alternative methods for prediction interval estimation. All necessary functionality is in the etna.experimental.prediction_intervals module.

This section covers currently implemented methods. Also, the module provides the possibility to easily extend the method list by implementing a custom one. This topic will be discussed in the last section.

Prediction interval functionality is implemented via wrapper classes for the ETNA pipelines. During initialization, such methods require pipeline instances and necessary hyperparameters. Provided pipeline can be fitted before or after wrapping with the interval estimation method.

NaiveVariancePredictionIntervals#

This method estimates prediction quantiles using the following algorithm:

  1. Compute the residuals matrix \(r_{it} = \hat y_{it} - y_{it}\) using k-fold backtest, where \(i\) is fold index.

  2. Estimate variance for each step in the prediction horizon \(v_t = \frac{1}{k} \sum_{i = 1}^k r_{it}^2\).

  3. Use \(z\)-scores and estimated variance to compute corresponding quantiles.

Desired quantiles levels for the prediction interval can be set via quantiles of the forecast method.

[21]:
from etna.experimental.prediction_intervals import NaiveVariancePredictionIntervals

pipeline = NaiveVariancePredictionIntervals(pipeline=pipeline)

forecast = pipeline.forecast(quantiles=(0.025, 0.975), prediction_interval=True, n_folds=40)

forecast
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    6.8s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:   12.3s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:   18.3s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:   24.9s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   5 out of   5 | elapsed:   32.5s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   6 out of   6 | elapsed:   38.2s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   7 out of   7 | elapsed:   43.9s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   8 out of   8 | elapsed:   48.9s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   9 out of   9 | elapsed:   54.3s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  10 out of  10 | elapsed:  1.1min remaining:    0.0s
[Parallel(n_jobs=1)]: Done  40 out of  40 | elapsed:  4.2min finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.3s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.4s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.6s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   5 out of   5 | elapsed:    0.7s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   6 out of   6 | elapsed:    1.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   7 out of   7 | elapsed:    1.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   8 out of   8 | elapsed:    1.3s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   9 out of   9 | elapsed:    1.5s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  10 out of  10 | elapsed:    1.7s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  40 out of  40 | elapsed:    5.6s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   5 out of   5 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   6 out of   6 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   7 out of   7 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   8 out of   8 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   9 out of   9 | elapsed:    0.2s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  10 out of  10 | elapsed:    0.2s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  40 out of  40 | elapsed:    0.8s finished
[21]:
segment segment_a ... segment_d
feature flag_day_number_in_month flag_day_number_in_week flag_is_weekend flag_month_number_in_year flag_week_number_in_month flag_week_number_in_year flag_year_number lag_30 lag_31 lag_32 ... lag_44 lag_45 lag_46 lag_47 lag_48 lag_49 segment_code target target_0.025 target_0.975
timestamp
2019-11-01 1 4 False 11 1 44 2019 516.0 558.0 551.0 ... 860.0 859.0 833.0 592.0 616.0 824.0 3 884.890130 748.811092 1020.969168
2019-11-02 2 5 True 11 1 44 2019 489.0 516.0 558.0 ... 822.0 860.0 859.0 833.0 592.0 616.0 3 766.540795 613.388663 919.692927
2019-11-03 3 6 True 11 1 44 2019 471.0 489.0 516.0 ... 908.0 822.0 860.0 859.0 833.0 592.0 3 735.275270 577.335580 893.214960
2019-11-04 4 0 False 11 2 45 2019 371.0 471.0 489.0 ... 648.0 908.0 822.0 860.0 859.0 833.0 3 871.086070 720.004922 1022.167218
2019-11-05 5 1 False 11 2 45 2019 359.0 371.0 471.0 ... 599.0 648.0 908.0 822.0 860.0 859.0 3 854.657442 700.824652 1008.490232
2019-11-06 6 2 False 11 2 45 2019 499.0 359.0 371.0 ... 821.0 599.0 648.0 908.0 822.0 860.0 3 845.998526 696.777355 995.219696
2019-11-07 7 3 False 11 2 45 2019 528.0 499.0 359.0 ... 883.0 821.0 599.0 648.0 908.0 822.0 3 863.000343 709.835775 1016.164912
2019-11-08 8 4 False 11 2 45 2019 550.0 528.0 499.0 ... 923.0 883.0 821.0 599.0 648.0 908.0 3 886.018643 713.554861 1058.482426
2019-11-09 9 5 True 11 2 45 2019 547.0 550.0 528.0 ... 908.0 923.0 883.0 821.0 599.0 648.0 3 786.671129 597.917516 975.424742
2019-11-10 10 6 True 11 2 45 2019 544.0 547.0 550.0 ... 874.0 908.0 923.0 883.0 821.0 599.0 3 770.142993 578.696059 961.589928
2019-11-11 11 0 False 11 3 46 2019 423.0 544.0 547.0 ... 712.0 874.0 908.0 923.0 883.0 821.0 3 883.026096 688.904334 1077.147858
2019-11-12 12 1 False 11 3 46 2019 402.0 423.0 544.0 ... 691.0 712.0 874.0 908.0 923.0 883.0 3 867.552488 668.860909 1066.244066
2019-11-13 13 2 False 11 3 46 2019 550.0 402.0 423.0 ... 980.0 691.0 712.0 874.0 908.0 923.0 3 850.850720 652.719134 1048.982305
2019-11-14 14 3 False 11 3 46 2019 582.0 550.0 402.0 ... 1037.0 980.0 691.0 712.0 874.0 908.0 3 866.155902 669.821064 1062.490740
2019-11-15 15 4 False 11 3 46 2019 559.0 582.0 550.0 ... 969.0 1037.0 980.0 691.0 712.0 874.0 3 891.686379 680.016449 1103.356310
2019-11-16 16 5 True 11 3 46 2019 543.0 559.0 582.0 ... 929.0 969.0 1037.0 980.0 691.0 712.0 3 795.462403 576.255367 1014.669440
2019-11-17 17 6 True 11 3 46 2019 523.0 543.0 559.0 ... 874.0 929.0 969.0 1037.0 980.0 691.0 3 794.104098 572.084078 1016.124118
2019-11-18 18 0 False 11 4 47 2019 422.0 523.0 543.0 ... 664.0 874.0 929.0 969.0 1037.0 980.0 3 869.079738 646.518611 1091.640865
2019-11-19 19 1 False 11 4 47 2019 403.0 422.0 523.0 ... 623.0 664.0 874.0 929.0 969.0 1037.0 3 855.177747 631.622915 1078.732579
2019-11-20 20 2 False 11 4 47 2019 538.0 403.0 422.0 ... 841.0 623.0 664.0 874.0 929.0 969.0 3 850.930059 631.485430 1070.374688
2019-11-21 21 3 False 11 4 47 2019 532.0 538.0 403.0 ... 879.0 841.0 623.0 664.0 874.0 929.0 3 870.278782 649.083711 1091.473853
2019-11-22 22 4 False 11 4 47 2019 515.0 532.0 538.0 ... 886.0 879.0 841.0 623.0 664.0 874.0 3 890.245006 653.296062 1127.193951
2019-11-23 23 5 True 11 4 47 2019 520.0 515.0 532.0 ... 934.0 886.0 879.0 841.0 623.0 664.0 3 775.886589 528.054183 1023.718995
2019-11-24 24 6 True 11 4 47 2019 511.0 520.0 515.0 ... 885.0 934.0 886.0 879.0 841.0 623.0 3 751.033775 502.507435 999.560114
2019-11-25 25 0 False 11 5 48 2019 502.0 511.0 520.0 ... 672.0 885.0 934.0 886.0 879.0 841.0 3 874.921950 622.112739 1127.731161
2019-11-26 26 1 False 11 5 48 2019 499.0 502.0 511.0 ... 621.0 672.0 885.0 934.0 886.0 879.0 3 878.395149 619.881779 1136.908519
2019-11-27 27 2 False 11 5 48 2019 534.0 499.0 502.0 ... 859.0 621.0 672.0 885.0 934.0 886.0 3 860.259958 603.009967 1117.509950
2019-11-28 28 3 False 11 5 48 2019 502.0 534.0 499.0 ... 931.0 859.0 621.0 672.0 885.0 934.0 3 889.115142 627.464437 1150.765846
2019-11-29 29 4 False 11 5 48 2019 497.0 502.0 534.0 ... 897.0 931.0 859.0 621.0 672.0 885.0 3 891.706478 637.348950 1146.064007
2019-11-30 30 5 True 11 5 48 2019 501.0 497.0 502.0 ... 882.0 897.0 931.0 859.0 621.0 672.0 3 779.465792 521.183909 1037.747675

30 rows × 124 columns

[22]:
plot_forecast(forecast, test_ts, train_ts, prediction_intervals=True, n_train_samples=30)
../_images/tutorials_306-prediction_intervals_42_0.png
[23]:
coverage, width = interval_metrics(test_ts=test_ts, forecast=forecast)
[24]:
coverage
[24]:
{'segment_a': 0.8333333333333334,
 'segment_b': 0.9333333333333333,
 'segment_c': 0.8666666666666667,
 'segment_d': 0.9666666666666667}
[25]:
width
[25]:
{'segment_a': 108.50105291374686,
 'segment_b': 46.55147330372556,
 'segment_c': 72.3575509237814,
 'segment_d': 414.01584353697956}

ConformalPredictionIntervals#

Estimates conformal prediction intervals:

  1. Compute matrix of absolute residuals \(r_{it} = |\hat y_{it} - y_{it}|\) using k-fold historical backtest, where \(i\) is fold index.

  2. Estimate corresponding quantiles levels using the provided coverage (e.g. apply Bonferroni correction).

  3. Estimate quantiles for each horizon step separately using computed absolute residuals and levels.

Note: this method estimates arbitrary interval bounds that tend to provide a given coverage rate. So this method ignores the quantiles parameter of the forecast method.

Coverage rate and correction option should be set at the method initialization step.

[26]:
from etna.experimental.prediction_intervals import ConformalPredictionIntervals

pipeline = ConformalPredictionIntervals(pipeline=pipeline, coverage=0.95, bonferroni_correction=True)

forecast = pipeline.forecast(prediction_interval=True, n_folds=40)

forecast
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    6.7s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:   13.9s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:   19.7s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:   25.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   5 out of   5 | elapsed:   30.7s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   6 out of   6 | elapsed:   39.9s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   7 out of   7 | elapsed:   46.4s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   8 out of   8 | elapsed:   52.3s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   9 out of   9 | elapsed:   57.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  10 out of  10 | elapsed:  1.0min remaining:    0.0s
[Parallel(n_jobs=1)]: Done  40 out of  40 | elapsed:  4.0min finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.2s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.4s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.6s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.8s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   5 out of   5 | elapsed:    1.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   6 out of   6 | elapsed:    1.2s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   7 out of   7 | elapsed:    1.4s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   8 out of   8 | elapsed:    1.5s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   9 out of   9 | elapsed:    1.7s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  10 out of  10 | elapsed:    1.9s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  40 out of  40 | elapsed:    6.1s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   5 out of   5 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   6 out of   6 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   7 out of   7 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   8 out of   8 | elapsed:    0.2s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   9 out of   9 | elapsed:    0.2s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  10 out of  10 | elapsed:    0.2s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  40 out of  40 | elapsed:    0.9s finished
[26]:
segment segment_a ... segment_d
feature flag_day_number_in_month flag_day_number_in_week flag_is_weekend flag_month_number_in_year flag_week_number_in_month flag_week_number_in_year flag_year_number lag_30 lag_31 lag_32 ... lag_44 lag_45 lag_46 lag_47 lag_48 lag_49 segment_code target target_lower target_upper
timestamp
2019-11-01 1 4 False 11 1 44 2019 516.0 558.0 551.0 ... 860.0 859.0 833.0 592.0 616.0 824.0 3 884.890130 649.883469 1119.896790
2019-11-02 2 5 True 11 1 44 2019 489.0 516.0 558.0 ... 822.0 860.0 859.0 833.0 592.0 616.0 3 766.540795 510.541365 1022.540225
2019-11-03 3 6 True 11 1 44 2019 471.0 489.0 516.0 ... 908.0 822.0 860.0 859.0 833.0 592.0 3 735.275270 505.051234 965.499305
2019-11-04 4 0 False 11 2 45 2019 371.0 471.0 489.0 ... 648.0 908.0 822.0 860.0 859.0 833.0 3 871.086070 673.381908 1068.790232
2019-11-05 5 1 False 11 2 45 2019 359.0 371.0 471.0 ... 599.0 648.0 908.0 822.0 860.0 859.0 3 854.657442 638.063875 1071.251009
2019-11-06 6 2 False 11 2 45 2019 499.0 359.0 371.0 ... 821.0 599.0 648.0 908.0 822.0 860.0 3 845.998526 633.125621 1058.871431
2019-11-07 7 3 False 11 2 45 2019 528.0 499.0 359.0 ... 883.0 821.0 599.0 648.0 908.0 822.0 3 863.000343 644.282218 1081.718469
2019-11-08 8 4 False 11 2 45 2019 550.0 528.0 499.0 ... 923.0 883.0 821.0 599.0 648.0 908.0 3 886.018643 631.511196 1140.526090
2019-11-09 9 5 True 11 2 45 2019 547.0 550.0 528.0 ... 908.0 923.0 883.0 821.0 599.0 648.0 3 786.671129 526.983537 1046.358721
2019-11-10 10 6 True 11 2 45 2019 544.0 547.0 550.0 ... 874.0 908.0 923.0 883.0 821.0 599.0 3 770.142993 552.186714 988.099272
2019-11-11 11 0 False 11 3 46 2019 423.0 544.0 547.0 ... 712.0 874.0 908.0 923.0 883.0 821.0 3 883.026096 642.857515 1123.194676
2019-11-12 12 1 False 11 3 46 2019 402.0 423.0 544.0 ... 691.0 712.0 874.0 908.0 923.0 883.0 3 867.552488 611.475150 1123.629825
2019-11-13 13 2 False 11 3 46 2019 550.0 402.0 423.0 ... 980.0 691.0 712.0 874.0 908.0 923.0 3 850.850720 588.452817 1113.248622
2019-11-14 14 3 False 11 3 46 2019 582.0 550.0 402.0 ... 1037.0 980.0 691.0 712.0 874.0 908.0 3 866.155902 636.397113 1095.914691
2019-11-15 15 4 False 11 3 46 2019 559.0 582.0 550.0 ... 969.0 1037.0 980.0 691.0 712.0 874.0 3 891.686379 653.680738 1129.692021
2019-11-16 16 5 True 11 3 46 2019 543.0 559.0 582.0 ... 929.0 969.0 1037.0 980.0 691.0 712.0 3 795.462403 549.666908 1041.257899
2019-11-17 17 6 True 11 3 46 2019 523.0 543.0 559.0 ... 874.0 929.0 969.0 1037.0 980.0 691.0 3 794.104098 552.971103 1035.237093
2019-11-18 18 0 False 11 4 47 2019 422.0 523.0 543.0 ... 664.0 874.0 929.0 969.0 1037.0 980.0 3 869.079738 638.546838 1099.612638
2019-11-19 19 1 False 11 4 47 2019 403.0 422.0 523.0 ... 623.0 664.0 874.0 929.0 969.0 1037.0 3 855.177747 611.844429 1098.511065
2019-11-20 20 2 False 11 4 47 2019 538.0 403.0 422.0 ... 841.0 623.0 664.0 874.0 929.0 969.0 3 850.930059 619.085280 1082.774839
2019-11-21 21 3 False 11 4 47 2019 532.0 538.0 403.0 ... 879.0 841.0 623.0 664.0 874.0 929.0 3 870.278782 627.132036 1113.425529
2019-11-22 22 4 False 11 4 47 2019 515.0 532.0 538.0 ... 886.0 879.0 841.0 623.0 664.0 874.0 3 890.245006 625.590242 1154.899770
2019-11-23 23 5 True 11 4 47 2019 520.0 515.0 532.0 ... 934.0 886.0 879.0 841.0 623.0 664.0 3 775.886589 497.658120 1054.115058
2019-11-24 24 6 True 11 4 47 2019 511.0 520.0 515.0 ... 885.0 934.0 886.0 879.0 841.0 623.0 3 751.033775 478.400911 1023.666638
2019-11-25 25 0 False 11 5 48 2019 502.0 511.0 520.0 ... 672.0 885.0 934.0 886.0 879.0 841.0 3 874.921950 606.783185 1143.060715
2019-11-26 26 1 False 11 5 48 2019 499.0 502.0 511.0 ... 621.0 672.0 885.0 934.0 886.0 879.0 3 878.395149 595.791580 1160.998718
2019-11-27 27 2 False 11 5 48 2019 534.0 499.0 502.0 ... 859.0 621.0 672.0 885.0 934.0 886.0 3 860.259958 566.670708 1153.849209
2019-11-28 28 3 False 11 5 48 2019 502.0 534.0 499.0 ... 931.0 859.0 621.0 672.0 885.0 934.0 3 889.115142 603.341819 1174.888464
2019-11-29 29 4 False 11 5 48 2019 497.0 502.0 534.0 ... 897.0 931.0 859.0 621.0 672.0 885.0 3 891.706478 604.005315 1179.407642
2019-11-30 30 5 True 11 5 48 2019 501.0 497.0 502.0 ... 882.0 897.0 931.0 859.0 621.0 672.0 3 779.465792 483.160342 1075.771241

30 rows × 124 columns

[27]:
plot_forecast(forecast, test_ts, train_ts, prediction_intervals=True, n_train_samples=30)
../_images/tutorials_306-prediction_intervals_48_0.png
[28]:
coverage, width = interval_metrics(test_ts=test_ts, forecast=forecast)
[29]:
coverage
[29]:
{'segment_a': 0.9333333333333333,
 'segment_b': 0.9666666666666667,
 'segment_c': 0.9666666666666667,
 'segment_d': 0.9666666666666667}
[30]:
width
[30]:
{'segment_a': 141.19971750106404,
 'segment_b': 54.86083262147848,
 'segment_c': 77.48869154071357,
 'segment_d': 499.4061536460892}

EmpiricalPredictionIntervals#

Estimates prediction intervals via historical residuals:

  1. Compute matrix of residuals \(r_{it} = |\hat y_{it} - y_{it}|\) using k-fold backtest, where \(i\) is fold index.

  2. Estimate quantiles levels, that satisfy the provided coverage, for the corresponding residuals distributions.

  3. Estimate quantiles for each timestamp using computed residuals and levels.

Note: this method estimates arbitrary interval bounds that tend to provide a given coverage rate. So this method ignores the quantiles parameter of the forecast method.

Coverage rate and correction option should be set at method initialization step.

[31]:
from etna.experimental.prediction_intervals import EmpiricalPredictionIntervals

pipeline = EmpiricalPredictionIntervals(pipeline=pipeline)

forecast = pipeline.forecast(prediction_interval=True, n_folds=40)

forecast
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    5.8s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:   11.2s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:   16.7s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:   22.7s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   5 out of   5 | elapsed:   30.9s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   6 out of   6 | elapsed:   36.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   7 out of   7 | elapsed:   41.5s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   8 out of   8 | elapsed:   47.5s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   9 out of   9 | elapsed:   53.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  10 out of  10 | elapsed:   58.9s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  40 out of  40 | elapsed:  4.5min finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.2s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.4s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.6s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.7s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   5 out of   5 | elapsed:    0.9s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   6 out of   6 | elapsed:    1.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   7 out of   7 | elapsed:    1.2s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   8 out of   8 | elapsed:    1.4s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   9 out of   9 | elapsed:    1.5s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  10 out of  10 | elapsed:    1.7s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  40 out of  40 | elapsed:    6.1s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   5 out of   5 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   6 out of   6 | elapsed:    0.2s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   7 out of   7 | elapsed:    0.2s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   8 out of   8 | elapsed:    0.2s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   9 out of   9 | elapsed:    0.2s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  10 out of  10 | elapsed:    0.3s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  40 out of  40 | elapsed:    1.1s finished
[31]:
segment segment_a ... segment_d
feature flag_day_number_in_month flag_day_number_in_week flag_is_weekend flag_month_number_in_year flag_week_number_in_month flag_week_number_in_year flag_year_number lag_30 lag_31 lag_32 ... lag_44 lag_45 lag_46 lag_47 lag_48 lag_49 segment_code target target_lower target_upper
timestamp
2019-11-01 1 4 False 11 1 44 2019 516.0 558.0 551.0 ... 860.0 859.0 833.0 592.0 616.0 824.0 3 884.890130 663.833307 912.762716
2019-11-02 2 5 True 11 1 44 2019 489.0 516.0 558.0 ... 822.0 860.0 859.0 833.0 592.0 616.0 3 766.540795 569.030264 801.586843
2019-11-03 3 6 True 11 1 44 2019 471.0 489.0 516.0 ... 908.0 822.0 860.0 859.0 833.0 592.0 3 735.275270 507.103746 778.679820
2019-11-04 4 0 False 11 2 45 2019 371.0 471.0 489.0 ... 648.0 908.0 822.0 860.0 859.0 833.0 3 871.086070 681.515744 901.930019
2019-11-05 5 1 False 11 2 45 2019 359.0 371.0 471.0 ... 599.0 648.0 908.0 822.0 860.0 859.0 3 854.657442 641.407846 866.047086
2019-11-06 6 2 False 11 2 45 2019 499.0 359.0 371.0 ... 821.0 599.0 648.0 908.0 822.0 860.0 3 845.998526 660.217161 867.762434
2019-11-07 7 3 False 11 2 45 2019 528.0 499.0 359.0 ... 883.0 821.0 599.0 648.0 908.0 822.0 3 863.000343 659.673561 882.054801
2019-11-08 8 4 False 11 2 45 2019 550.0 528.0 499.0 ... 923.0 883.0 821.0 599.0 648.0 908.0 3 886.018643 679.139431 886.018643
2019-11-09 9 5 True 11 2 45 2019 547.0 550.0 528.0 ... 908.0 923.0 883.0 821.0 599.0 648.0 3 786.671129 575.006386 786.671129
2019-11-10 10 6 True 11 2 45 2019 544.0 547.0 550.0 ... 874.0 908.0 923.0 883.0 821.0 599.0 3 770.142993 557.823008 770.142993
2019-11-11 11 0 False 11 3 46 2019 423.0 544.0 547.0 ... 712.0 874.0 908.0 923.0 883.0 821.0 3 883.026096 685.016228 883.026096
2019-11-12 12 1 False 11 3 46 2019 402.0 423.0 544.0 ... 691.0 712.0 874.0 908.0 923.0 883.0 3 867.552488 643.820651 867.552488
2019-11-13 13 2 False 11 3 46 2019 550.0 402.0 423.0 ... 980.0 691.0 712.0 874.0 908.0 923.0 3 850.850720 621.234961 850.850720
2019-11-14 14 3 False 11 3 46 2019 582.0 550.0 402.0 ... 1037.0 980.0 691.0 712.0 874.0 908.0 3 866.155902 659.104947 866.155902
2019-11-15 15 4 False 11 3 46 2019 559.0 582.0 550.0 ... 969.0 1037.0 980.0 691.0 712.0 874.0 3 891.686379 661.206654 891.686379
2019-11-16 16 5 True 11 3 46 2019 543.0 559.0 582.0 ... 929.0 969.0 1037.0 980.0 691.0 712.0 3 795.462403 567.609030 795.462403
2019-11-17 17 6 True 11 3 46 2019 523.0 543.0 559.0 ... 874.0 929.0 969.0 1037.0 980.0 691.0 3 794.104098 575.237147 794.104098
2019-11-18 18 0 False 11 4 47 2019 422.0 523.0 543.0 ... 664.0 874.0 929.0 969.0 1037.0 980.0 3 869.079738 672.336771 869.079738
2019-11-19 19 1 False 11 4 47 2019 403.0 422.0 523.0 ... 623.0 664.0 874.0 929.0 969.0 1037.0 3 855.177747 640.234380 855.177747
2019-11-20 20 2 False 11 4 47 2019 538.0 403.0 422.0 ... 841.0 623.0 664.0 874.0 929.0 969.0 3 850.930059 629.823414 850.930059
2019-11-21 21 3 False 11 4 47 2019 532.0 538.0 403.0 ... 879.0 841.0 623.0 664.0 874.0 929.0 3 870.278782 638.128801 870.278782
2019-11-22 22 4 False 11 4 47 2019 515.0 532.0 538.0 ... 886.0 879.0 841.0 623.0 664.0 874.0 3 890.245006 654.294956 890.245006
2019-11-23 23 5 True 11 4 47 2019 520.0 515.0 532.0 ... 934.0 886.0 879.0 841.0 623.0 664.0 3 775.886589 518.248701 777.189834
2019-11-24 24 6 True 11 4 47 2019 511.0 520.0 515.0 ... 885.0 934.0 886.0 879.0 841.0 623.0 3 751.033775 489.768310 751.033775
2019-11-25 25 0 False 11 5 48 2019 502.0 511.0 520.0 ... 672.0 885.0 934.0 886.0 879.0 841.0 3 874.921950 620.295448 874.921950
2019-11-26 26 1 False 11 5 48 2019 499.0 502.0 511.0 ... 621.0 672.0 885.0 934.0 886.0 879.0 3 878.395149 635.385892 878.395149
2019-11-27 27 2 False 11 5 48 2019 534.0 499.0 502.0 ... 859.0 621.0 672.0 885.0 934.0 886.0 3 860.259958 605.707920 860.259958
2019-11-28 28 3 False 11 5 48 2019 502.0 534.0 499.0 ... 931.0 859.0 621.0 672.0 885.0 934.0 3 889.115142 619.360181 889.115142
2019-11-29 29 4 False 11 5 48 2019 497.0 502.0 534.0 ... 897.0 931.0 859.0 621.0 672.0 885.0 3 891.706478 633.501917 891.706478
2019-11-30 30 5 True 11 5 48 2019 501.0 497.0 502.0 ... 882.0 897.0 931.0 859.0 621.0 672.0 3 779.465792 516.031185 779.465792

30 rows × 124 columns

[32]:
plot_forecast(forecast, test_ts, train_ts, prediction_intervals=True, n_train_samples=30)
../_images/tutorials_306-prediction_intervals_54_0.png
[33]:
coverage, width = interval_metrics(test_ts=test_ts, forecast=forecast)
[34]:
coverage
[34]:
{'segment_a': 0.8666666666666667,
 'segment_b': 0.9333333333333333,
 'segment_c': 0.16666666666666666,
 'segment_d': 0.4666666666666667}
[35]:
width
[35]:
{'segment_a': 97.76013455528955,
 'segment_b': 40.370329153680345,
 'segment_c': 40.2553666590423,
 'segment_d': 231.97320117702435}

Prediction intervals for ensembles#

Pipeline ensembles could be passed to interval methods as well.

Consider a short usage example with VotingEnsemble.

[36]:
from etna.ensembles import VotingEnsemble

ensemble = VotingEnsemble(pipelines=[deepcopy(pipeline), deepcopy(pipeline)])
ensemble = NaiveVariancePredictionIntervals(pipeline=ensemble, stride=HORIZON)

forecast = pipeline.forecast(prediction_interval=True, n_folds=5)

forecast.prediction_intervals_names
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    7.8s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:   14.5s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:   21.2s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:   27.6s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   5 out of   5 | elapsed:   35.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   5 out of   5 | elapsed:   35.0s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.2s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.4s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.6s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.7s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   5 out of   5 | elapsed:    0.9s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   5 out of   5 | elapsed:    0.9s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   5 out of   5 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   5 out of   5 | elapsed:    0.1s finished
[36]:
('target_lower', 'target_upper')
[37]:
plot_forecast(forecast, test_ts, train_ts, prediction_intervals=True, n_train_samples=30)
../_images/tutorials_306-prediction_intervals_60_0.png

Custom prediction interval method#

There is a possibility in the library to extend the set of prediction intervals methods by implementing the desired algorithm. This section demonstrates how it can be done. Examples of interface and utilities usage are provided as well.

BasePredictionIntervals - base class for prediction intervals methods.

This class implements a wrapper interface for pipelines and ensembles that provides the ability to estimate prediction intervals. So it requires a pipeline instance to be provided to the __init__ method for proper initialization.

To add a particular method for pipelines, one must inherit from this class and provide an implementation for the abstract method _forecast_prediction_interval. This method should estimate and store prediction intervals for out-of-sample forecasts.

Limitations In-sample prediction is not supported by default and will raise a corresponding error while attempting to do so. This functionality could be implemented if needed by overriding the _predict method, which is responsible for building an in-sample point forecast and adding prediction intervals.

Non-parametric method#

The example below demonstrates how the interval method could be implemented.

Consider ConstantWidthInterval, which simply adds constant width to a point forecast. Here width is a hyperparameter that will be set on the method initialization step.

[38]:
from typing import Sequence

from etna.experimental.prediction_intervals import BasePredictionIntervals
from etna.pipeline import BasePipeline
[39]:
class ConstantWidthInterval(BasePredictionIntervals):
    def __init__(self, pipeline: BasePipeline, interval_width: float):
        assert interval_width > 0

        self.interval_width = interval_width
        super().__init__(pipeline=pipeline)

    def _forecast_prediction_interval(
        self, ts: TSDataset, predictions: TSDataset, quantiles: Sequence[float], n_folds: int
    ) -> TSDataset:
        predicted_target = predictions[..., "target"]

        lower_border = predicted_target - self.interval_width / 2
        upper_border = predicted_target + self.interval_width / 2

        upper_border.rename({"target": "target_upper"}, inplace=True, axis=1)
        lower_border.rename({"target": "target_lower"}, inplace=True, axis=1)

        predictions.add_prediction_intervals(prediction_intervals_df=pd.concat([lower_border, upper_border], axis=1))
        return predictions
[40]:
pipeline = ConstantWidthInterval(pipeline=pipeline, interval_width=150)

forecast = pipeline.forecast(prediction_interval=True, n_folds=40)

forecast
[40]:
segment segment_a ... segment_d
feature flag_day_number_in_month flag_day_number_in_week flag_is_weekend flag_month_number_in_year flag_week_number_in_month flag_week_number_in_year flag_year_number lag_30 lag_31 lag_32 ... lag_44 lag_45 lag_46 lag_47 lag_48 lag_49 segment_code target target_lower target_upper
timestamp
2019-11-01 1 4 False 11 1 44 2019 516.0 558.0 551.0 ... 860.0 859.0 833.0 592.0 616.0 824.0 3 884.890130 809.890130 959.890130
2019-11-02 2 5 True 11 1 44 2019 489.0 516.0 558.0 ... 822.0 860.0 859.0 833.0 592.0 616.0 3 766.540795 691.540795 841.540795
2019-11-03 3 6 True 11 1 44 2019 471.0 489.0 516.0 ... 908.0 822.0 860.0 859.0 833.0 592.0 3 735.275270 660.275270 810.275270
2019-11-04 4 0 False 11 2 45 2019 371.0 471.0 489.0 ... 648.0 908.0 822.0 860.0 859.0 833.0 3 871.086070 796.086070 946.086070
2019-11-05 5 1 False 11 2 45 2019 359.0 371.0 471.0 ... 599.0 648.0 908.0 822.0 860.0 859.0 3 854.657442 779.657442 929.657442
2019-11-06 6 2 False 11 2 45 2019 499.0 359.0 371.0 ... 821.0 599.0 648.0 908.0 822.0 860.0 3 845.998526 770.998526 920.998526
2019-11-07 7 3 False 11 2 45 2019 528.0 499.0 359.0 ... 883.0 821.0 599.0 648.0 908.0 822.0 3 863.000343 788.000343 938.000343
2019-11-08 8 4 False 11 2 45 2019 550.0 528.0 499.0 ... 923.0 883.0 821.0 599.0 648.0 908.0 3 886.018643 811.018643 961.018643
2019-11-09 9 5 True 11 2 45 2019 547.0 550.0 528.0 ... 908.0 923.0 883.0 821.0 599.0 648.0 3 786.671129 711.671129 861.671129
2019-11-10 10 6 True 11 2 45 2019 544.0 547.0 550.0 ... 874.0 908.0 923.0 883.0 821.0 599.0 3 770.142993 695.142993 845.142993
2019-11-11 11 0 False 11 3 46 2019 423.0 544.0 547.0 ... 712.0 874.0 908.0 923.0 883.0 821.0 3 883.026096 808.026096 958.026096
2019-11-12 12 1 False 11 3 46 2019 402.0 423.0 544.0 ... 691.0 712.0 874.0 908.0 923.0 883.0 3 867.552488 792.552488 942.552488
2019-11-13 13 2 False 11 3 46 2019 550.0 402.0 423.0 ... 980.0 691.0 712.0 874.0 908.0 923.0 3 850.850720 775.850720 925.850720
2019-11-14 14 3 False 11 3 46 2019 582.0 550.0 402.0 ... 1037.0 980.0 691.0 712.0 874.0 908.0 3 866.155902 791.155902 941.155902
2019-11-15 15 4 False 11 3 46 2019 559.0 582.0 550.0 ... 969.0 1037.0 980.0 691.0 712.0 874.0 3 891.686379 816.686379 966.686379
2019-11-16 16 5 True 11 3 46 2019 543.0 559.0 582.0 ... 929.0 969.0 1037.0 980.0 691.0 712.0 3 795.462403 720.462403 870.462403
2019-11-17 17 6 True 11 3 46 2019 523.0 543.0 559.0 ... 874.0 929.0 969.0 1037.0 980.0 691.0 3 794.104098 719.104098 869.104098
2019-11-18 18 0 False 11 4 47 2019 422.0 523.0 543.0 ... 664.0 874.0 929.0 969.0 1037.0 980.0 3 869.079738 794.079738 944.079738
2019-11-19 19 1 False 11 4 47 2019 403.0 422.0 523.0 ... 623.0 664.0 874.0 929.0 969.0 1037.0 3 855.177747 780.177747 930.177747
2019-11-20 20 2 False 11 4 47 2019 538.0 403.0 422.0 ... 841.0 623.0 664.0 874.0 929.0 969.0 3 850.930059 775.930059 925.930059
2019-11-21 21 3 False 11 4 47 2019 532.0 538.0 403.0 ... 879.0 841.0 623.0 664.0 874.0 929.0 3 870.278782 795.278782 945.278782
2019-11-22 22 4 False 11 4 47 2019 515.0 532.0 538.0 ... 886.0 879.0 841.0 623.0 664.0 874.0 3 890.245006 815.245006 965.245006
2019-11-23 23 5 True 11 4 47 2019 520.0 515.0 532.0 ... 934.0 886.0 879.0 841.0 623.0 664.0 3 775.886589 700.886589 850.886589
2019-11-24 24 6 True 11 4 47 2019 511.0 520.0 515.0 ... 885.0 934.0 886.0 879.0 841.0 623.0 3 751.033775 676.033775 826.033775
2019-11-25 25 0 False 11 5 48 2019 502.0 511.0 520.0 ... 672.0 885.0 934.0 886.0 879.0 841.0 3 874.921950 799.921950 949.921950
2019-11-26 26 1 False 11 5 48 2019 499.0 502.0 511.0 ... 621.0 672.0 885.0 934.0 886.0 879.0 3 878.395149 803.395149 953.395149
2019-11-27 27 2 False 11 5 48 2019 534.0 499.0 502.0 ... 859.0 621.0 672.0 885.0 934.0 886.0 3 860.259958 785.259958 935.259958
2019-11-28 28 3 False 11 5 48 2019 502.0 534.0 499.0 ... 931.0 859.0 621.0 672.0 885.0 934.0 3 889.115142 814.115142 964.115142
2019-11-29 29 4 False 11 5 48 2019 497.0 502.0 534.0 ... 897.0 931.0 859.0 621.0 672.0 885.0 3 891.706478 816.706478 966.706478
2019-11-30 30 5 True 11 5 48 2019 501.0 497.0 502.0 ... 882.0 897.0 931.0 859.0 621.0 672.0 3 779.465792 704.465792 854.465792

30 rows × 124 columns

[41]:
plot_forecast(forecast, test_ts, train_ts, prediction_intervals=True, n_train_samples=30)
../_images/tutorials_306-prediction_intervals_67_0.png
[42]:
coverage, width = interval_metrics(test_ts=test_ts, forecast=forecast)
[43]:
coverage
[43]:
{'segment_a': 0.9333333333333333,
 'segment_b': 1.0,
 'segment_c': 1.0,
 'segment_d': 0.5333333333333333}
[44]:
width
[44]:
{'segment_a': 150.0,
 'segment_b': 150.0,
 'segment_c': 150.0,
 'segment_d': 150.0}

Estimating historical residuals#

Some prediction intervals methods require doing historical forecasts. This could be done by using the pipeline’s get_historical_forecasts method. As BasePredictionIntervals wraps pipelines, this method is implemented here as well.

Consider the example MaxAbsResidInterval. This method estimates intervals based on the maximum absolute values of historical residuals for each segment. So we can break down this algorithm into the following steps:

  1. Estimate historical forecasts by calling the get_historical_forecasts method.

  2. For each segment estimate residuals, find the maximum absolute value and add to the point forecast.

[45]:
class MaxAbsResidInterval(BasePredictionIntervals):
    def __init__(self, pipeline: BasePipeline, coverage: float = 0.95, stride: int = 1):
        assert stride > 0
        assert 0 < coverage <= 1

        self.stride = stride
        self.coverage = coverage
        super().__init__(pipeline=pipeline)

    def _forecast_prediction_interval(
        self, ts: TSDataset, predictions: TSDataset, quantiles: Sequence[float], n_folds: int
    ) -> TSDataset:
        predicted_target = predictions[..., "target"]

        lower_border = predicted_target.copy()
        upper_border = predicted_target.copy()

        fold_forecast = self.get_historical_forecasts(ts=ts, n_folds=n_folds, stride=self.stride)

        for segment in ts.segments:
            residuals = (
                ts.loc[:, pd.IndexSlice[segment, "target"]] - fold_forecast.loc[:, pd.IndexSlice[segment, "target"]]
            )
            width = np.max(np.abs(residuals))

            lower_border.loc[:, pd.IndexSlice[segment, "target"]] -= self.coverage * width / 2
            upper_border.loc[:, pd.IndexSlice[segment, "target"]] += self.coverage * width / 2

        upper_border.rename({"target": "target_upper"}, inplace=True, axis=1)
        lower_border.rename({"target": "target_lower"}, inplace=True, axis=1)

        predictions.add_prediction_intervals(prediction_intervals_df=pd.concat([lower_border, upper_border], axis=1))
        return predictions
[46]:
pipeline = MaxAbsResidInterval(pipeline=pipeline)

forecast = pipeline.forecast(prediction_interval=True, n_folds=5)

forecast
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    7.5s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:   13.5s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:   19.7s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:   26.4s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   5 out of   5 | elapsed:   32.3s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   5 out of   5 | elapsed:   32.3s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.3s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.7s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.9s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   5 out of   5 | elapsed:    1.2s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   5 out of   5 | elapsed:    1.2s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.2s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   5 out of   5 | elapsed:    0.2s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   5 out of   5 | elapsed:    0.2s finished
[46]:
segment segment_a ... segment_d
feature flag_day_number_in_month flag_day_number_in_week flag_is_weekend flag_month_number_in_year flag_week_number_in_month flag_week_number_in_year flag_year_number lag_30 lag_31 lag_32 ... lag_44 lag_45 lag_46 lag_47 lag_48 lag_49 segment_code target target_lower target_upper
timestamp
2019-11-01 1 4 False 11 1 44 2019 516.0 558.0 551.0 ... 860.0 859.0 833.0 592.0 616.0 824.0 3 884.890130 761.305956 1008.474304
2019-11-02 2 5 True 11 1 44 2019 489.0 516.0 558.0 ... 822.0 860.0 859.0 833.0 592.0 616.0 3 766.540795 642.956621 890.124969
2019-11-03 3 6 True 11 1 44 2019 471.0 489.0 516.0 ... 908.0 822.0 860.0 859.0 833.0 592.0 3 735.275270 611.691096 858.859443
2019-11-04 4 0 False 11 2 45 2019 371.0 471.0 489.0 ... 648.0 908.0 822.0 860.0 859.0 833.0 3 871.086070 747.501896 994.670244
2019-11-05 5 1 False 11 2 45 2019 359.0 371.0 471.0 ... 599.0 648.0 908.0 822.0 860.0 859.0 3 854.657442 731.073268 978.241616
2019-11-06 6 2 False 11 2 45 2019 499.0 359.0 371.0 ... 821.0 599.0 648.0 908.0 822.0 860.0 3 845.998526 722.414352 969.582700
2019-11-07 7 3 False 11 2 45 2019 528.0 499.0 359.0 ... 883.0 821.0 599.0 648.0 908.0 822.0 3 863.000343 739.416169 986.584517
2019-11-08 8 4 False 11 2 45 2019 550.0 528.0 499.0 ... 923.0 883.0 821.0 599.0 648.0 908.0 3 886.018643 762.434469 1009.602817
2019-11-09 9 5 True 11 2 45 2019 547.0 550.0 528.0 ... 908.0 923.0 883.0 821.0 599.0 648.0 3 786.671129 663.086955 910.255303
2019-11-10 10 6 True 11 2 45 2019 544.0 547.0 550.0 ... 874.0 908.0 923.0 883.0 821.0 599.0 3 770.142993 646.558819 893.727167
2019-11-11 11 0 False 11 3 46 2019 423.0 544.0 547.0 ... 712.0 874.0 908.0 923.0 883.0 821.0 3 883.026096 759.441922 1006.610270
2019-11-12 12 1 False 11 3 46 2019 402.0 423.0 544.0 ... 691.0 712.0 874.0 908.0 923.0 883.0 3 867.552488 743.968314 991.136661
2019-11-13 13 2 False 11 3 46 2019 550.0 402.0 423.0 ... 980.0 691.0 712.0 874.0 908.0 923.0 3 850.850720 727.266546 974.434894
2019-11-14 14 3 False 11 3 46 2019 582.0 550.0 402.0 ... 1037.0 980.0 691.0 712.0 874.0 908.0 3 866.155902 742.571728 989.740076
2019-11-15 15 4 False 11 3 46 2019 559.0 582.0 550.0 ... 969.0 1037.0 980.0 691.0 712.0 874.0 3 891.686379 768.102206 1015.270553
2019-11-16 16 5 True 11 3 46 2019 543.0 559.0 582.0 ... 929.0 969.0 1037.0 980.0 691.0 712.0 3 795.462403 671.878230 919.046577
2019-11-17 17 6 True 11 3 46 2019 523.0 543.0 559.0 ... 874.0 929.0 969.0 1037.0 980.0 691.0 3 794.104098 670.519924 917.688272
2019-11-18 18 0 False 11 4 47 2019 422.0 523.0 543.0 ... 664.0 874.0 929.0 969.0 1037.0 980.0 3 869.079738 745.495564 992.663912
2019-11-19 19 1 False 11 4 47 2019 403.0 422.0 523.0 ... 623.0 664.0 874.0 929.0 969.0 1037.0 3 855.177747 731.593573 978.761921
2019-11-20 20 2 False 11 4 47 2019 538.0 403.0 422.0 ... 841.0 623.0 664.0 874.0 929.0 969.0 3 850.930059 727.345885 974.514233
2019-11-21 21 3 False 11 4 47 2019 532.0 538.0 403.0 ... 879.0 841.0 623.0 664.0 874.0 929.0 3 870.278782 746.694608 993.862956
2019-11-22 22 4 False 11 4 47 2019 515.0 532.0 538.0 ... 886.0 879.0 841.0 623.0 664.0 874.0 3 890.245006 766.660832 1013.829180
2019-11-23 23 5 True 11 4 47 2019 520.0 515.0 532.0 ... 934.0 886.0 879.0 841.0 623.0 664.0 3 775.886589 652.302415 899.470763
2019-11-24 24 6 True 11 4 47 2019 511.0 520.0 515.0 ... 885.0 934.0 886.0 879.0 841.0 623.0 3 751.033775 627.449601 874.617949
2019-11-25 25 0 False 11 5 48 2019 502.0 511.0 520.0 ... 672.0 885.0 934.0 886.0 879.0 841.0 3 874.921950 751.337776 998.506124
2019-11-26 26 1 False 11 5 48 2019 499.0 502.0 511.0 ... 621.0 672.0 885.0 934.0 886.0 879.0 3 878.395149 754.810975 1001.979323
2019-11-27 27 2 False 11 5 48 2019 534.0 499.0 502.0 ... 859.0 621.0 672.0 885.0 934.0 886.0 3 860.259958 736.675784 983.844132
2019-11-28 28 3 False 11 5 48 2019 502.0 534.0 499.0 ... 931.0 859.0 621.0 672.0 885.0 934.0 3 889.115142 765.530968 1012.699316
2019-11-29 29 4 False 11 5 48 2019 497.0 502.0 534.0 ... 897.0 931.0 859.0 621.0 672.0 885.0 3 891.706478 768.122305 1015.290652
2019-11-30 30 5 True 11 5 48 2019 501.0 497.0 502.0 ... 882.0 897.0 931.0 859.0 621.0 672.0 3 779.465792 655.881618 903.049966

30 rows × 124 columns

[47]:
plot_forecast(forecast, test_ts, train_ts, prediction_intervals=True, n_train_samples=30)
../_images/tutorials_306-prediction_intervals_74_0.png
[48]:
coverage, width = interval_metrics(test_ts=test_ts, forecast=forecast)
[49]:
coverage
[49]:
{'segment_a': 0.8666666666666667,
 'segment_b': 0.6666666666666666,
 'segment_c': 0.8666666666666667,
 'segment_d': 0.8}
[50]:
width
[50]:
{'segment_a': 106.01589052275085,
 'segment_b': 29.46380276236236,
 'segment_c': 50.17717248209039,
 'segment_d': 247.16834782574568}

Obtaining historical residuals for prediction intervals estimation can be simplified by using the more efficient utility function residuals_matrices. This function accepts pipeline, data, parameters for backtest and computes residuals for each segment on every fold.

Note that residuals_matrices function returns a 3 dimensional array with axes sizes (num_folds, horizon, num_segments).

Here we use this function to optimize the proposed method. Consider OptimizedMaxAbsResidInterval.

[51]:
from etna.experimental.prediction_intervals.utils import residuals_matrices


class OptimizedMaxAbsResidInterval(BasePredictionIntervals):
    def __init__(self, pipeline: BasePipeline, coverage: float = 0.95, stride: int = 1):
        assert stride > 0
        assert 0 < coverage <= 1

        self.stride = stride
        self.coverage = coverage
        super().__init__(pipeline=pipeline)

    def _forecast_prediction_interval(
        self, ts: TSDataset, predictions: TSDataset, quantiles: Sequence[float], n_folds: int
    ) -> TSDataset:
        residuals = residuals_matrices(pipeline=self, ts=ts, n_folds=n_folds, stride=self.stride)

        predicted_target = predictions[..., "target"]

        width = np.max(np.abs(residuals), axis=(0, 1)).reshape(1, -1)

        lower_border = predicted_target - self.coverage * width / 2
        upper_border = predicted_target + self.coverage * width / 2

        upper_border.rename({"target": "target_upper"}, inplace=True, axis=1)
        lower_border.rename({"target": "target_lower"}, inplace=True, axis=1)

        predictions.add_prediction_intervals(prediction_intervals_df=pd.concat([lower_border, upper_border], axis=1))
        return predictions
[52]:
pipeline = OptimizedMaxAbsResidInterval(pipeline=pipeline)

forecast = pipeline.forecast(prediction_interval=True, n_folds=3)

forecast
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    6.6s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:   12.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:   17.8s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:   17.8s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.3s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.4s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.4s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.1s finished
[52]:
segment segment_a ... segment_d
feature flag_day_number_in_month flag_day_number_in_week flag_is_weekend flag_month_number_in_year flag_week_number_in_month flag_week_number_in_year flag_year_number lag_30 lag_31 lag_32 ... lag_44 lag_45 lag_46 lag_47 lag_48 lag_49 segment_code target target_lower target_upper
timestamp
2019-11-01 1 4 False 11 1 44 2019 516.0 558.0 551.0 ... 860.0 859.0 833.0 592.0 616.0 824.0 3 884.890130 761.305956 1008.474304
2019-11-02 2 5 True 11 1 44 2019 489.0 516.0 558.0 ... 822.0 860.0 859.0 833.0 592.0 616.0 3 766.540795 642.956621 890.124969
2019-11-03 3 6 True 11 1 44 2019 471.0 489.0 516.0 ... 908.0 822.0 860.0 859.0 833.0 592.0 3 735.275270 611.691096 858.859443
2019-11-04 4 0 False 11 2 45 2019 371.0 471.0 489.0 ... 648.0 908.0 822.0 860.0 859.0 833.0 3 871.086070 747.501896 994.670244
2019-11-05 5 1 False 11 2 45 2019 359.0 371.0 471.0 ... 599.0 648.0 908.0 822.0 860.0 859.0 3 854.657442 731.073268 978.241616
2019-11-06 6 2 False 11 2 45 2019 499.0 359.0 371.0 ... 821.0 599.0 648.0 908.0 822.0 860.0 3 845.998526 722.414352 969.582700
2019-11-07 7 3 False 11 2 45 2019 528.0 499.0 359.0 ... 883.0 821.0 599.0 648.0 908.0 822.0 3 863.000343 739.416169 986.584517
2019-11-08 8 4 False 11 2 45 2019 550.0 528.0 499.0 ... 923.0 883.0 821.0 599.0 648.0 908.0 3 886.018643 762.434469 1009.602817
2019-11-09 9 5 True 11 2 45 2019 547.0 550.0 528.0 ... 908.0 923.0 883.0 821.0 599.0 648.0 3 786.671129 663.086955 910.255303
2019-11-10 10 6 True 11 2 45 2019 544.0 547.0 550.0 ... 874.0 908.0 923.0 883.0 821.0 599.0 3 770.142993 646.558819 893.727167
2019-11-11 11 0 False 11 3 46 2019 423.0 544.0 547.0 ... 712.0 874.0 908.0 923.0 883.0 821.0 3 883.026096 759.441922 1006.610270
2019-11-12 12 1 False 11 3 46 2019 402.0 423.0 544.0 ... 691.0 712.0 874.0 908.0 923.0 883.0 3 867.552488 743.968314 991.136661
2019-11-13 13 2 False 11 3 46 2019 550.0 402.0 423.0 ... 980.0 691.0 712.0 874.0 908.0 923.0 3 850.850720 727.266546 974.434894
2019-11-14 14 3 False 11 3 46 2019 582.0 550.0 402.0 ... 1037.0 980.0 691.0 712.0 874.0 908.0 3 866.155902 742.571728 989.740076
2019-11-15 15 4 False 11 3 46 2019 559.0 582.0 550.0 ... 969.0 1037.0 980.0 691.0 712.0 874.0 3 891.686379 768.102206 1015.270553
2019-11-16 16 5 True 11 3 46 2019 543.0 559.0 582.0 ... 929.0 969.0 1037.0 980.0 691.0 712.0 3 795.462403 671.878230 919.046577
2019-11-17 17 6 True 11 3 46 2019 523.0 543.0 559.0 ... 874.0 929.0 969.0 1037.0 980.0 691.0 3 794.104098 670.519924 917.688272
2019-11-18 18 0 False 11 4 47 2019 422.0 523.0 543.0 ... 664.0 874.0 929.0 969.0 1037.0 980.0 3 869.079738 745.495564 992.663912
2019-11-19 19 1 False 11 4 47 2019 403.0 422.0 523.0 ... 623.0 664.0 874.0 929.0 969.0 1037.0 3 855.177747 731.593573 978.761921
2019-11-20 20 2 False 11 4 47 2019 538.0 403.0 422.0 ... 841.0 623.0 664.0 874.0 929.0 969.0 3 850.930059 727.345885 974.514233
2019-11-21 21 3 False 11 4 47 2019 532.0 538.0 403.0 ... 879.0 841.0 623.0 664.0 874.0 929.0 3 870.278782 746.694608 993.862956
2019-11-22 22 4 False 11 4 47 2019 515.0 532.0 538.0 ... 886.0 879.0 841.0 623.0 664.0 874.0 3 890.245006 766.660832 1013.829180
2019-11-23 23 5 True 11 4 47 2019 520.0 515.0 532.0 ... 934.0 886.0 879.0 841.0 623.0 664.0 3 775.886589 652.302415 899.470763
2019-11-24 24 6 True 11 4 47 2019 511.0 520.0 515.0 ... 885.0 934.0 886.0 879.0 841.0 623.0 3 751.033775 627.449601 874.617949
2019-11-25 25 0 False 11 5 48 2019 502.0 511.0 520.0 ... 672.0 885.0 934.0 886.0 879.0 841.0 3 874.921950 751.337776 998.506124
2019-11-26 26 1 False 11 5 48 2019 499.0 502.0 511.0 ... 621.0 672.0 885.0 934.0 886.0 879.0 3 878.395149 754.810975 1001.979323
2019-11-27 27 2 False 11 5 48 2019 534.0 499.0 502.0 ... 859.0 621.0 672.0 885.0 934.0 886.0 3 860.259958 736.675784 983.844132
2019-11-28 28 3 False 11 5 48 2019 502.0 534.0 499.0 ... 931.0 859.0 621.0 672.0 885.0 934.0 3 889.115142 765.530968 1012.699316
2019-11-29 29 4 False 11 5 48 2019 497.0 502.0 534.0 ... 897.0 931.0 859.0 621.0 672.0 885.0 3 891.706478 768.122305 1015.290652
2019-11-30 30 5 True 11 5 48 2019 501.0 497.0 502.0 ... 882.0 897.0 931.0 859.0 621.0 672.0 3 779.465792 655.881618 903.049966

30 rows × 124 columns

[53]:
plot_forecast(forecast, test_ts, train_ts, prediction_intervals=True, n_train_samples=30)
../_images/tutorials_306-prediction_intervals_81_0.png
[54]:
coverage, width = interval_metrics(test_ts=test_ts, forecast=forecast)
[55]:
coverage
[55]:
{'segment_a': 0.8666666666666667,
 'segment_b': 0.6666666666666666,
 'segment_c': 0.8666666666666667,
 'segment_d': 0.8}
[56]:
width
[56]:
{'segment_a': 99.74407563536734,
 'segment_b': 29.46380276236236,
 'segment_c': 49.833980779974766,
 'segment_d': 247.16834782574568}