-
Notifications
You must be signed in to change notification settings - Fork 795
Open
Labels
documentationThis item involves documentation issuesThis item involves documentation issues
Description
## Description
When I finisned training the model and got the prediction result,I wanted to evaluate the result with forecast_it, ts_it =
make_evaluation_predictions() as the tutorial suggested, But when I tried to convert the the forecast_it and ts_it with list() function,
an error rose:
To Reproduce
forecast_it, ts_it = make_evaluation_predictions(
dataset=test_data, # test dataset
predictor=predictor, # predictor
num_samples=100, # number of sample paths we want for evaluation
)
forecasts_ev =list(forecast_it)
tss = list(ts_it)
df = pd.read_csv(
"C:/Users/leo/Downloads/AirPassengers.csv",
index_col=0,
parse_dates=True,
)
dataset = PandasDataset(df, target="target",freq='d')
# Split the data for training and testing
training_data, test_gen = split(dataset, date=pd.Period("2002/10/25", freq="1D"))
test_data = test_gen.generate_instances(prediction_length=168, windows=1)
1
#Train the model and make predictions
model = DeepAREstimator(
prediction_length=168, freq="d", trainer_kwargs={"max_epochs": 1},num_layers=5,batch_size=100,hidden_size=336,context_length=24
)
predictor = model.train(training_data)
forecasts = list(predictor.predict(test_data.input))
predictor.serialize(Path("D:/BaiduNetdiskDownload/auto_bnn/tmp"))
# Plot predictions
plt.plot(df["1954":], color="black")
a = []
for forecast in forecasts:
forecast.plot()
plt.legend(["True values"], loc="upper left", fontsize="xx-large")
plt.show()
#predictor saved and metrics evaluated
predictor_deserialized = Predictor.deserialize(Path("D:/BaiduNetdiskDownload/auto_bnn/tmp"))
forecast_it, ts_it = make_evaluation_predictions(
dataset=test_data, # test dataset
predictor=predictor, # predictor
num_samples=100, # number of sample paths we want for evaluation
)
forecasts_ev =list(forecast_it)
tss = list(ts_it)
evaluator = Evaluator(quantiles=[0.1, 0.5, 0.9])
agg_metrics, item_metrics = evaluator(tss, forecasts_ev)
print(json.dumps(agg_metrics, indent=4))
## Error message or code output
Traceback (most recent call last):
File "D:\BaiduNetdiskDownload\auto_bnn\c.py", line 55, in <module>
forecasts_ev =list(forecast_it)
File "D:\anaconda3\envs\glu39\lib\site-packages\gluonts\torch\model\predictor.py", line 90, in predict
yield from self.forecast_generator(
File "D:\anaconda3\envs\glu39\lib\site-packages\gluonts\model\forecast_generator.py", line 172, in __call__
for batch in inference_data_loader:
File "D:\anaconda3\envs\glu39\lib\site-packages\gluonts\transform\_base.py", line 111, in __iter__
yield from self.transformation(
File "D:\anaconda3\envs\glu39\lib\site-packages\gluonts\transform\_base.py", line 132, in __call__
for data_entry in data_it:
File "D:\anaconda3\envs\glu39\lib\site-packages\gluonts\dataset\loader.py", line 50, in __call__
yield from batcher(data, self.batch_size)
File "D:\anaconda3\envs\glu39\lib\site-packages\gluonts\itertools.py", line 128, in get_batch
return list(itertools.islice(it, batch_size))
File "D:\anaconda3\envs\glu39\lib\site-packages\gluonts\transform\_base.py", line 132, in __call__
for data_entry in data_it:
File "D:\anaconda3\envs\glu39\lib\site-packages\gluonts\transform\_base.py", line 186, in __call__
for data_entry in data_it:
File "D:\anaconda3\envs\glu39\lib\site-packages\gluonts\transform\_base.py", line 132, in __call__
for data_entry in data_it:
File "D:\anaconda3\envs\glu39\lib\site-packages\gluonts\transform\_base.py", line 132, in __call__
for data_entry in data_it:
File "D:\anaconda3\envs\glu39\lib\site-packages\gluonts\transform\_base.py", line 132, in __call__
for data_entry in data_it:
[Previous line repeated 8 more times]
File "D:\anaconda3\envs\glu39\lib\site-packages\gluonts\dataset\split.py", line 414, in __iter__
for input, _label in self.test_data:
File "D:\anaconda3\envs\glu39\lib\site-packages\gluonts\dataset\split.py", line 386, in __iter__
yield from self.splitter.generate_test_pairs(
File "D:\anaconda3\envs\glu39\lib\site-packages\gluonts\dataset\split.py", line 251, in generate_test_pairs
test = self.test_pair(
File "D:\anaconda3\envs\glu39\lib\site-packages\gluonts\dataset\split.py", line 289, in test_pair
offset_ += entry[FieldName.TARGET].shape[-1]
TypeError: tuple indices must be integers or slices, not str
put error or undesired output here
TypeError: tuple indices must be integers or slices, not str
## Environment
- Operating system: Windows 11
- Python version: 3.10
- GluonTS version: 1.6
- MXNet version: not installed
(Add as much information about your environment as possible, e.g. dependencies versions.)
Metadata
Metadata
Assignees
Labels
documentationThis item involves documentation issuesThis item involves documentation issues