Lag-Llama:第一个时间序列预测的开源基础模型介绍和性能测试-Lag-Llama测试

时间:2024-02-19 08:48:48

因为代码已经开源,所以我们可以直接测试,我们首先使用Lag-Llama的零样本预测能力,并将其性能与特定数据模型(如TFT和DeepAR)进行比较。

Lag-Llama的实现是建立在GluonTS之上的,所以我们还需要安装这个库。实验使用了澳大利亚电力需求数据集,该数据集包含五个单变量时间序列,以半小时的频率跟踪能源需求。

这里有个说明:Lag-Llama目前的实现是初期阶段。并且存还在积极开发中,后面可能还会有很大的调整,因为目前还没加入微调的功能。

1、环境设置

 !git clone https://github.com/time-series-foundation-models/lag-llama/ 
 cd lag-llama 
 pip install -r requirements.txt --quiet

然后需要我们从HuggingFace下载模型的权重。

 !huggingface-cli download time-series-foundation-models/Lag-Llama lag-llama.ckpt --local-dir /content/lag-llama

2、加载数据集

 import pandas as pd 
 import matplotlib.pyplot as plt 
 import matplotlib.dates as mdates 
 import torch
 
 from itertools import islice
 
 from gluonts.evaluation import make_evaluation_predictions, Evaluator 
 from gluonts.dataset.repository.datasets import get_dataset 
 from lag_llama.gluon.estimator import LagLlamaEstimator

可以直接从GluonTS加载数据集。

 dataset = get_dataset("australian_electricity_demand") 
 backtest_dataset = dataset.test prediction_length = dataset.metadata.prediction_length 
 context_length = 3 * prediction_length

3、使用Lag-Llama预测

简单地初始化模型并使用LagLlamaEstimator对象。

 ckpt = torch.load("lag-llama.ckpt", map_location=torch.device('cuda:0')) 
 estimator_args = ckpt["hyper_parameters"]["model_kwargs"] 
 estimator = LagLlamaEstimator( ckpt_path="lag-llama.ckpt", 
   prediction_length=prediction_length, 
   context_length=context_length, 
   input_size=estimator_args["input_size"], 
   n_layer=estimator_args["n_layer"], 
   n_embd_per_head=estimator_args["n_embd_per_head"], 
   n_head=estimator_args["n_head"], 
   scaling=estimator_args["scaling"], 
   time_feat=estimator_args["time_feat"]) 
 
 lightning_module = estimator.create_lightning_module() 
 transformation = estimator.create_transformation() 
 predictor = estimator.create_predictor(transformation, lightning_module)

使用make_evaluation_predictions函数生成零样本的预测。

 forecast_it, ts_it = make_evaluation_predictions(
   dataset=backtest_dataset, 
   predictor=predictor)

这个函数返回生成器。我们需要把它们转换成列表。

 forecasts = list(forecast_it) 
 tss = list(ts_it)

4、评估

GluonTS可以使用Evaluator对象方便地计算不同的性能指标。

 evaluator = Evaluator() 
 
 agg_metrics, ts_metrics = evaluator(iter(tss), iter(forecasts))

RMSE为481.57。

我们还可以随意地将预测可视化。

 plt.figure(figsize=(20, 15)) 
 date_formater = mdates.DateFormatter('%b, %d') 
 plt.rcParams.update({'font.size': 15}) 
 
 for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 4): 
   ax = plt.subplot(2, 2, idx+1) 
   plt.plot(ts[-4 * dataset.metadata.prediction_length:].to_timestamp(), label="target") 
   forecast.plot( color='g')
 
   plt.xticks(rotation=60) 
   ax.xaxis.set_major_formatter(date_formater) 
   ax.set_title(forecast.item_id) 
 
 plt.gcf().tight_layout() 
 plt.legend() 
 plt.show()

上图可以看到模型对数据做出了合理的预测,尽管它在第四个序列(图的右下角)上确实存在问题。

另外由于 Lag-Llama实现了概率预测,可以得到预测的不确定性区间。

5、与TFT和DeepAR相比

我们在数据集上训练TFT和DeepAR模型,看看它们是否能表现得更好。

为了节省时间,我们将训练设置为5个epoch。

 from gluonts.torch import TemporalFusionTransformerEstimator, DeepAREstimator 
 
 tft_estimator = TemporalFusionTransformerEstimator(
   prediction_length=prediction_length, 
   context_length=context_length, 
   freq="30min", 
   trainer_kwargs={"max_epochs": 5}) 
 
 deepar_estimator = DeepAREstimator(
   prediction_length=prediction_length, 
   context_length=context_length, 
   freq="30min", 
   trainer_kwargs={"max_epochs": 5})

训练过程。

 tft_predictor = tft_estimator.train(dataset.train) 
 deepar_predictor = deepar_estimator.train(dataset.train)

训练完成后,生成预测并计算RMSE。

 
 tft_forecast_it, tft_ts_it = make_evaluation_predictions(
   dataset=backtest_dataset, 
   predictor=tft_predictor) 
 
 deepar_forecast_it, deepar_ts_it = make_evaluation_predictions(
   dataset=backtest_dataset, 
   predictor=deepar_predictor) 
 
 tft_forecasts = list(tft_forecast_it) 
 tft_tss = list(tft_ts_it) 
 
 deepar_forecasts = list(deepar_forecast_it) 
 deepar_tss = list(deepar_ts_it) 
 
 # Get evaluation metrics
 tft_agg_metrics, tft_ts_metrics = evaluator(iter(tft_tss), iter(tft_forecasts)) 
 deepar_agg_metrics, deepar_ts_metrics = evaluator(iter(deepar_tss), iter(deepar_forecasts))

下表突出显示了性能最好的模型。

可以看到TFT是目前表现最好的模型,DeepAR的表现也优于laglama。

虽然laglllama的表现似乎不尽如人意,但该模型没有经过微调,而且零样本测本身就比较困难。

有趣的是,只训练了5个epoch这两个模型都取得了比Lag-Llama更好的结果。虽然样本预测可以节省时间,但训练五个epoch在时间和计算能力方面的要求应该不是很苛刻。所以目前可能零样本学习方面还需要很大的提升。