-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Description
🐛 Bug Description
To Reproduce
Steps to reproduce the behavior:
1.qrun benchmarks/TFT/workflow_config_tft_Alpha158.yaml
Expected Behavior
Screenshot
*** Fitting TemporalFusionTransformer ***
Getting batched_data
Using keras standard fit
WARNING:tensorflow:From /root/anaconda3/envs/myqlib/lib/python3.7/site-packages/tensorflow_core/python/ops/math_grad.py:1424: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Train on 633886 samples, validate on 122170 samples
Epoch 1/100
2021-04-14 11:11:05.560207: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcublas.so.10.0
2021-04-14 11:11:06.273970: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudnn.so.7
633886/633886 [==============================] - 248s 391us/sample - loss: 0.6394 - val_loss: 0.5856
Epoch 2/100
633886/633886 [==============================] - 231s 364us/sample - loss: 0.5897 - val_loss: 0.5797
Epoch 3/100
633886/633886 [==============================] - 231s 364us/sample - loss: 0.5825 - val_loss: 0.5794
Epoch 4/100
633886/633886 [==============================] - 232s 365us/sample - loss: 0.5771 - val_loss: 0.5772
Epoch 5/100
633886/633886 [==============================] - 231s 364us/sample - loss: 0.5722 - val_loss: 0.5835
Epoch 6/100
633886/633886 [==============================] - 231s 364us/sample - loss: 0.5677 - val_loss: 0.5923
Epoch 7/100
633886/633886 [==============================] - 231s 364us/sample - loss: 0.5631 - val_loss: 0.5884
Epoch 8/100
633886/633886 [==============================] - 231s 364us/sample - loss: 0.5587 - val_loss: 0.6014
Epoch 9/100
633886/633886 [==============================] - 231s 364us/sample - loss: 0.5540 - val_loss: 0.6064
Epoch 10/100
633886/633886 [==============================] - 231s 364us/sample - loss: 0.5492 - val_loss: 0.6066
Epoch 11/100
633886/633886 [==============================] - 231s 364us/sample - loss: 0.5446 - val_loss: 0.6194
Epoch 12/100
633886/633886 [==============================] - 231s 365us/sample - loss: 0.5399 - val_loss: 0.6256
Epoch 13/100
633886/633886 [==============================] - 231s 365us/sample - loss: 0.5355 - val_loss: 0.6257
Epoch 14/100
633886/633886 [==============================] - 231s 365us/sample - loss: 0.5313 - val_loss: 0.6424
Cannot load from qlib_tft_model/tmp, skipping ...
*** Finished training ***
WARNING:tensorflow:From /home/disk/model/qlib/examples/benchmarks/TFT/libs/utils.py:168: The name tf.get_collection is deprecated. Please use tf.compat.v1.get_collection instead.
WARNING:tensorflow:From /home/disk/model/qlib/examples/benchmarks/TFT/libs/utils.py:168: The name tf.GraphKeys is deprecated. Please use tf.compat.v1.GraphKeys instead.
WARNING:tensorflow:From /home/disk/model/qlib/examples/benchmarks/TFT/libs/utils.py:169: The name tf.train.Saver is deprecated. Please use tf.compat.v1.train.Saver instead.
Model saved to: qlib_tft_model/saved_model/TemporalFusionTransformer.ckpt
Training completed.
[2613:MainThread](2021-04-14 12:05:09,635) ERROR - qlib.workflow - [utils.py:35] - An exception has been raised[TypeError: can't pickle _thread.RLock objects].
File "/root/anaconda3/envs/myqlib/bin/qrun", line 8, in
sys.exit(run())
File "/root/anaconda3/envs/myqlib/lib/python3.7/site-packages/qlib/workflow/cli.py", line 61, in run
fire.Fire(workflow)
File "/root/anaconda3/envs/myqlib/lib/python3.7/site-packages/fire/core.py", line 141, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/root/anaconda3/envs/myqlib/lib/python3.7/site-packages/fire/core.py", line 471, in _Fire
target=component.name)
File "/root/anaconda3/envs/myqlib/lib/python3.7/site-packages/fire/core.py", line 681, in _CallAndUpdateTrace
component = fn(*varargs, kwargs)
File "/root/anaconda3/envs/myqlib/lib/python3.7/site-packages/qlib/workflow/cli.py", line 56, in workflow
task_train(config.get("task"), experiment_name=experiment_name)
File "/root/anaconda3/envs/myqlib/lib/python3.7/site-packages/qlib/model/trainer.py", line 29, in task_train
R.save_objects({"params.pkl": model})
File "/root/anaconda3/envs/myqlib/lib/python3.7/site-packages/qlib/workflow/init.py", line 404, in save_objects
self.get_exp().get_recorder().save_objects(local_path, artifact_path, **kwargs)
File "/root/anaconda3/envs/myqlib/lib/python3.7/site-packages/qlib/workflow/recorder.py", line 294, in save_objects
pickle.dump(data, f)
TypeError: can't pickle _thread.RLock objects
Environment
Note: User could run cd scripts && python collect_info.py all
under project directory to get system information
and paste them here directly.
- Qlib version: pyqlib 0.6.3.99
- Python version:python3.7
- OS (
Windows
,Linux
,MacOS
):linux - Commit number (optional, please provide it if you are using the dev version):