Tasks & Benchmarks
This notebook covers the following topics:
- Defining a time series forecasting
Task
consisting of multipleEvaluationWindow
s - Multivariate and univariate forecasting
- Evaluation on a
Benchmark
consisting of multiple tasks - Aggregating benchmark results
import warnings
from pathlib import Path
import datasets
import numpy as np
from tqdm.auto import tqdm
import fev
warnings.simplefilter("ignore")
datasets.disable_progress_bars()
Main classes¶
The fev
package provides 3 core classes for evaluating time series forecasting models:
Task
- Defines a single forecasting task with dataset path, forecast horizon, and evaluation settings. EachTask
contains one or more evaluation windows.EvaluationWindow
- Represents a single train/test split of the data at a specific cutoff point. Model performance is averaged across all windows within aTask
.Benchmark
- A collection of multiple tasks (e.g., different datasets). Individual task results are aggregated to compute overall benchmark scores.
In short, the hierarchy is Benchmark
-> Task
-> EvaluationWindow
.
This tutorial demonstrates the functionality of these classes.
Data sources¶
Dataset stored on Hugging Face Hub: https://huggingface.co/datasets/autogluon/chronos_datasets
task = fev.Task(
dataset_path="autogluon/chronos_datasets",
dataset_config="monash_cif_2016",
horizon=12,
)
Dataset stored on S3
# Dataset consisting of a single parquet / arrow file
task = fev.Task(
dataset_path="s3://autogluon/datasets/timeseries/m1_monthly/data.parquet",
horizon=12,
)
# Dataset consisting of multiple parquet / arrow files
task = fev.Task(
dataset_path="s3://autogluon/datasets/timeseries/m1_monthly/*.parquet",
horizon=12,
)
Dataset stored locally
# Download dataset from HF Hub and save it locally
ds = datasets.load_dataset("autogluon/chronos_datasets", name="m4_hourly", split="train")
local_path = "/tmp/m4_hourly/data.parquet"
ds.to_parquet(local_path)
task = fev.Task(
dataset_path=local_path,
horizon=48,
)
Evaluation windows¶
A single Task
consists of one or more EvaluationWindow
s.
Each EvaluationWindow
represents a single train/test split of the time series data at a specific cutoff point.
We'll create a task with a toy dataset to demonstrate how evaluation windows work.
import pandas as pd
# Create a toy dataset with a single time series
ts = {
"id": "A",
"timestamp": pd.date_range("2025-01-01", freq="D", periods=10),
"target": list(range(10)),
}
ds = datasets.Dataset.from_list([ts])
dataset_path = "/tmp/toy_dataset.parquet"
ds.to_parquet(dataset_path);
We now construct a Task
with 2 evaluation windows based on this toy dataset.
task = fev.Task(
dataset_path=dataset_path,
horizon=3,
num_windows=2,
)
# Show the original dataset before any splits (for reference only)
full_dataset = task.load_full_dataset()
print(full_dataset)
print(full_dataset[0])
Dataset({ features: ['id', 'timestamp', 'target'], num_rows: 1 }) {'id': np.str_('A'), 'timestamp': array(['2025-01-01T00:00:00.000000000', '2025-01-02T00:00:00.000000000', '2025-01-03T00:00:00.000000000', '2025-01-04T00:00:00.000000000', '2025-01-05T00:00:00.000000000', '2025-01-06T00:00:00.000000000', '2025-01-07T00:00:00.000000000', '2025-01-08T00:00:00.000000000', '2025-01-09T00:00:00.000000000', '2025-01-10T00:00:00.000000000'], dtype='datetime64[ns]'), 'target': array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])}
Now let's examine how the data is split across the 2 evaluation windows:
# Show how data is split across the 2 evaluation windows
for window_index, window in enumerate(task.iter_windows()):
past, future = window.get_input_data()
ground_truth = window.get_ground_truth()
print(f"Window {window_index} (cutoff={window.cutoff}):")
print(f" Past data: {past[0]['target']}")
print(f" Ground truth: {ground_truth[0]['target']}")
Window 0 (cutoff=-6): Past data: [0 1 2 3] Ground truth: [4 5 6] Window 1 (cutoff=-3): Past data: [0 1 2 3 4 5 6] Ground truth: [7 8 9]
Customizing evaluation window parameters¶
You can control how evaluation windows are positioned using initial_cutoff
and window_step_size
parameters.
# Example 1: Start evaluation earlier with initial_cutoff
task = fev.Task(
dataset_path=dataset_path,
horizon=3,
num_windows=2,
initial_cutoff=-8,
)
for window_index, window in enumerate(task.iter_windows()):
past, future = window.get_input_data()
ground_truth = window.get_ground_truth()
print(f"Window {window_index} (cutoff={window.cutoff}):")
print(f" Past data: {past[0]['target']}")
print(f" Ground truth: {ground_truth[0]['target']}")
Window 0 (cutoff=-8): Past data: [0 1] Ground truth: [2 3 4] Window 1 (cutoff=-5): Past data: [0 1 2 3 4] Ground truth: [5 6 7]
# Example 2: Use smaller window_step_size
task = fev.Task(
dataset_path=dataset_path,
horizon=3,
num_windows=2,
window_step_size=1,
)
for window_index, window in enumerate(task.iter_windows()):
past, future = window.get_input_data()
ground_truth = window.get_ground_truth()
print(f"Window {window_index} (cutoff={window.cutoff}):")
print(f" Past data: {past[0]['target']}")
print(f" Ground truth: {ground_truth[0]['target']}")
Window 0 (cutoff=-4): Past data: [0 1 2 3 4 5] Ground truth: [6 7 8] Window 1 (cutoff=-3): Past data: [0 1 2 3 4 5 6] Ground truth: [7 8 9]
You can also set initial_cutoff
and window_step_size
for pandas-compatible time strings.
# Example 3: Use pandas timestamp-like strings
task = fev.Task(
dataset_path=dataset_path,
horizon=3,
num_windows=2,
initial_cutoff="2025-01-05",
window_step_size="2D",
)
for window_index, window in enumerate(task.iter_windows()):
past, future = window.get_input_data()
ground_truth = window.get_ground_truth()
print(f"Window {window_index} (cutoff={window.cutoff}):")
print(f" Past data: {past[0]['target']}")
print(f" Ground truth: {ground_truth[0]['target']}")
Window 0 (cutoff=2025-01-05T00:00:00): Past data: [0 1 2 3 4] Ground truth: [5 6 7] Window 1 (cutoff=2025-01-07T00:00:00): Past data: [0 1 2 3 4 5 6] Ground truth: [7 8 9]
Univariate forecasting¶
The simplest kind of forecasting task is univariate forecasting where the goal is to predict a single target
for each time series in the dataset.
task = fev.Task(
dataset_path="autogluon/chronos_datasets",
dataset_config="m4_hourly",
horizon=24,
num_windows=2,
)
To evaluate a forecasting model on this task we need to make predictions for each EvaluationWindow
.
Predictions format¶
Predictions must follow a certain format that is specified by task.predictions_schema
.
For point forecasting tasks (i.e., if quantile_levels=None
), predictions must contain a single array of length horizon
for each time series.
task.predictions_schema
{'predictions': Sequence(feature=Value(dtype='float64', id=None), length=24, id=None)}
Here is an example of a function that makes predictions for a single EvaluationWindow
and formats them as a datasets.Dataset
.
def naive_forecast(window: fev.EvaluationWindow) -> datasets.Dataset:
assert len(window.target_columns) == 1, "only univariate forecasting supported"
predictions: list[dict[str, np.ndarray]] = []
past_data, future_data = window.get_input_data()
for ts in past_data:
y = ts[window.target_columns[0]]
predictions.append(
{"predictions": np.array([y[-1] for _ in range(window.horizon)])}
)
return datasets.Dataset.from_list(predictions)
window = task.get_window(0)
predictions = naive_forecast(window)
predictions
Dataset({ features: ['predictions'], num_rows: 414 })
Each entry in predictions
is a dictionary where the key "predictions"
corresponds to an array with 24
values.
print(predictions[0])
{'predictions': [684.0, 684.0, 684.0, 684.0, 684.0, 684.0, 684.0, 684.0, 684.0, 684.0, 684.0, 684.0, 684.0, 684.0, 684.0, 684.0, 684.0, 684.0, 684.0, 684.0, 684.0, 684.0, 684.0, 684.0]}
Once we have predictions for each evaluation window, we can compute the metrics and generate an evaluation summary
predictions_per_window = [naive_forecast(window) for window in task.iter_windows()]
task.evaluation_summary(predictions_per_window, model_name="naive")
{'model_name': 'naive', 'dataset_path': 'autogluon/chronos_datasets', 'dataset_config': 'm4_hourly', 'horizon': 24, 'num_windows': 2, 'initial_cutoff': -48, 'window_step_size': 24, 'min_context_length': 1, 'max_context_length': None, 'seasonality': 1, 'eval_metric': 'MASE', 'extra_metrics': [], 'quantile_levels': [], 'id_column': 'id', 'timestamp_column': 'timestamp', 'target': 'target', 'generate_univariate_targets_from': None, 'known_dynamic_columns': [], 'past_dynamic_columns': [], 'static_columns': [], 'task_name': 'm4_hourly', 'test_error': 3.8851860313085385, 'training_time_s': None, 'inference_time_s': None, 'dataset_fingerprint': '19e36bb78b718d8d', 'trained_on_this_dataset': False, 'fev_version': '0.6.0', 'MASE': 3.8851860313085385}
Probabilistic forecasting¶
For probabilistic forecasting tasks (i.e., if quantile_levels
contains at least one value), predictions must additionally contain a prediction for each quantile level.
task = fev.Task(
dataset_path="autogluon/chronos_datasets",
dataset_config="m4_hourly",
horizon=24,
quantile_levels=[0.1, 0.5, 0.9],
eval_metric="WQL",
)
task.predictions_schema
{'predictions': Sequence(feature=Value(dtype='float64', id=None), length=24, id=None), '0.1': Sequence(feature=Value(dtype='float64', id=None), length=24, id=None), '0.5': Sequence(feature=Value(dtype='float64', id=None), length=24, id=None), '0.9': Sequence(feature=Value(dtype='float64', id=None), length=24, id=None)}
Covariates¶
By default, only the id_column
, timestamp_column
and target
columns are loaded from the dataset.
task = fev.Task(
dataset_path="autogluon/fev_datasets",
dataset_config="favorita_transactions_1D",
horizon=7,
target="transactions",
)
past_data, future_data = task.get_window(0).get_input_data()
print(past_data)
print(future_data)
Dataset({ features: ['id', 'timestamp', 'transactions'], num_rows: 51 }) Dataset({ features: ['id', 'timestamp'], num_rows: 51 })
We can view all the columns available in the dataset with Task.load_full_dataset()
full_ds = task.load_full_dataset()
print(full_ds)
full_ds.features
Dataset({ features: ['id', 'timestamp', 'transactions', 'oil_price', 'holiday', 'store_nbr', 'city', 'state', 'type', 'cluster'], num_rows: 51 })
{'id': Value(dtype='string', id=None), 'timestamp': Sequence(feature=Value(dtype='timestamp[ms]', id=None), length=-1, id=None), 'transactions': Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None), 'oil_price': Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None), 'holiday': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'store_nbr': Value(dtype='float32', id=None), 'city': Value(dtype='string', id=None), 'state': Value(dtype='string', id=None), 'type': Value(dtype='string', id=None), 'cluster': Value(dtype='float32', id=None)}
We can configure the task to use the additional columns as covariates. There are 3 types of covariates:
Static covariates (static_columns
) are the time-independent attributes of the time series, e.g.
- Location (country, state, city)
- Product properties (brand, color, size)
- IDs and constant metadata
Known dynamic covariates (known_dynamic_columns
) are time-varying features available for both past and future periods, e.g.
- Holidays, calendar features
- Planned promotions
Past dynamic covariates (past_dynamic_columns
) are time-varying features only available until the forecast start, e.g.
- Weather data, economic indicators
- Related product sales
Dynamic covariates must have feature type Sequence
and length must match the target length for each row
Static covariates must have feature type Value
(not Sequence
).
task = fev.Task(
dataset_path="autogluon/fev_datasets",
dataset_config="favorita_transactions_1D",
horizon=7,
target="transactions",
known_dynamic_columns=["holiday"], # time-dependent, known in the future and in the past
past_dynamic_columns=["oil_price"], # time-dependent, known only in the past
static_columns=["city", "state"], # time-independent
)
past_data, future_data = task.get_window(0).get_input_data()
print(past_data)
print(future_data)
Dataset({ features: ['id', 'timestamp', 'transactions', 'holiday', 'oil_price', 'city', 'state'], num_rows: 51 }) Dataset({ features: ['id', 'timestamp', 'holiday', 'city', 'state'], num_rows: 51 })
Multivariate forecasting¶
In all previous examples we considered univariate forecasting tasks, where the goal was to predict a single target
into the future.
fev
also supports multivariate tasks, where the goal is to simultaneously predict multiple target columns.
"Real" multivariate tasks¶
We can define multivariate forecasting tasks by setting the target
attribute to a list
of column names.
task = fev.Task(
dataset_path="autogluon/fev_datasets",
dataset_config="ETT_1H",
horizon=3,
target=["OT", "LUFL", "LULL"],
)
The input data created by the task in this case is identical to what would happen if we used ["OT", "LUFL", "LULL"]
as past_dynamic_columns
.
That is, the target columns ["OT", "LUFL", "LULL"]
are available in past_data
but not in future_data
.
past_data, future_data = task.get_window(0).get_input_data()
print(past_data)
print(future_data)
Dataset({ features: ['id', 'timestamp', 'LUFL', 'LULL', 'OT'], num_rows: 2 }) Dataset({ features: ['id', 'timestamp'], num_rows: 2 })
The only difference in a multivariate task is that the predictions must be formatted as a datasets.DatasetDict
where
- each key corresponds to the name of the target column
- each value is a
datasets.Dataset
containing the predictions for this column in a format compatible withtask.predictions_schema
def naive_forecast_multivariate(window: fev.EvaluationWindow) -> datasets.DatasetDict:
"""Predicts the last observed value in each multivariate column."""
past_data, future_data = window.get_input_data()
predictions = datasets.DatasetDict()
for col in window.target_columns:
predictions_for_column = []
for ts in past_data:
predictions_for_column.append({"predictions": [ts[col][-1] for _ in range(window.horizon)]})
predictions[col] = datasets.Dataset.from_list(predictions_for_column)
return predictions
window = task.get_window(0)
predictions_per_window = naive_forecast_multivariate(window).cast(task.predictions_schema)
predictions_per_window
DatasetDict({ LUFL: Dataset({ features: ['predictions'], num_rows: 2 }) LULL: Dataset({ features: ['predictions'], num_rows: 2 }) OT: Dataset({ features: ['predictions'], num_rows: 2 }) })
We can also look at the individual values in the Dataset
objects
for col in task.target_columns:
print(f"Predictions for column '{col}'")
print(f"\t{predictions_per_window[col].to_list()}")
Predictions for column 'LUFL' [{'predictions': [3.5329999923706055, 3.5329999923706055, 3.5329999923706055]}, {'predictions': [-10.331000328063965, -10.331000328063965, -10.331000328063965]}] Predictions for column 'LULL' [{'predictions': [1.6749999523162842, 1.6749999523162842, 1.6749999523162842]}, {'predictions': [-1.2899999618530273, -1.2899999618530273, -1.2899999618530273]}] Predictions for column 'OT' [{'predictions': [11.043999671936035, 11.043999671936035, 11.043999671936035]}, {'predictions': [48.18349838256836, 48.18349838256836, 48.18349838256836]}]
The rest of the code can stay the same.
task.evaluation_summary([predictions_per_window], model_name="naive")
{'model_name': 'naive', 'dataset_path': 'autogluon/fev_datasets', 'dataset_config': 'ETT_1H', 'horizon': 3, 'num_windows': 1, 'initial_cutoff': -3, 'window_step_size': 3, 'min_context_length': 1, 'max_context_length': None, 'seasonality': 1, 'eval_metric': 'MASE', 'extra_metrics': [], 'quantile_levels': [], 'id_column': 'id', 'timestamp_column': 'timestamp', 'target': ['LUFL', 'LULL', 'OT'], 'generate_univariate_targets_from': None, 'known_dynamic_columns': [], 'past_dynamic_columns': [], 'static_columns': [], 'task_name': 'ETT_1H', 'test_error': 1.1921320279836811, 'training_time_s': None, 'inference_time_s': None, 'dataset_fingerprint': '305bfc1cf6779b47', 'trained_on_this_dataset': False, 'fev_version': '0.6.0', 'MASE': 1.1921320279836811}
Converting multivariate tasks into univariate tasks¶
Alternatively, we can convert a multivariate task into a univariate one by creating multiple univariate time series from each multivariate time series.
The original ETTh
dataset contains two multivariate time series with the following ids:
past_data["id"]
array(['ETTh1', 'ETTh2'], dtype='<U5')
If we set generate_univariate_targets_from=["OT", "LUFL", "LULL"]
, fev
will create 3 univariate time series from each time series in the original dataset.
task = fev.Task(
dataset_path="autogluon/fev_datasets",
dataset_config="ETT_1H",
horizon=3,
generate_univariate_targets_from=["OT", "LUFL", "LULL"],
target="target", # new name for the target columns ['OT', 'LUFL', 'LULL'] after splitting
)
past_data, future_data = task.get_window(0).get_input_data()
print(past_data)
print(future_data)
Dataset({ features: ['id', 'timestamp', 'target'], num_rows: 6 }) Dataset({ features: ['id', 'timestamp'], num_rows: 6 })
The new dataset contains 6 items (2 original ids $\times$ 3 target columns).
past_data["id"]
array(['ETTh1_LUFL', 'ETTh1_LULL', 'ETTh1_OT', 'ETTh2_LUFL', 'ETTh2_LULL', 'ETTh2_OT'], dtype='<U10')
We can confirm that the naive forecast achieves the same MASE score on this equivalent representation of the multivariate task.
def naive_forecast_univariate(window: fev.EvaluationWindow) -> datasets.Dataset:
"""Predicts the last observed value."""
past_data, future_data = window.get_input_data()
predictions = []
for ts in past_data:
predictions.append({"predictions": [ts[window.target_columns[0]][-1] for _ in range(window.horizon)]})
return datasets.Dataset.from_list(predictions)
predictions_per_window = []
for window in task.iter_windows():
predictions_per_window.append(naive_forecast_univariate(window))
task.evaluation_summary(predictions_per_window, model_name="naive")
{'model_name': 'naive', 'dataset_path': 'autogluon/fev_datasets', 'dataset_config': 'ETT_1H', 'horizon': 3, 'num_windows': 1, 'initial_cutoff': -3, 'window_step_size': 3, 'min_context_length': 1, 'max_context_length': None, 'seasonality': 1, 'eval_metric': 'MASE', 'extra_metrics': [], 'quantile_levels': [], 'id_column': 'id', 'timestamp_column': 'timestamp', 'target': 'target', 'generate_univariate_targets_from': ['OT', 'LUFL', 'LULL'], 'known_dynamic_columns': [], 'past_dynamic_columns': [], 'static_columns': [], 'task_name': 'ETT_1H', 'test_error': 1.1921320279836811, 'training_time_s': None, 'inference_time_s': None, 'dataset_fingerprint': '81591e84125d0e33', 'trained_on_this_dataset': False, 'fev_version': '0.6.0', 'MASE': 1.1921320279836811}
Evaluation on a Benchmark consisting of multiple tasks¶
A fev.Benchmark
object is essentially a collection of Task
s.
We can create a benchmark from a list of dictionaries. Each dictionary is interpreted as a fev.TaskGenerator
.
tasks_configs = [
{
"dataset_path": "autogluon/chronos_datasets",
"dataset_config": "monash_m1_quarterly",
"horizon": 8,
"seasonality": 4,
"eval_metric": "MASE",
},
{
"dataset_path": "autogluon/chronos_datasets",
"dataset_config": "monash_electricity_weekly",
"horizon": 8,
"num_windows": 2,
},
{
"dataset_path": "autogluon/chronos_datasets",
"dataset_config": "monash_m1_yearly",
"horizon": 6,
},
]
benchmark = fev.Benchmark.from_list(tasks_configs)
Or from a YAML file
benchmark_path = Path(fev.__file__).parents[2] / "benchmarks" / "example" / "tasks.yaml"
# Show contents of the benchmark YAML file
!cat {benchmark_path}
tasks: - dataset_path: autogluon/chronos_datasets dataset_config: monash_m1_quarterly horizon: 8 seasonality: 4 - dataset_path: autogluon/chronos_datasets dataset_config: monash_electricity_weekly horizon: 8 num_windows: 2 - dataset_path: autogluon/chronos_datasets dataset_config: monash_m1_yearly horizon: 6 seasonality: 1
benchmark = fev.Benchmark.from_yaml(benchmark_path)
benchmark.tasks
[Task(dataset_path='autogluon/chronos_datasets', dataset_config='monash_m1_quarterly', horizon=8, num_windows=1, initial_cutoff=-8, window_step_size=8, min_context_length=1, max_context_length=None, seasonality=4, eval_metric='MASE', extra_metrics=[], quantile_levels=[], id_column='id', timestamp_column='timestamp', target='target', generate_univariate_targets_from=None, known_dynamic_columns=[], past_dynamic_columns=[], static_columns=[], task_name='monash_m1_quarterly'), Task(dataset_path='autogluon/chronos_datasets', dataset_config='monash_electricity_weekly', horizon=8, num_windows=2, initial_cutoff=-16, window_step_size=8, min_context_length=1, max_context_length=None, seasonality=1, eval_metric='MASE', extra_metrics=[], quantile_levels=[], id_column='id', timestamp_column='timestamp', target='target', generate_univariate_targets_from=None, known_dynamic_columns=[], past_dynamic_columns=[], static_columns=[], task_name='monash_electricity_weekly'), Task(dataset_path='autogluon/chronos_datasets', dataset_config='monash_m1_yearly', horizon=6, num_windows=1, initial_cutoff=-6, window_step_size=6, min_context_length=1, max_context_length=None, seasonality=1, eval_metric='MASE', extra_metrics=[], quantile_levels=[], id_column='id', timestamp_column='timestamp', target='target', generate_univariate_targets_from=None, known_dynamic_columns=[], past_dynamic_columns=[], static_columns=[], task_name='monash_m1_yearly')]
Now let's evaluate some simple forecasting models on this toy benchmark.
# You might need to restart the notebook after installing the dependencies
!pip install -q statsforecast "numpy<=2.2" "scipy<1.16"
from statsforecast.models import AutoETS, SeasonalNaive, Theta
def predict_with_model(task: fev.Task, model_name: str = "seasonal_naive") -> list[datasets.Dataset]:
assert len(task.target_columns) == 1, "only univariate forecasting supported"
if model_name == "seasonal_naive":
model = SeasonalNaive(season_length=task.seasonality)
elif model_name == "theta":
model = Theta(season_length=task.seasonality)
elif model_name == "ets":
model = AutoETS(season_length=task.seasonality)
else:
raise ValueError(f"Unknown model_name: {model_name}")
predictions_per_window = []
for window in task.iter_windows():
past_data, future_data = window.get_input_data()
predictions = [
{"predictions": model.forecast(y=ts[task.target], h=task.horizon)["mean"]}
for ts in past_data
]
predictions_per_window.append(datasets.Dataset.from_list(predictions))
return predictions_per_window
import time
summaries = []
for task in tqdm(benchmark.tasks, desc="Tasks completed"):
for model_name in ["seasonal_naive", "ets", "theta"]:
start_time = time.time()
predictions_per_window = predict_with_model(task, model_name=model_name)
infer_time_s = time.time() - start_time
eval_summary = task.evaluation_summary(
predictions_per_window,
model_name=model_name,
inference_time_s=infer_time_s,
training_time_s=0.0,
)
summaries.append(eval_summary)
Tasks completed: 0%| | 0/3 [00:00<?, ?it/s]
fev.leaderboard(summaries, baseline_model="seasonal_naive")
skill_score | win_rate | median_training_time_s | median_inference_time_s | training_corpus_overlap | num_failures | |
---|---|---|---|---|---|---|
model_name | ||||||
ets | 0.133483 | 0.833333 | 0.0 | 3.637132 | 0.0 | 0 |
theta | 0.105932 | 0.333333 | 0.0 | 0.131447 | 0.0 | 0 |
seasonal_naive | 0.000000 | 0.333333 | 0.0 | 1.638496 | 0.0 | 0 |
The leaderboard
method aggregates the performance into a single number.
We can investigate the performance for individual tasks using the pivot_table
method
fev.pivot_table(summaries, task_columns=["dataset_config"])
model_name | ets | seasonal_naive | theta |
---|---|---|---|
dataset_config | |||
monash_electricity_weekly | 2.552429 | 2.535526 | 2.557008 |
monash_m1_quarterly | 1.660810 | 2.077537 | 1.705247 |
monash_m1_yearly | 3.957011 | 4.894322 | 4.225722 |
fev.pairwise_comparison(summaries)
skill_score | win_rate | ||
---|---|---|---|
model_1 | model_2 | ||
ets | ets | 0.000000 | 0.500000 |
seasonal_naive | 0.133483 | 0.666667 | |
theta | 0.030815 | 1.000000 | |
seasonal_naive | ets | -0.154045 | 0.333333 |
seasonal_naive | 0.000000 | 0.500000 | |
theta | -0.118484 | 0.333333 | |
theta | ets | -0.031794 | 0.000000 |
seasonal_naive | 0.105932 | 0.666667 | |
theta | 0.000000 | 0.500000 |
Both leaderboard()
and pivot_table()
methods can handle single or multiple evaluation summaries in different formats:
pandas.DataFrame
- list of dictionaries
- paths to JSONL (orient="record") or CSV files
Here is an example of how we can work with URLs of CSV files:
summaries = [
"https://raw.githubusercontent.com/autogluon/fev/refs/heads/main/benchmarks/chronos_zeroshot/results/auto_arima.csv",
"https://raw.githubusercontent.com/autogluon/fev/refs/heads/main/benchmarks/chronos_zeroshot/results/chronos_bolt_base.csv",
"https://raw.githubusercontent.com/autogluon/fev/refs/heads/main/benchmarks/chronos_zeroshot/results/seasonal_naive.csv",
]
fev.leaderboard(summaries, metric_column="MASE")
skill_score | win_rate | median_training_time_s | median_inference_time_s | training_corpus_overlap | num_failures | |
---|---|---|---|---|---|---|
model_name | ||||||
chronos_bolt_base | 0.204508 | 0.722222 | NaN | 0.406413 | 0.0 | 0 |
auto_arima | 0.130551 | 0.648148 | NaN | 75.883700 | 0.0 | 0 |
seasonal_naive | 0.000000 | 0.129630 | NaN | 0.096449 | 0.0 | 0 |
We can also compute the 95% confidence intervals for skill_score
and win_rate
columns via bootstrap by setting the n_resamples
parameter.
fev.leaderboard(summaries, metric_column="MASE", n_resamples=1000).round(3)
skill_score | skill_score_lower | skill_score_upper | win_rate | win_rate_lower | win_rate_upper | median_training_time_s | median_inference_time_s | training_corpus_overlap | num_failures | |
---|---|---|---|---|---|---|---|---|---|---|
model_name | ||||||||||
chronos_bolt_base | 0.205 | 0.147 | 0.265 | 0.722 | 0.611 | 0.833 | NaN | 0.406 | 0.0 | 0 |
auto_arima | 0.131 | 0.053 | 0.204 | 0.648 | 0.500 | 0.778 | NaN | 75.884 | 0.0 | 0 |
seasonal_naive | 0.000 | 0.000 | 0.000 | 0.130 | 0.056 | 0.204 | NaN | 0.096 | 0.0 | 0 |
fev.pairwise_comparison(summaries, metric_column="MASE", n_resamples=1000).round(3)
skill_score | skill_score_lower | skill_score_upper | win_rate | win_rate_lower | win_rate_upper | ||
---|---|---|---|---|---|---|---|
model_1 | model_2 | ||||||
chronos_bolt_base | chronos_bolt_base | 0.000 | 0.000 | 0.000 | 0.500 | 0.500 | 0.500 |
auto_arima | 0.085 | 0.000 | 0.162 | 0.519 | 0.333 | 0.704 | |
seasonal_naive | 0.205 | 0.147 | 0.265 | 0.926 | 0.815 | 1.000 | |
auto_arima | chronos_bolt_base | -0.093 | -0.194 | -0.000 | 0.481 | 0.296 | 0.667 |
auto_arima | 0.000 | 0.000 | 0.000 | 0.500 | 0.500 | 0.500 | |
seasonal_naive | 0.131 | 0.053 | 0.204 | 0.815 | 0.667 | 0.963 | |
seasonal_naive | chronos_bolt_base | -0.257 | -0.360 | -0.173 | 0.074 | 0.000 | 0.185 |
auto_arima | -0.150 | -0.256 | -0.056 | 0.185 | 0.037 | 0.333 | |
seasonal_naive | 0.000 | 0.000 | 0.000 | 0.500 | 0.500 | 0.500 |
Like before, we can view the scores for individual tasks with the pivot_table
method.
fev.pivot_table(summaries, task_columns=["dataset_config"], metric_column="WQL").round(3)
model_name | auto_arima | chronos_bolt_base | seasonal_naive |
---|---|---|---|
dataset_config | |||
ETTh | 0.089 | 0.071 | 0.122 |
ETTm | 0.105 | 0.052 | 0.141 |
dominick | 0.485 | 0.345 | 0.453 |
ercot | 0.041 | 0.021 | 0.037 |
exchange_rate | 0.011 | 0.012 | 0.013 |
m4_quarterly | 0.079 | 0.077 | 0.119 |
m4_yearly | 0.125 | 0.121 | 0.161 |
m5 | 0.617 | 0.562 | 1.024 |
monash_australian_electricity | 0.067 | 0.036 | 0.084 |
monash_car_parts | 1.333 | 0.995 | 1.600 |
monash_cif_2016 | 0.033 | 0.016 | 0.015 |
monash_covid_deaths | 0.029 | 0.047 | 0.133 |
monash_fred_md | 0.035 | 0.042 | 0.122 |
monash_hospital | 0.059 | 0.057 | 0.073 |
monash_m1_monthly | 0.154 | 0.139 | 0.191 |
monash_m1_quarterly | 0.088 | 0.101 | 0.150 |
monash_m1_yearly | 0.133 | 0.151 | 0.209 |
monash_m3_monthly | 0.098 | 0.093 | 0.149 |
monash_m3_quarterly | 0.077 | 0.076 | 0.101 |
monash_m3_yearly | 0.156 | 0.129 | 0.167 |
monash_nn5_weekly | 0.084 | 0.084 | 0.123 |
monash_tourism_monthly | 0.091 | 0.090 | 0.104 |
monash_tourism_quarterly | 0.100 | 0.065 | 0.119 |
monash_tourism_yearly | 0.129 | 0.166 | 0.209 |
monash_traffic | 0.354 | 0.231 | 0.362 |
monash_weather | 0.215 | 0.134 | 0.217 |
nn5 | 0.248 | 0.150 | 0.425 |