{ "cells": [ { "cell_type": "markdown", "id": "2c8984e0-0792-4cf8-b3c6-446b45b717f2", "metadata": {}, "source": [ "# Embedding models\n", "\n", "[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/etna-team/etna/master?filepath=examples/210-embedding_models.ipynb)" ] }, { "cell_type": "markdown", "id": "94e7669f-de54-4df8-86ba-aa72c6d5fb55", "metadata": {}, "source": [ "This notebooks contains examples with embedding models.\n", "\n", "**Table of contents**\n", "\n", "* [Using embedding models directly](#chapter1) \n", "* [Using embedding models with transforms](#chapter2)\n", " * [Baseline](#section_2_1)\n", " * [EmbeddingSegmentTransform](#section_2_2)\n", " * [EmbeddingWindowTransform](#section_2_3)\n", "* [Saving and loading models](#chapter3)" ] }, { "cell_type": "code", "execution_count": 1, "id": "bf32c6a9-f920-4888-ac9d-f4a1c454cd91", "metadata": { "tags": [] }, "outputs": [], "source": [ "import warnings\n", "\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "markdown", "id": "d732e5b1-2c10-4de3-93ce-c6395ddbd4f1", "metadata": {}, "source": [ "## 1. Using embedding models directly " ] }, { "cell_type": "markdown", "id": "4c63da5a-eed8-472b-9786-9884a5bb78d1", "metadata": {}, "source": [ "We have two models to generate embeddings for time series: `TS2VecEmbeddingModel` and `TSTCCEmbeddingModel`.\n", "\n", "Each model has following methods:\n", "\n", "- `fit` to train model:\n", "- `encode_segment` to generate embeddings for the whole series. These features are regressors.\n", "- `encode_window` to generate embeddings for each timestamp. These features aren't regressors and lag transformation should be applied to them before using in forecasting.\n", "- `freeze` to enable or disable skipping training in `fit` method. It is useful, for example, when you have a pretrained model and you want only to generate embeddings without new training during `backtest`.\n", "- `save` and `load` to save and load pretrained models, respectively." ] }, { "cell_type": "code", "execution_count": 2, "id": "d5ec9757-dd5a-423c-9be1-e4835b4b2a03", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Disabling SSL verification. Connections to this server are not verified and may be insecure!\n", "Global seed set to 42\n" ] }, { "data": { "text/plain": [ "42" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from pytorch_lightning import seed_everything\n", "\n", "seed_everything(42, workers=True)" ] }, { "cell_type": "code", "execution_count": 3, "id": "f99c90c5-8a8b-481a-848f-ebcb00b22bb0", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
segmentsegment_0segment_1segment_2
featuretargettargettarget
timestamp
2001-01-011.6243451.462108-1.100619
2001-01-021.012589-0.5980330.044105
2001-01-030.484417-0.9204500.945695
2001-01-04-0.588551-1.3045041.448190
2001-01-050.276856-0.1707352.349046
\n", "
" ], "text/plain": [ "segment segment_0 segment_1 segment_2\n", "feature target target target\n", "timestamp \n", "2001-01-01 1.624345 1.462108 -1.100619\n", "2001-01-02 1.012589 -0.598033 0.044105\n", "2001-01-03 0.484417 -0.920450 0.945695\n", "2001-01-04 -0.588551 -1.304504 1.448190\n", "2001-01-05 0.276856 -0.170735 2.349046" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from etna.datasets import TSDataset\n", "from etna.datasets import generate_ar_df\n", "\n", "df = generate_ar_df(periods=10, start_time=\"2001-01-01\", n_segments=3)\n", "ts = TSDataset(df, freq=\"D\")\n", "ts.head()" ] }, { "cell_type": "markdown", "id": "9712e58c-73fe-475e-807b-ae082752fcf8", "metadata": {}, "source": [ "Now let's work with models directly.\n", "\n", "They are expecting array with shapes\n", "(n_segments, n_timestamps, num_features). The example shows working with `TS2VecEmbeddingModel`, it is all the same with `TSTCCEmbeddingModel`." ] }, { "cell_type": "code", "execution_count": 4, "id": "05a191ee-17dd-4cb1-a993-73aee7706272", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "(3, 10, 1)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = ts.df.values.reshape(ts.size()).transpose(1, 0, 2)\n", "x.shape" ] }, { "cell_type": "code", "execution_count": 5, "id": "0263277f-b642-4c1b-8f19-a42520d6d09e", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "(3, 2)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from etna.transforms.embeddings.models import TS2VecEmbeddingModel\n", "from etna.transforms.embeddings.models import TSTCCEmbeddingModel\n", "\n", "model_ts2vec = TS2VecEmbeddingModel(input_dims=1, output_dims=2)\n", "model_ts2vec.fit(x, n_epochs=1)\n", "segment_embeddings = model_ts2vec.encode_segment(x)\n", "segment_embeddings.shape" ] }, { "cell_type": "markdown", "id": "26329bf0-e955-46ad-9962-4ea1295ef671", "metadata": {}, "source": [ "As we are using `encode_segment` we get `output_dims` features consisting of one value for each segment.\n", "\n", "And what about `encode_window`?" ] }, { "cell_type": "code", "execution_count": 6, "id": "9a307886-cdf2-4e98-9a8e-3917741f287c", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "(3, 10, 2)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "window_embeddings = model_ts2vec.encode_window(x)\n", "window_embeddings.shape" ] }, { "cell_type": "markdown", "id": "aded40fc-382c-4f7a-901b-8498f9258c3b", "metadata": {}, "source": [ "We get `output_dims` features consisting of `n_timestamps` values for each segment." ] }, { "cell_type": "markdown", "id": "b2d599a0", "metadata": {}, "source": [ "You can change some attributes of the model after initialization, for example `device`, `batch_size` or `num_workers`." ] }, { "cell_type": "code", "execution_count": 7, "id": "6ae8c79d", "metadata": {}, "outputs": [], "source": [ "model_ts2vec.device = \"cuda\"" ] }, { "cell_type": "markdown", "id": "ffbb2210-d77f-426e-91b2-3729544ce872", "metadata": {}, "source": [ "## 2. Using embedding models with transforms " ] }, { "cell_type": "markdown", "id": "459e90a9-97fb-4922-bc6a-52500b3a132e", "metadata": {}, "source": [ "In this section we will test our models on example." ] }, { "cell_type": "code", "execution_count": 8, "id": "b7827e25-4597-451a-88f8-5e0475556041", "metadata": { "tags": [] }, "outputs": [], "source": [ "HORIZON = 6" ] }, { "cell_type": "markdown", "id": "17955757-7585-4db0-889b-dbd978339822", "metadata": { "tags": [] }, "source": [ "### 2.1 Baseline " ] }, { "cell_type": "markdown", "id": "8c92c86f-fce7-442b-a7f4-0c024344bec9", "metadata": {}, "source": [ "Before working with embedding features, let's make forecasts using usual features." ] }, { "cell_type": "code", "execution_count": 9, "id": "21ac6694-1c3c-4fdc-a96a-3d0544ee90df", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
segmentM1000_MACROM1001_MACROM1002_MACROM1003_MACROM1004_MACROM1005_MACROM1006_MACROM1007_MACROM1008_MACROM1009_MACRO...M992_MACROM993_MACROM994_MACROM995_MACROM996_MACROM997_MACROM998_MACROM999_MACROM99_MICROM9_MICRO
featuretargettargettargettargettargettargettargettargettargettarget...targettargettargettargettargettargettargettargettargettarget
timestamp
0NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
1NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
2NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
3NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
4NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
\n", "

5 rows × 1428 columns

\n", "
" ], "text/plain": [ "segment M1000_MACRO M1001_MACRO M1002_MACRO M1003_MACRO M1004_MACRO \\\n", "feature target target target target target \n", "timestamp \n", "0 NaN NaN NaN NaN NaN \n", "1 NaN NaN NaN NaN NaN \n", "2 NaN NaN NaN NaN NaN \n", "3 NaN NaN NaN NaN NaN \n", "4 NaN NaN NaN NaN NaN \n", "\n", "segment M1005_MACRO M1006_MACRO M1007_MACRO M1008_MACRO M1009_MACRO ... \\\n", "feature target target target target target ... \n", "timestamp ... \n", "0 NaN NaN NaN NaN NaN ... \n", "1 NaN NaN NaN NaN NaN ... \n", "2 NaN NaN NaN NaN NaN ... \n", "3 NaN NaN NaN NaN NaN ... \n", "4 NaN NaN NaN NaN NaN ... \n", "\n", "segment M992_MACRO M993_MACRO M994_MACRO M995_MACRO M996_MACRO M997_MACRO \\\n", "feature target target target target target target \n", "timestamp \n", "0 NaN NaN NaN NaN NaN NaN \n", "1 NaN NaN NaN NaN NaN NaN \n", "2 NaN NaN NaN NaN NaN NaN \n", "3 NaN NaN NaN NaN NaN NaN \n", "4 NaN NaN NaN NaN NaN NaN \n", "\n", "segment M998_MACRO M999_MACRO M99_MICRO M9_MICRO \n", "feature target target target target \n", "timestamp \n", "0 NaN NaN NaN NaN \n", "1 NaN NaN NaN NaN \n", "2 NaN NaN NaN NaN \n", "3 NaN NaN NaN NaN \n", "4 NaN NaN NaN NaN \n", "\n", "[5 rows x 1428 columns]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from etna.datasets import load_dataset\n", "\n", "ts = load_dataset(\"m3_monthly\")\n", "ts.drop_features(features=[\"origin_timestamp\"])\n", "ts.df_exog = None\n", "ts.head()" ] }, { "cell_type": "code", "execution_count": 10, "id": "fe224f12-6b86-4513-8d61-3fa0cb895eb1", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 4.3s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 8.4s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 13.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 13.1s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.4s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.8s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.1s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.2s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.2s finished\n" ] } ], "source": [ "from etna.metrics import SMAPE\n", "from etna.models import CatBoostMultiSegmentModel\n", "from etna.pipeline import Pipeline\n", "from etna.transforms import LagTransform\n", "\n", "model = CatBoostMultiSegmentModel()\n", "\n", "lag_transform = LagTransform(in_column=\"target\", lags=list(range(HORIZON, HORIZON + 6)), out_column=\"lag\")\n", "\n", "pipeline = Pipeline(model=model, transforms=[lag_transform], horizon=HORIZON)\n", "metrics_df, _, _ = pipeline.backtest(ts, metrics=[SMAPE()], n_folds=3)" ] }, { "cell_type": "code", "execution_count": 11, "id": "bfbab09d-eb27-4529-8954-3dc0e471668a", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SMAPE: 14.719683971886594\n" ] } ], "source": [ "print(\"SMAPE: \", metrics_df[\"SMAPE\"].mean())" ] }, { "cell_type": "markdown", "id": "efa358d6-c1a0-460a-b1a2-7a123b5b4eec", "metadata": {}, "source": [ "### 2.2 EmbeddingSegmentTransform " ] }, { "cell_type": "markdown", "id": "8f7ca5fd-4186-4bf4-ac32-0bace7802ca9", "metadata": {}, "source": [ "`EmbeddingSegmentTransform` calls models' `encode_segment` method inside." ] }, { "cell_type": "code", "execution_count": 12, "id": "f05f8f02-4d24-4438-ac15-9ed45e2e4f78", "metadata": { "tags": [] }, "outputs": [], "source": [ "from etna.transforms import EmbeddingSegmentTransform\n", "from etna.transforms.embeddings.models import BaseEmbeddingModel\n", "\n", "\n", "def forecast_with_segment_embeddings(emb_model: BaseEmbeddingModel, training_params: dict) -> float:\n", " model = CatBoostMultiSegmentModel()\n", "\n", " emb_transform = EmbeddingSegmentTransform(\n", " in_columns=[\"target\"], embedding_model=emb_model, training_params=training_params, out_column=\"emb\"\n", " )\n", " pipeline = Pipeline(model=model, transforms=[lag_transform, emb_transform], horizon=HORIZON)\n", " metrics_df, _, _ = pipeline.backtest(ts, metrics=[SMAPE()], n_folds=3)\n", " smape_score = metrics_df[\"SMAPE\"].mean()\n", " return smape_score" ] }, { "cell_type": "markdown", "id": "6bc237e1-d2e3-48ee-99b5-ac35b957717b", "metadata": {}, "source": [ "You can see training parameters of the model to pass it to transform.\n", "\n", "Let's begin with `TSTCCEmbeddingModel`" ] }, { "cell_type": "code", "execution_count": 13, "id": "3e6cc297-48bd-4614-bcbe-44e5732bf3a8", "metadata": { "tags": [] }, "outputs": [], "source": [ "?TSTCCEmbeddingModel.fit" ] }, { "cell_type": "code", "execution_count": 14, "id": "516b209e-7bd2-45c6-8db0-b1708ffda0fc", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 34.0s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 1.1min remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.7min remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.7min finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 1.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 2.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 3.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 3.1s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n" ] } ], "source": [ "import torch\n", "\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "\n", "emb_model = TSTCCEmbeddingModel(input_dims=1, tc_hidden_dim=16, depth=3, output_dims=6, device=device)\n", "training_params = {\"n_epochs\": 10}\n", "smape_score = forecast_with_segment_embeddings(emb_model, training_params)" ] }, { "cell_type": "code", "execution_count": 15, "id": "e673ff95-fdc2-4751-9025-98940f73211d", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SMAPE: 14.214904390075835\n" ] } ], "source": [ "print(\"SMAPE: \", smape_score)" ] }, { "cell_type": "markdown", "id": "e5e35346-6f16-4d42-8a33-b98bf5679046", "metadata": {}, "source": [ "Better then without embeddings. Let's try `TS2VecEmbeddingModel`." ] }, { "cell_type": "code", "execution_count": 16, "id": "6b7e43f6-9ce3-4a3f-b6e9-70ed46b272e5", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 27.7s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 58.0s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.7min remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.7min finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 1.7s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 3.0s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 4.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 4.1s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.2s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.2s finished\n" ] } ], "source": [ "emb_model = TS2VecEmbeddingModel(input_dims=1, hidden_dims=16, depth=3, output_dims=6, device=device)\n", "training_params = {\"n_epochs\": 10}\n", "smape_score = forecast_with_segment_embeddings(emb_model, training_params)" ] }, { "cell_type": "code", "execution_count": 17, "id": "5688ea80-5d6c-414a-89a7-7ec144d09b4f", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SMAPE: 13.549340740762041\n" ] } ], "source": [ "print(\"SMAPE: \", smape_score)" ] }, { "cell_type": "markdown", "id": "da58c577-3519-41a7-b63c-1da896121954", "metadata": {}, "source": [ "Much better. Now let's try another transform." ] }, { "cell_type": "markdown", "id": "949a07ba-548f-4b09-bcba-a07f76c9d501", "metadata": { "tags": [] }, "source": [ "### 2.3 EmbeddingWindowTransform " ] }, { "cell_type": "markdown", "id": "dc3ef834-fcd7-4ee1-85e9-a8ace8dde8a4", "metadata": {}, "source": [ "`EmbeddingWindowTransform` calls models' `encode_window` method inside. As we have discussed, these features are not regressors and should be used as lags for future." ] }, { "cell_type": "code", "execution_count": 18, "id": "b39d1abe-42f9-44ff-af5e-f93d08c0ac02", "metadata": { "tags": [] }, "outputs": [], "source": [ "from etna.transforms import EmbeddingWindowTransform\n", "from etna.transforms import FilterFeaturesTransform\n", "\n", "\n", "def forecast_with_window_embeddings(emb_model: BaseEmbeddingModel, training_params: dict) -> float:\n", " model = CatBoostMultiSegmentModel()\n", "\n", " output_dims = emb_model.output_dims\n", "\n", " emb_transform = EmbeddingWindowTransform(\n", " in_columns=[\"target\"], embedding_model=emb_model, training_params=training_params, out_column=\"embedding_window\"\n", " )\n", " lag_emb_transforms = [\n", " LagTransform(in_column=f\"embedding_window_{i}\", lags=[HORIZON], out_column=f\"lag_emb_{i}\")\n", " for i in range(output_dims)\n", " ]\n", " filter_transforms = FilterFeaturesTransform(exclude=[f\"embedding_window_{i}\" for i in range(output_dims)])\n", "\n", " transforms = [lag_transform] + [emb_transform] + lag_emb_transforms + [filter_transforms]\n", "\n", " pipeline = Pipeline(model=model, transforms=transforms, horizon=HORIZON)\n", " metrics_df, _, _ = pipeline.backtest(ts, metrics=[SMAPE()], n_folds=3)\n", " smape_score = metrics_df[\"SMAPE\"].mean()\n", " return smape_score" ] }, { "cell_type": "code", "execution_count": 19, "id": "5e663aa0-778d-4393-80e6-7ef888210ec5", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 53.9s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 1.8min remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 2.7min remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 2.7min finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 10.0s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 20.9s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 31.6s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 31.6s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n" ] } ], "source": [ "emb_model = TSTCCEmbeddingModel(input_dims=1, tc_hidden_dim=16, depth=3, output_dims=6, device=device)\n", "training_params = {\"n_epochs\": 10}\n", "smape_score = forecast_with_window_embeddings(emb_model, training_params)" ] }, { "cell_type": "code", "execution_count": 20, "id": "d07041ad-698c-4b07-b339-16aed9856129", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SMAPE: 104.68988621650867\n" ] } ], "source": [ "print(\"SMAPE: \", smape_score)" ] }, { "cell_type": "markdown", "id": "41d78b2c-7574-4b0b-806c-adb5db324998", "metadata": {}, "source": [ "Oops... What about `TS2VecEmbeddingModel`?" ] }, { "cell_type": "code", "execution_count": 21, "id": "fab4711f-7b9d-4263-abbd-05a3cedcaaef", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 34.5s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 1.2min remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.8min remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 1.8min finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 8.6s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 17.4s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 26.3s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 26.3s finished\n", "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n", "[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s remaining: 0.0s\n", "[Parallel(n_jobs=1)]: Done 3 out of 3 | elapsed: 0.1s finished\n" ] } ], "source": [ "emb_model = TS2VecEmbeddingModel(input_dims=1, hidden_dims=16, depth=3, output_dims=6, device=device)\n", "training_params = {\"n_epochs\": 10}\n", "smape_score = forecast_with_window_embeddings(emb_model, training_params)" ] }, { "cell_type": "code", "execution_count": 22, "id": "fd890c55-ea57-4f51-a2e1-320cc4111b46", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SMAPE: 29.776520212845234\n" ] } ], "source": [ "print(\"SMAPE: \", smape_score)" ] }, { "cell_type": "markdown", "id": "3a73e55e-1e1b-4ca2-a42e-62a68142c517", "metadata": {}, "source": [ "Window embeddings don't help with this dataset. It means that you should try both models and both transforms to get the best results." ] }, { "cell_type": "markdown", "id": "84fcd1b8-e61a-40d4-a80a-9c558637a8d4", "metadata": {}, "source": [ "## 3. Saving and loading models \n" ] }, { "cell_type": "markdown", "id": "c5f0bc45-6388-4e92-bdc1-d8090af66b26", "metadata": {}, "source": [ "If you have a pretrained embedding model and aren't going to train it on calling `fit`, you should \"freeze\" training loop. It is helpful for using the model inside transforms, which call `fit` method on each `fit` of the pipeline." ] }, { "cell_type": "code", "execution_count": 23, "id": "1d5fb109-b1c7-431c-a5bc-b2eeddc311f3", "metadata": { "tags": [] }, "outputs": [], "source": [ "MODEL_PATH = \"model.zip\"" ] }, { "cell_type": "code", "execution_count": 24, "id": "24229c75-5e9a-4ff8-a7f4-1c723c62fc9e", "metadata": { "tags": [] }, "outputs": [], "source": [ "emb_model.freeze()\n", "emb_model.save(MODEL_PATH)" ] }, { "cell_type": "markdown", "id": "e7a0275d-47ca-4aa4-a312-d02fcd06a7ae", "metadata": {}, "source": [ "Now you are ready to load pretrained model. " ] }, { "cell_type": "code", "execution_count": 25, "id": "5d3f522a-dc2d-46d4-a28b-8a1524793874", "metadata": { "tags": [] }, "outputs": [], "source": [ "model_loaded = TS2VecEmbeddingModel.load(MODEL_PATH)" ] }, { "cell_type": "markdown", "id": "98e5325a-abe1-4c88-9f2b-fb61ef5d110e", "metadata": { "tags": [] }, "source": [ "If you need to fine-tune pretrained model, you should \"unfreeze\" training loop. After that it will start fitting on calling `fit` method." ] }, { "cell_type": "code", "execution_count": 26, "id": "f9961758-6f5b-42f6-92f1-0aa68bb0a677", "metadata": { "tags": [] }, "outputs": [], "source": [ "model_loaded.freeze(is_freezed=False)" ] }, { "cell_type": "markdown", "id": "9ff472a4", "metadata": {}, "source": [ "To get information about whether model is \"freezed\" or not use `is_freezed` property." ] }, { "cell_type": "code", "execution_count": 27, "id": "eba6d010", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "False" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_loaded.is_freezed" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.6" } }, "nbformat": 4, "nbformat_minor": 5 }