diff --git a/.github/workflows/test_macos.yml b/.github/workflows/test_macos.yml index 46e5eaeb3a..e684115966 100644 --- a/.github/workflows/test_macos.yml +++ b/.github/workflows/test_macos.yml @@ -60,7 +60,7 @@ jobs: python -m pip install --upgrade cython python -m pip install numpy jupyter jupyter_contrib_nbextensions python -m pip install -U scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't - python setup.py install + pip install -e . - name: Install test dependencies run: | python -m pip install --upgrade pip diff --git a/README.md b/README.md index 62005ada72..4df8a0ca37 100644 --- a/README.md +++ b/README.md @@ -11,23 +11,24 @@ Recent released features | Feature | Status | | -- | ------ | -| Meta-Learning-based framework & DDG-DA | [Released](https://github.com/microsoft/qlib/pull/743) on Jan 10, 2022 | -| Planning-based portfolio optimization | [Released](https://github.com/microsoft/qlib/pull/754) on Dec 28, 2021 | -| Release Qlib v0.8.0 | [Released](https://github.com/microsoft/qlib/releases/tag/v0.8.0) on Dec 8, 2021 | -| ADD model | [Released](https://github.com/microsoft/qlib/pull/704) on Nov 22, 2021 | -| ADARNN model | [Released](https://github.com/microsoft/qlib/pull/689) on Nov 14, 2021 | -| TCN model | [Released](https://github.com/microsoft/qlib/pull/668) on Nov 4, 2021 | -| Nested Decision Framework | [Released](https://github.com/microsoft/qlib/pull/438) on Oct 1, 2021. [Example](https://github.com/microsoft/qlib/blob/main/examples/nested_decision_execution/workflow.py) and [Doc](https://qlib.readthedocs.io/en/latest/component/highfreq.html) | -|Temporal Routing Adaptor (TRA) | [Released](https://github.com/microsoft/qlib/pull/531) on July 30, 2021 | -| Transformer & Localformer | [Released](https://github.com/microsoft/qlib/pull/508) on July 22, 2021 | -| Release Qlib v0.7.0 | [Released](https://github.com/microsoft/qlib/releases/tag/v0.7.0) on July 12, 2021 | -| TCTS Model | [Released](https://github.com/microsoft/qlib/pull/491) on July 1, 2021 | -| Online serving and automatic model rolling | :star: [Released](https://github.com/microsoft/qlib/pull/290) on May 17, 2021 | -| DoubleEnsemble Model | [Released](https://github.com/microsoft/qlib/pull/286) on Mar 2, 2021 | -| High-frequency data processing example | [Released](https://github.com/microsoft/qlib/pull/257) on Feb 5, 2021 | -| High-frequency trading example | [Part of code released](https://github.com/microsoft/qlib/pull/227) on Jan 28, 2021 | -| High-frequency data(1min) | [Released](https://github.com/microsoft/qlib/pull/221) on Jan 27, 2021 | -| Tabnet Model | [Released](https://github.com/microsoft/qlib/pull/205) on Jan 22, 2021 | +| Arctic Provider Backend & Orderbook data example | :hammer: [Rleased](https://github.com/microsoft/qlib/pull/744) on Jan 17, 2022 | +| Meta-Learning-based framework & DDG-DA | :chart_with_upwards_trend: :hammer: [Released](https://github.com/microsoft/qlib/pull/743) on Jan 10, 2022 | +| Planning-based portfolio optimization | :hammer: [Released](https://github.com/microsoft/qlib/pull/754) on Dec 28, 2021 | +| Release Qlib v0.8.0 | :octocat: [Released](https://github.com/microsoft/qlib/releases/tag/v0.8.0) on Dec 8, 2021 | +| ADD model | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/704) on Nov 22, 2021 | +| ADARNN model | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/689) on Nov 14, 2021 | +| TCN model | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/668) on Nov 4, 2021 | +| Nested Decision Framework | :hammer: [Released](https://github.com/microsoft/qlib/pull/438) on Oct 1, 2021. [Example](https://github.com/microsoft/qlib/blob/main/examples/nested_decision_execution/workflow.py) and [Doc](https://qlib.readthedocs.io/en/latest/component/highfreq.html) | +| Temporal Routing Adaptor (TRA) | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/531) on July 30, 2021 | +| Transformer & Localformer | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/508) on July 22, 2021 | +| Release Qlib v0.7.0 | :octocat: [Released](https://github.com/microsoft/qlib/releases/tag/v0.7.0) on July 12, 2021 | +| TCTS Model | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/491) on July 1, 2021 | +| Online serving and automatic model rolling | :hammer: [Released](https://github.com/microsoft/qlib/pull/290) on May 17, 2021 | +| DoubleEnsemble Model | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/286) on Mar 2, 2021 | +| High-frequency data processing example | :hammer: [Released](https://github.com/microsoft/qlib/pull/257) on Feb 5, 2021 | +| High-frequency trading example | :chart_with_upwards_trend: [Part of code released](https://github.com/microsoft/qlib/pull/227) on Jan 28, 2021 | +| High-frequency data(1min) | :rice: [Released](https://github.com/microsoft/qlib/pull/221) on Jan 27, 2021 | +| Tabnet Model | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/205) on Jan 22, 2021 | Features released before 2021 are not listed here. @@ -72,7 +73,6 @@ Your feedbacks about the features are very important. | Feature | Status | | -- | ------ | | Point-in-Time database | Under review: https://github.com/microsoft/qlib/pull/343 | -| Orderbook database | Under review: https://github.com/microsoft/qlib/pull/744 | # Framework of Qlib diff --git a/examples/orderbook_data/README.md b/examples/orderbook_data/README.md new file mode 100644 index 0000000000..a4b25ac1a1 --- /dev/null +++ b/examples/orderbook_data/README.md @@ -0,0 +1,51 @@ +# Introduction + +This example tries to demonstrate how Qlib supports data without fixed shared frequency. + +For example, +- Daily prices volume data are fixed-frequency data. The data comes in a fixed frequency (i.e. daily) +- Orders are not fixed data and they may come at any time point + +To support such non-fixed-frequency, Qlib implements an Arctic-based backend. +Here is an example to import and query data based on this backend. + +# Installation + +Please refer to [the installation docs](https://docs.mongodb.com/manual/installation/) of mongodb. +Current version of script with default value tries to connect localhost **via default port without authentication**. + +Run following command to install necessary libraries +``` +pip install pytest +``` + +# Importing example data + + +1. (Optional) Please follow the first part of [this section](https://github.com/microsoft/qlib#data-preparation) to **get 1min data** of Qlib. +2. Please follow following steps to download example data +```bash +cd examples/orderbook_data/ +wget http://fintech.msra.cn/stock_data/downloads/highfreq_orderboook_example_data.tar.bz2 +tar xf highfreq_orderboook_example_data.tar.bz2 +``` + +3. Please import the example data to your mongo db +```bash +cd examples/orderbook_data/ +python create_dataset.py initialize_library # Initialization Libraries +python create_dataset.py import_data # Initialization Libraries +``` + +# Query Examples + +After importing these data, you run `example.py` to create some high-frequency features. +```bash +cd examples/orderbook_data/ +pytest -s --disable-warnings example.py # If you want run all examples +pytest -s --disable-warnings example.py::TestClass::test_exp_10 # If you want to run specific example +``` + + +# Known limitations +Expression computing between different frequencies are not supported yet diff --git a/examples/orderbook_data/create_dataset.py b/examples/orderbook_data/create_dataset.py new file mode 100755 index 0000000000..f2b7a8a680 --- /dev/null +++ b/examples/orderbook_data/create_dataset.py @@ -0,0 +1,315 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +""" + NOTE: + - This scripts is a demo to import example data import Qlib + - !!!!!!!!!!!!!!!TODO!!!!!!!!!!!!!!!!!!!: + - Its structure is not well designed and very ugly, your contribution is welcome to make importing dataset easier +""" +from datetime import date, datetime as dt +import os +from pathlib import Path +import random +import shutil +import time +import traceback + +from arctic import Arctic, chunkstore +import arctic +from arctic import Arctic, CHUNK_STORE +from arctic.chunkstore.chunkstore import CHUNK_SIZE +import fire +from joblib import Parallel, delayed, parallel +import numpy as np +import pandas as pd +from pandas import DataFrame +from pandas.core.indexes.datetimes import date_range +from pymongo.mongo_client import MongoClient + +DIRNAME = Path(__file__).absolute().resolve().parent + +# CONFIG +N_JOBS = -1 # leaving one kernel free +LOG_FILE_PATH = DIRNAME / "log_file" +DATA_PATH = DIRNAME / "raw_data" +DATABASE_PATH = DIRNAME / "orig_data" +DATA_INFO_PATH = DIRNAME / "data_info" +DATA_FINISH_INFO_PATH = DIRNAME / "./data_finish_info" +DOC_TYPE = ["Tick", "Order", "OrderQueue", "Transaction", "Day", "Minute"] +MAX_SIZE = 3000 * 1024 * 1024 * 1024 +ALL_STOCK_PATH = DATABASE_PATH / "all.txt" +ARCTIC_SRV = "127.0.0.1" + + +def get_library_name(doc_type): + if str.lower(doc_type) == str.lower("Tick"): + return "ticks" + else: + return str.lower(doc_type) + + +def is_stock(exchange_place, code): + if exchange_place == "SH" and code[0] != "6": + return False + if exchange_place == "SZ" and code[0] != "0" and code[:2] != "30": + return False + return True + + +def add_one_stock_daily_data(filepath, type, exchange_place, arc, date): + """ + exchange_place: "SZ" OR "SH" + type: "tick", "orderbook", ... + filepath: the path of csv + arc: arclink created by a process + """ + code = os.path.split(filepath)[-1].split(".csv")[0] + if exchange_place == "SH" and code[0] != "6": + return + if exchange_place == "SZ" and code[0] != "0" and code[:2] != "30": + return + + df = pd.read_csv(filepath, encoding="gbk", dtype={"code": str}) + code = os.path.split(filepath)[-1].split(".csv")[0] + + def format_time(day, hms): + day = str(day) + hms = str(hms) + if hms[0] == "1": # >=10, + return ( + "-".join([day[0:4], day[4:6], day[6:8]]) + " " + ":".join([hms[:2], hms[2:4], hms[4:6] + "." + hms[6:]]) + ) + else: + return ( + "-".join([day[0:4], day[4:6], day[6:8]]) + " " + ":".join([hms[:1], hms[1:3], hms[3:5] + "." + hms[5:]]) + ) + + ## Discard the entire row if wrong data timestamp encoutered. + timestamp = list(zip(list(df["date"]), list(df["time"]))) + error_index_list = [] + for index, t in enumerate(timestamp): + try: + pd.Timestamp(format_time(t[0], t[1])) + except Exception: + error_index_list.append(index) ## The row number of the error line + + # to-do: writting to logs + + if len(error_index_list) > 0: + print("error: {}, {}".format(filepath, len(error_index_list))) + + df = df.drop(error_index_list) + timestamp = list(zip(list(df["date"]), list(df["time"]))) ## The cleaned timestamp + # generate timestamp + pd_timestamp = pd.DatetimeIndex( + [pd.Timestamp(format_time(timestamp[i][0], timestamp[i][1])) for i in range(len(df["date"]))] + ) + df = df.drop(columns=["date", "time", "name", "code", "wind_code"]) + # df = pd.DataFrame(data=df.to_dict("list"), index=pd_timestamp) + df["date"] = pd.to_datetime(pd_timestamp) + df.set_index("date", inplace=True) + + if str.lower(type) == "orderqueue": + ## extract ab1~ab50 + df["ab"] = [ + ",".join([str(int(row["ab" + str(i + 1)])) for i in range(0, row["ab_items"])]) + for timestamp, row in df.iterrows() + ] + df = df.drop(columns=["ab" + str(i) for i in range(1, 51)]) + + type = get_library_name(type) + # arc.initialize_library(type, lib_type=CHUNK_STORE) + lib = arc[type] + + symbol = "".join([exchange_place, code]) + if symbol in lib.list_symbols(): + print("update {0}, date={1}".format(symbol, date)) + if df.empty == True: + return error_index_list + lib.update(symbol, df, chunk_size="D") + else: + print("write {0}, date={1}".format(symbol, date)) + lib.write(symbol, df, chunk_size="D") + return error_index_list + + +def add_one_stock_daily_data_wrapper(filepath, type, exchange_place, index, date): + pid = os.getpid() + code = os.path.split(filepath)[-1].split(".csv")[0] + arc = Arctic(ARCTIC_SRV) + try: + if index % 100 == 0: + print("index = {}, filepath = {}".format(index, filepath)) + error_index_list = add_one_stock_daily_data(filepath, type, exchange_place, arc, date) + if error_index_list is not None and len(error_index_list) > 0: + f = open(os.path.join(LOG_FILE_PATH, "temp_timestamp_error_{0}_{1}_{2}.txt".format(pid, date, type)), "a+") + f.write("{}, {}, {}\n".format(filepath, error_index_list, exchange_place + "_" + code)) + f.close() + + except Exception as e: + info = traceback.format_exc() + print("error:" + str(e)) + f = open(os.path.join(LOG_FILE_PATH, "temp_fail_{0}_{1}_{2}.txt".format(pid, date, type)), "a+") + f.write("fail:" + str(filepath) + "\n" + str(e) + "\n" + str(info) + "\n") + f.close() + + finally: + arc.reset() + + +def add_data(tick_date, doc_type, stock_name_dict): + pid = os.getpid() + + if doc_type not in DOC_TYPE: + print("doc_type not in {}".format(DOC_TYPE)) + return + try: + begin_time = time.time() + os.system(f"cp {DATABASE_PATH}/{tick_date + '_{}.tar.gz'.format(doc_type)} {DATA_PATH}/") + + os.system( + f"tar -xvzf {DATA_PATH}/{tick_date + '_{}.tar.gz'.format(doc_type)} -C {DATA_PATH}/ {tick_date + '_' + doc_type}/SH" + ) + os.system( + f"tar -xvzf {DATA_PATH}/{tick_date + '_{}.tar.gz'.format(doc_type)} -C {DATA_PATH}/ {tick_date + '_' + doc_type}/SZ" + ) + os.system(f"chmod 777 {DATA_PATH}") + os.system(f"chmod 777 {DATA_PATH}/{tick_date + '_' + doc_type}") + os.system(f"chmod 777 {DATA_PATH}/{tick_date + '_' + doc_type}/SH") + os.system(f"chmod 777 {DATA_PATH}/{tick_date + '_' + doc_type}/SZ") + os.system(f"chmod 777 {DATA_PATH}/{tick_date + '_' + doc_type}/SH/{tick_date}") + os.system(f"chmod 777 {DATA_PATH}/{tick_date + '_' + doc_type}/SZ/{tick_date}") + + print("tick_date={}".format(tick_date)) + + temp_data_path_sh = os.path.join(DATA_PATH, tick_date + "_" + doc_type, "SH", tick_date) + temp_data_path_sz = os.path.join(DATA_PATH, tick_date + "_" + doc_type, "SZ", tick_date) + is_files_exist = {"sh": os.path.exists(temp_data_path_sh), "sz": os.path.exists(temp_data_path_sz)} + + sz_files = ( + ( + set([i.split(".csv")[0] for i in os.listdir(temp_data_path_sz) if i[:2] == "30" or i[0] == "0"]) + & set(stock_name_dict["SZ"]) + ) + if is_files_exist["sz"] + else set() + ) + sz_file_nums = len(sz_files) if is_files_exist["sz"] else 0 + sh_files = ( + ( + set([i.split(".csv")[0] for i in os.listdir(temp_data_path_sh) if i[0] == "6"]) + & set(stock_name_dict["SH"]) + ) + if is_files_exist["sh"] + else set() + ) + sh_file_nums = len(sh_files) if is_files_exist["sh"] else 0 + print("sz_file_nums:{}, sh_file_nums:{}".format(sz_file_nums, sh_file_nums)) + + f = (DATA_INFO_PATH / "data_info_log_{}_{}".format(doc_type, tick_date)).open("w+") + f.write("sz:{}, sh:{}, date:{}:".format(sz_file_nums, sh_file_nums, tick_date) + "\n") + f.close() + + if sh_file_nums > 0: + # write is not thread-safe, update may be thread-safe + Parallel(n_jobs=N_JOBS)( + delayed(add_one_stock_daily_data_wrapper)( + os.path.join(temp_data_path_sh, name + ".csv"), doc_type, "SH", index, tick_date + ) + for index, name in enumerate(list(sh_files)) + ) + if sz_file_nums > 0: + # write is not thread-safe, update may be thread-safe + Parallel(n_jobs=N_JOBS)( + delayed(add_one_stock_daily_data_wrapper)( + os.path.join(temp_data_path_sz, name + ".csv"), doc_type, "SZ", index, tick_date + ) + for index, name in enumerate(list(sz_files)) + ) + + os.system(f"rm -f {DATA_PATH}/{tick_date + '_{}.tar.gz'.format(doc_type)}") + os.system(f"rm -rf {DATA_PATH}/{tick_date + '_' + doc_type}") + total_time = time.time() - begin_time + f = (DATA_FINISH_INFO_PATH / "data_info_finish_log_{}_{}".format(doc_type, tick_date)).open("w+") + f.write("finish: date:{}, consume_time:{}, end_time: {}".format(tick_date, total_time, time.time()) + "\n") + f.close() + + except Exception as e: + info = traceback.format_exc() + print("date error:" + str(e)) + f = open(os.path.join(LOG_FILE_PATH, "temp_fail_{0}_{1}_{2}.txt".format(pid, tick_date, doc_type)), "a+") + f.write("fail:" + str(tick_date) + "\n" + str(e) + "\n" + str(info) + "\n") + f.close() + + +class DSCreator: + """Dataset creator""" + + def clear(self): + client = MongoClient(ARCTIC_SRV) + client.drop_database("arctic") + + def initialize_library(self): + arc = Arctic(ARCTIC_SRV) + for doc_type in DOC_TYPE: + arc.initialize_library(get_library_name(doc_type), lib_type=CHUNK_STORE) + + def _get_empty_folder(self, fp: Path): + fp = Path(fp) + if fp.exists(): + shutil.rmtree(fp) + fp.mkdir(parents=True, exist_ok=True) + + def import_data(self, doc_type_l=["Tick", "Transaction", "Order"]): + # clear all the old files + for fp in LOG_FILE_PATH, DATA_INFO_PATH, DATA_FINISH_INFO_PATH, DATA_PATH: + self._get_empty_folder(fp) + + arc = Arctic(ARCTIC_SRV) + for doc_type in DOC_TYPE: + # arc.initialize_library(get_library_name(doc_type), lib_type=CHUNK_STORE) + arc.set_quota(get_library_name(doc_type), MAX_SIZE) + arc.reset() + + # doc_type = 'Day' + for doc_type in doc_type_l: + date_list = list(set([int(path.split("_")[0]) for path in os.listdir(DATABASE_PATH) if doc_type in path])) + date_list.sort() + date_list = [str(date) for date in date_list] + + f = open(ALL_STOCK_PATH, "r") + stock_name_list = [lines.split("\t")[0] for lines in f.readlines()] + f.close() + stock_name_dict = { + "SH": [stock_name[2:] for stock_name in stock_name_list if "SH" in stock_name], + "SZ": [stock_name[2:] for stock_name in stock_name_list if "SZ" in stock_name], + } + + lib_name = get_library_name(doc_type) + a = Arctic(ARCTIC_SRV) + # a.initialize_library(lib_name, lib_type=CHUNK_STORE) + + stock_name_exist = a[lib_name].list_symbols() + lib = a[lib_name] + initialize_count = 0 + for stock_name in stock_name_list: + if stock_name not in stock_name_exist: + initialize_count += 1 + # A placeholder for stocks + pdf = pd.DataFrame(index=[pd.Timestamp("1900-01-01")]) + pdf.index.name = "date" # an col named date is necessary + lib.write(stock_name, pdf) + print("initialize count: {}".format(initialize_count)) + print("tasks: {}".format(date_list)) + a.reset() + + # date_list = [files.split("_")[0] for files in os.listdir("./raw_data_price") if "tar" in files] + # print(len(date_list)) + date_list = ["20201231"] # for test + Parallel(n_jobs=min(2, len(date_list)))( + delayed(add_data)(date, doc_type, stock_name_dict) for date in date_list + ) + + +if __name__ == "__main__": + fire.Fire(DSCreator) diff --git a/examples/orderbook_data/example.py b/examples/orderbook_data/example.py new file mode 100644 index 0000000000..6e3232229d --- /dev/null +++ b/examples/orderbook_data/example.py @@ -0,0 +1,308 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from arctic.arctic import Arctic +import qlib +from qlib.data import D +import unittest + + +class TestClass(unittest.TestCase): + """ + Useful commands + - run all tests: pytest examples/orderbook_data/example.py + - run a single test: pytest -s --pdb --disable-warnings examples/orderbook_data/example.py::TestClass::test_basic01 + """ + + def setUp(self): + """ + Configure for arctic + """ + provider_uri = "~/.qlib/qlib_data/yahoo_cn_1min" + qlib.init( + provider_uri=provider_uri, + mem_cache_size_limit=1024 ** 3 * 2, + mem_cache_type="sizeof", + kernels=1, + expression_provider={"class": "LocalExpressionProvider", "kwargs": {"time2idx": False}}, + feature_provider={"class": "ArcticFeatureProvider", "kwargs": {"uri": "127.0.0.1"}}, + dataset_provider={ + "class": "LocalDatasetProvider", + "kwargs": { + "align_time": False, # Order book is not fixed, so it can't be align to a shared fixed frequency calendar + }, + }, + ) + # self.stocks_list = ["SH600519"] + self.stocks_list = ["SZ000725"] + + def test_basic(self): + # NOTE: this data contains a lot of zeros in $askX and $bidX + df = D.features( + self.stocks_list, + fields=["$ask1", "$ask2", "$bid1", "$bid2"], + freq="ticks", + start_time="20201230", + end_time="20210101", + ) + print(df) + + def test_basic_without_time(self): + df = D.features(self.stocks_list, fields=["$ask1"], freq="ticks") + print(df) + + def test_basic01(self): + df = D.features( + self.stocks_list, + fields=["TResample($ask1, '1min', 'last')"], + freq="ticks", + start_time="20201230", + end_time="20210101", + ) + print(df) + + def test_basic02(self): + df = D.features( + self.stocks_list, + fields=["$function_code"], + freq="transaction", + start_time="20201230", + end_time="20210101", + ) + print(df) + + def test_basic03(self): + df = D.features( + self.stocks_list, + fields=["$function_code"], + freq="order", + start_time="20201230", + end_time="20210101", + ) + print(df) + + # Here are some popular expressions for high-frequency + # 1) some shared expression + expr_sum_buy_ask_1 = "(TResample($ask1, '1min', 'last') + TResample($bid1, '1min', 'last'))" + total_volume = ( + "TResample(" + + "+".join([f"${name}{i}" for i in range(1, 11) for name in ["asize", "bsize"]]) + + ", '1min', 'sum')" + ) + + @staticmethod + def total_func(name, method): + return "TResample(" + "+".join([f"${name}{i}" for i in range(1, 11)]) + ",'1min', '{}')".format(method) + + def test_exp_01(self): + exprs = [] + names = [] + for name in ["asize", "bsize"]: + for i in range(1, 11): + exprs.append(f"TResample(${name}{i}, '1min', 'mean') / ({self.total_volume})") + names.append(f"v_{name}_{i}") + df = D.features(self.stocks_list, fields=exprs, freq="ticks") + df.columns = names + print(df) + + # 2) some often used papers; + def test_exp_02(self): + spread_func = ( + lambda index: f"2 * TResample($ask{index} - $bid{index}, '1min', 'last') / {self.expr_sum_buy_ask_1}" + ) + mid_func = ( + lambda index: f"2 * TResample(($ask{index} + $bid{index})/2, '1min', 'last') / {self.expr_sum_buy_ask_1}" + ) + + exprs = [] + names = [] + for i in range(1, 11): + exprs.extend([spread_func(i), mid_func(i)]) + names.extend([f"p_spread_{i}", f"p_mid_{i}"]) + df = D.features(self.stocks_list, fields=exprs, freq="ticks") + df.columns = names + print(df) + + def test_exp_03(self): + expr3_func1 = ( + lambda name, index_left, index_right: f"2 * TResample(Abs(${name}{index_left} - ${name}{index_right}), '1min', 'last') / {self.expr_sum_buy_ask_1}" + ) + for name in ["ask", "bid"]: + for i in range(1, 10): + exprs = [expr3_func1(name, i + 1, i)] + names = [f"p_diff_{name}_{i}_{i+1}"] + exprs.extend([expr3_func1("ask", 10, 1), expr3_func1("bid", 1, 10)]) + names.extend(["p_diff_ask_10_1", "p_diff_bid_1_10"]) + df = D.features(self.stocks_list, fields=exprs, freq="ticks") + df.columns = names + print(df) + + def test_exp_04(self): + exprs = [] + names = [] + for name in ["asize", "bsize"]: + exprs.append(f"(({ self.total_func(name, 'mean')}) / 10) / {self.total_volume}") + names.append(f"v_avg_{name}") + + df = D.features(self.stocks_list, fields=exprs, freq="ticks") + df.columns = names + print(df) + + def test_exp_05(self): + exprs = [ + f"2 * Sub({ self.total_func('ask', 'last')}, {self.total_func('bid', 'last')})/{self.expr_sum_buy_ask_1}", + f"Sub({ self.total_func('asize', 'mean')}, {self.total_func('bsize', 'mean')})/{self.total_volume}", + ] + names = ["p_accspread", "v_accspread"] + + df = D.features(self.stocks_list, fields=exprs, freq="ticks") + df.columns = names + print(df) + + # (p|v)_diff_(ask|bid|asize|bsize)_(time_interval) + def test_exp_06(self): + t = 3 + expr6_price_func = ( + lambda name, index, method: f'2 * (TResample(${name}{index}, "{t}s", "{method}") - Ref(TResample(${name}{index}, "{t}s", "{method}"), 1)) / {t}' + ) + exprs = [] + names = [] + for i in range(1, 11): + for name in ["bid", "ask"]: + exprs.append( + f"TResample({expr6_price_func(name, i, 'last')}, '1min', 'mean') / {self.expr_sum_buy_ask_1}" + ) + names.append(f"p_diff_{name}{i}_{t}s") + + for i in range(1, 11): + for name in ["asize", "bsize"]: + exprs.append(f"TResample({expr6_price_func(name, i, 'mean')}, '1min', 'mean') / {self.total_volume}") + names.append(f"v_diff_{name}{i}_{t}s") + + df = D.features(self.stocks_list, fields=exprs, freq="ticks") + df.columns = names + print(df) + + # TODOs: + # Following expressions may be implemented in the future + # expr7_2 = lambda funccode, bsflag, time_interval: \ + # "TResample(TRolling(TEq(@transaction.function_code, {}) & TEq(@transaction.bs_flag ,{}), '{}s', 'sum') / \ + # TRolling(@transaction.function_code, '{}s', 'count') , '1min', 'mean')".format(ord(funccode), bsflag,time_interval,time_interval) + # create_dataset(7, "SH600000", [expr7_2("C")] + [expr7(funccode, ordercode) for funccode in ['B','S'] for ordercode in ['0','1']]) + # create_dataset(7, ["SH600000"], [expr7_2("C", 48)] ) + + @staticmethod + def expr7_init(funccode, ordercode, time_interval): + # NOTE: based on on order frequency (i.e. freq="order") + return f"Rolling(Eq($function_code, {ord(funccode)}) & Eq($order_kind ,{ord(ordercode)}), '{time_interval}s', 'sum') / Rolling($function_code, '{time_interval}s', 'count')" + + # (la|lb|ma|mb|ca|cb)_intensity_(time_interval) + def test_exp_07_1(self): + # NOTE: based on transaction frequency (i.e. freq="transaction") + expr7_3 = ( + lambda funccode, code, time_interval: f"TResample(Rolling(Eq($function_code, {ord(funccode)}) & {code}($ask_order, $bid_order) , '{time_interval}s', 'sum') / Rolling($function_code, '{time_interval}s', 'count') , '1min', 'mean')" + ) + + exprs = [expr7_3("C", "Gt", "3"), expr7_3("C", "Lt", "3")] + names = ["ca_intensity_3s", "cb_intensity_3s"] + + df = D.features(self.stocks_list, fields=exprs, freq="transaction") + df.columns = names + print(df) + + trans_dict = {"B": "a", "S": "b", "0": "l", "1": "m"} + + def test_exp_07_2(self): + # NOTE: based on on order frequency + expr7 = ( + lambda funccode, ordercode, time_interval: f"TResample({self.expr7_init(funccode, ordercode, time_interval)}, '1min', 'mean')" + ) + + exprs = [] + names = [] + for funccode in ["B", "S"]: + for ordercode in ["0", "1"]: + exprs.append(expr7(funccode, ordercode, "3")) + names.append(self.trans_dict[ordercode] + self.trans_dict[funccode] + "_intensity_3s") + df = D.features(self.stocks_list, fields=exprs, freq="transaction") + df.columns = names + print(df) + + @staticmethod + def expr7_3_init(funccode, code, time_interval): + # NOTE: It depends on transaction frequency + return f"Rolling(Eq($function_code, {ord(funccode)}) & {code}($ask_order, $bid_order) , '{time_interval}s', 'sum') / Rolling($function_code, '{time_interval}s', 'count')" + + # (la|lb|ma|mb|ca|cb)_relative_intensity_(time_interval_small)_(time_interval_big) + def test_exp_08_1(self): + expr8_1 = ( + lambda funccode, ordercode, time_interval_short, time_interval_long: f"TResample(Gt({self.expr7_init(funccode, ordercode, time_interval_short)},{self.expr7_init(funccode, ordercode, time_interval_long)}), '1min', 'mean')" + ) + + exprs = [] + names = [] + for funccode in ["B", "S"]: + for ordercode in ["0", "1"]: + exprs.append(expr8_1(funccode, ordercode, "10", "900")) + names.append(self.trans_dict[ordercode] + self.trans_dict[funccode] + "_relative_intensity_10s_900s") + + df = D.features(self.stocks_list, fields=exprs, freq="order") + df.columns = names + print(df) + + def test_exp_08_2(self): + # NOTE: It depends on transaction frequency + expr8_2 = ( + lambda funccode, ordercode, time_interval_short, time_interval_long: f"TResample(Gt({self.expr7_3_init(funccode, ordercode, time_interval_short)},{self.expr7_3_init(funccode, ordercode, time_interval_long)}), '1min', 'mean')" + ) + + exprs = [expr8_2("C", "Gt", "10", "900"), expr8_2("C", "Lt", "10", "900")] + names = ["ca_relative_intensity_10s_900s", "cb_relative_intensity_10s_900s"] + + df = D.features(self.stocks_list, fields=exprs, freq="transaction") + df.columns = names + print(df) + + ## v9(la|lb|ma|mb|ca|cb)_diff_intensity_(time_interval1)_(time_interval2) + # 1) calculating the original data + # 2) Resample data to 3s and calculate the changing rate + # 3) Resample data to 1min + + def test_exp_09_trans(self): + exprs = [ + f'TResample(Div(Sub(TResample({self.expr7_3_init("C", "Gt", "3")}, "3s", "last"), Ref(TResample({self.expr7_3_init("C", "Gt", "3")}, "3s","last"), 1)), 3), "1min", "mean")', + f'TResample(Div(Sub(TResample({self.expr7_3_init("C", "Lt", "3")}, "3s", "last"), Ref(TResample({self.expr7_3_init("C", "Lt", "3")}, "3s","last"), 1)), 3), "1min", "mean")', + ] + names = ["ca_diff_intensity_3s_3s", "cb_diff_intensity_3s_3s"] + df = D.features(self.stocks_list, fields=exprs, freq="transaction") + df.columns = names + print(df) + + def test_exp_09_order(self): + exprs = [] + names = [] + for funccode in ["B", "S"]: + for ordercode in ["0", "1"]: + exprs.append( + f'TResample(Div(Sub(TResample({self.expr7_init(funccode, ordercode, "3")}, "3s", "last"), Ref(TResample({self.expr7_init(funccode, ordercode, "3")},"3s", "last"), 1)), 3) ,"1min", "mean")' + ) + names.append(self.trans_dict[ordercode] + self.trans_dict[funccode] + "_diff_intensity_3s_3s") + df = D.features(self.stocks_list, fields=exprs, freq="order") + df.columns = names + print(df) + + def test_exp_10(self): + exprs = [] + names = [] + for i in [5, 10, 30, 60]: + exprs.append( + f'TResample(Ref(TResample($ask1 + $bid1, "1s", "ffill"), {-i}) / TResample($ask1 + $bid1, "1s", "ffill") - 1, "1min", "mean" )' + ) + names.append(f"lag_{i}_change_rate" for i in [5, 10, 30, 60]) + df = D.features(self.stocks_list, fields=exprs, freq="ticks") + df.columns = names + print(df) + + +if __name__ == "__main__": + unittest.main() diff --git a/qlib/__init__.py b/qlib/__init__.py index 13d8b3590c..134028c51d 100644 --- a/qlib/__init__.py +++ b/qlib/__init__.py @@ -12,7 +12,6 @@ import subprocess from .log import get_module_logger - # init qlib def init(default_conf="client", **kwargs): """ diff --git a/qlib/config.py b/qlib/config.py index 54af9b954c..91c24ca22f 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -19,7 +19,7 @@ import platform import multiprocessing from pathlib import Path -from typing import Optional, Union +from typing import Callable, Optional, Union from typing import TYPE_CHECKING from qlib.constant import REG_CN, REG_US @@ -40,7 +40,7 @@ def __getattr__(self, attr): if attr in self.__dict__["_config"]: return self.__dict__["_config"][attr] - raise AttributeError(f"No such {attr} in self._config") + raise AttributeError(f"No such `{attr}` in self._config") def get(self, key, default=None): return self.__dict__["_config"].get(key, default) @@ -112,6 +112,8 @@ def set_conf_from_C(self, config_c): "calendar_cache": None, # for simple dataset cache "local_cache_path": None, + # kernels can be a fixed value or a callable function lie `def (freq: str) -> int` + # If the kernels are arctic_kernels, `min(NUM_USABLE_CPU, 30)` may be a good value "kernels": NUM_USABLE_CPU, # pickle.dump protocol version "dump_protocol_version": PROTOCOL_VERSION, @@ -121,11 +123,10 @@ def set_conf_from_C(self, config_c): "joblib_backend": "multiprocessing", "default_disk_cache": 1, # 0:skip/1:use "mem_cache_size_limit": 500, + "mem_cache_limit_type": "length", # memory cache expire second, only in used 'DatasetURICache' and 'client D.calendar' # default 1 hour "mem_cache_expire": 60 * 60, - # memory cache space limit, default 5GB, only in used client - "mem_cache_space_limit": 1024 * 1024 * 1024 * 5, # cache dir name "dataset_cache_dir_name": "dataset_cache", "features_cache_dir_name": "features_cache", @@ -462,6 +463,12 @@ def reset_qlib_version(self): # Due to a bug? that converting __version__ to _QlibConfig__version__bak # Using __version__bak instead of __version__ + def get_kernels(self, freq: str): + """get number of processors given frequency""" + if isinstance(self["kernels"], Callable): + return self["kernels"](freq) + return self["kernels"] + @property def registered(self): return self._registered diff --git a/qlib/contrib/model/pytorch_nn.py b/qlib/contrib/model/pytorch_nn.py index 5849e613dc..62e6096ca7 100644 --- a/qlib/contrib/model/pytorch_nn.py +++ b/qlib/contrib/model/pytorch_nn.py @@ -27,7 +27,6 @@ class DNNModelPytorch(Model): """DNN Model - Parameters ---------- input_dim : int diff --git a/qlib/data/__init__.py b/qlib/data/__init__.py index ef5fe4708e..6549d16f7a 100644 --- a/qlib/data/__init__.py +++ b/qlib/data/__init__.py @@ -15,6 +15,7 @@ LocalCalendarProvider, LocalInstrumentProvider, LocalFeatureProvider, + ArcticFeatureProvider, LocalExpressionProvider, LocalDatasetProvider, ClientCalendarProvider, diff --git a/qlib/data/base.py b/qlib/data/base.py index f768f70674..b18e7aa476 100644 --- a/qlib/data/base.py +++ b/qlib/data/base.py @@ -150,7 +150,7 @@ def load(self, instrument, start_index, end_index, freq): args = str(self), instrument, start_index, end_index, freq if args in H["f"]: return H["f"][args] - if start_index is None or end_index is None or start_index > end_index: + if start_index is not None and end_index is not None and start_index > end_index: raise ValueError("Invalid index range: {} {}".format(start_index, end_index)) try: series = self._load_internal(instrument, start_index, end_index, freq) diff --git a/qlib/data/cache.py b/qlib/data/cache.py index c33fa655b5..a156bded42 100644 --- a/qlib/data/cache.py +++ b/qlib/data/cache.py @@ -147,6 +147,7 @@ def __init__(self, mem_cache_size_limit=None, limit_type="length"): """ size_limit = C.mem_cache_size_limit if mem_cache_size_limit is None else mem_cache_size_limit + limit_type = C.mem_cache_limit_type if limit_type is None else limit_type if limit_type == "length": klass = MemCacheLengthUnit @@ -1198,7 +1199,4 @@ def calendar(self, start_time=None, end_time=None, freq="day", future=False): return result -# MemCache sizeof -HZ = MemCache(C.mem_cache_space_limit, limit_type="sizeof") -# MemCache length -H = MemCache(limit_type="length") +H = MemCache() diff --git a/qlib/data/data.py b/qlib/data/data.py index 186e907f13..9849f36ed1 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -5,8 +5,10 @@ from __future__ import division from __future__ import print_function +import os import re import abc +import time import copy import queue import bisect @@ -15,9 +17,11 @@ from multiprocessing import Pool from typing import Iterable, Union from typing import List, Union +from arctic import Arctic # For supporting multiprocessing in outer code, joblib is used from joblib import delayed +import pymongo from .cache import H from ..config import C @@ -38,11 +42,17 @@ normalize_cache_fields, code_to_fname, set_log_with_config, + time_to_slc_point, ) from ..utils.paral import ParallelExt class ProviderBackendMixin: + """ + This helper class tries to make the provider based on storage backend more convenient + It is not necessary to inherent this class if that provider don't rely on the backend storage + """ + def get_default_backend(self): backend = {} provider_name: str = re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2] @@ -59,15 +69,12 @@ def backend_obj(self, **kwargs): return init_instance_by_config(backend) -class CalendarProvider(abc.ABC, ProviderBackendMixin): +class CalendarProvider(abc.ABC): """Calendar provider base class Provide calendar data. """ - def __init__(self, *args, **kwargs): - self.backend = kwargs.get("backend", {}) - def calendar(self, start_time=None, end_time=None, freq="day", future=False): """Get calendar of certain market in given time range. @@ -194,15 +201,12 @@ def load_calendar(self, freq, future): raise NotImplementedError("Subclass of CalendarProvider must implement `load_calendar` method") -class InstrumentProvider(abc.ABC, ProviderBackendMixin): +class InstrumentProvider(abc.ABC): """Instrument provider base class Provide instrument data. """ - def __init__(self, *args, **kwargs): - self.backend = kwargs.get("backend", {}) - @staticmethod def instruments(market: Union[List, str] = "all", filter_pipe: Union[List, None] = None): """Get the general config dictionary for a base market adding several dynamic filters. @@ -304,15 +308,12 @@ def get_inst_type(cls, inst): raise ValueError(f"Unknown instrument type {inst}") -class FeatureProvider(abc.ABC, ProviderBackendMixin): +class FeatureProvider(abc.ABC): """Feature provider class Provide feature data. """ - def __init__(self, *args, **kwargs): - self.backend = kwargs.get("backend", {}) - @abc.abstractmethod def feature(self, instrument, field, start_time, end_time, freq): """Get feature data. @@ -365,9 +366,13 @@ def get_expression_instance(self, field): return expression @abc.abstractmethod - def expression(self, instrument, field, start_time=None, end_time=None, freq="day"): + def expression(self, instrument, field, start_time=None, end_time=None, freq="day") -> pd.Series: """Get Expression data. + The responsibility of `expression` + - parse the `field` and `load` the according data. + - When loading the data, it should handle the time dependency of the data. `get_expression_instance` is commonly used in this method + Parameters ---------- instrument : str @@ -385,6 +390,11 @@ def expression(self, instrument, field, start_time=None, end_time=None, freq="da ------- pd.Series data of a certain expression + + The data has two types of format + 1) expression with datetime index + 2) expression with integer index + - because the datetime is not as good as """ raise NotImplementedError("Subclass of ExpressionProvider must implement `Expression` method") @@ -500,7 +510,7 @@ def dataset_processor(instruments_d, column_names, start_time, end_time, freq, i """ normalize_column_names = normalize_cache_fields(column_names) # One process for one task, so that the memory will be freed quicker. - workers = max(min(C.kernels, len(instruments_d)), 1) + workers = max(min(C.get_kernels(freq), len(instruments_d)), 1) # create iterator if isinstance(instruments_d, dict): @@ -513,7 +523,7 @@ def dataset_processor(instruments_d, column_names, start_time, end_time, freq, i for inst, spans in it: inst_l.append(inst) task_l.append( - delayed(DatasetProvider.expression_calculator)( + delayed(DatasetProvider.inst_calculator)( inst, start_time, end_time, freq, normalize_column_names, spans, C, inst_processors ) ) @@ -536,17 +546,17 @@ def dataset_processor(instruments_d, column_names, start_time, end_time, freq, i data = DiskDatasetCache.cache_to_origin_data(data, column_names) else: data = pd.DataFrame( - index=pd.MultiIndex.from_arrays([[], []], names=("instrument", "datetime")), columns=column_names + index=pd.MultiIndex.from_arrays([[], []], names=("instrument", "datetime")), + columns=column_names, + dtype=np.float32, ) return data @staticmethod - def expression_calculator( - inst, start_time, end_time, freq, column_names, spans=None, g_config=None, inst_processors=[] - ): + def inst_calculator(inst, start_time, end_time, freq, column_names, spans=None, g_config=None, inst_processors=[]): """ - Calculate the expressions for one instrument, return a df result. + Calculate the expressions for **one** instrument, return a df result. If the expression has been calculated before, load from cache. return value: A data frame with index 'datetime' and other data columns. @@ -566,8 +576,10 @@ def expression_calculator( obj[field] = ExpressionD.expression(inst, field, start_time, end_time, freq) data = pd.DataFrame(obj) - _calendar = Cal.calendar(freq=freq) - data.index = _calendar[data.index.values.astype(int)] + if not data.empty and not np.issubdtype(data.index.dtype, np.dtype("M")): + # If the underlaying provides the data not in datatime formmat, we'll convert it into datetime format + _calendar = Cal.calendar(freq=freq) + data.index = _calendar[data.index.values.astype(int)] data.index.names = ["datetime"] if spans is not None: @@ -583,15 +595,16 @@ def expression_calculator( return data -class LocalCalendarProvider(CalendarProvider): +class LocalCalendarProvider(CalendarProvider, ProviderBackendMixin): """Local calendar data provider class Provide calendar data from local data source. """ - def __init__(self, **kwargs): - super(LocalCalendarProvider, self).__init__(**kwargs) - self.remote = kwargs.get("remote", False) + def __init__(self, remote=False, backend={}): + super().__init__() + self.remote = remote + self.backend = backend def load_calendar(self, freq, future): """Load original calendar timestamp from file. @@ -623,12 +636,16 @@ def load_calendar(self, freq, future): return [pd.Timestamp(x) for x in backend_obj] -class LocalInstrumentProvider(InstrumentProvider): +class LocalInstrumentProvider(InstrumentProvider, ProviderBackendMixin): """Local instrument data provider class Provide instrument data from local data source. """ + def __init__(self, backend={}) -> None: + super().__init__() + self.backend = backend + def _load_instruments(self, market, freq): return self.backend_obj(market=market, freq=freq).data @@ -667,15 +684,16 @@ def list_instruments(self, instruments, start_time=None, end_time=None, freq="da return _instruments_filtered -class LocalFeatureProvider(FeatureProvider): +class LocalFeatureProvider(FeatureProvider, ProviderBackendMixin): """Local feature data provider class Provide feature data from local data source. """ - def __init__(self, **kwargs): - super(LocalFeatureProvider, self).__init__(**kwargs) - self.remote = kwargs.get("remote", False) + def __init__(self, remote=False, backend={}): + super().__init__() + self.remote = remote + self.backend = backend def feature(self, instrument, field, start_index, end_index, freq): # validate @@ -684,20 +702,72 @@ def feature(self, instrument, field, start_index, end_index, freq): return self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1] +class ArcticFeatureProvider(FeatureProvider): + def __init__( + self, uri="127.0.0.1", retry_time=0, market_transaction_time_list=[("09:15", "11:30"), ("13:00", "15:00")] + ): + super().__init__() + self.uri = uri + # TODO: + # retry connecting if error occurs + # does it real matters? + self.retry_time = retry_time + # NOTE: this is especially important for TResample operator + self.market_transaction_time_list = market_transaction_time_list + + def feature(self, instrument, field, start_index, end_index, freq): + field = str(field)[1:] + with pymongo.MongoClient(self.uri) as client: + # TODO: this will result in frequently connecting the server and performance issue + arctic = Arctic(client) + + if freq not in arctic.list_libraries(): + raise ValueError("lib {} not in arctic".format(freq)) + + if instrument not in arctic[freq].list_symbols(): + # instruments does not exist + return pd.Series() + else: + df = arctic[freq].read(instrument, columns=[field], chunk_range=(start_index, end_index)) + s = df[field] + + if not s.empty: + s = pd.concat( + [ + s.between_time(time_tuple[0], time_tuple[1]) + for time_tuple in self.market_transaction_time_list + ] + ) + return s + + class LocalExpressionProvider(ExpressionProvider): """Local expression data provider class Provide expression data from local data source. """ + def __init__(self, time2idx=True): + super().__init__() + self.time2idx = time2idx + def expression(self, instrument, field, start_time=None, end_time=None, freq="day"): expression = self.get_expression_instance(field) - start_time = pd.Timestamp(start_time) - end_time = pd.Timestamp(end_time) - _, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq=freq, future=False) - lft_etd, rght_etd = expression.get_extended_window_size() + start_time = time_to_slc_point(start_time) + end_time = time_to_slc_point(end_time) + + # Two kinds of queries are supported + # - Index-based expression: this may save a lot of memory because the datetime index is not saved on the disk + # - Data with datetime index expression: this will make it more convenient to integrating with some existing databases + if self.time2idx: + _, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq=freq, future=False) + lft_etd, rght_etd = expression.get_extended_window_size() + query_start, query_end = max(0, start_index - lft_etd), end_index + rght_etd + else: + start_index, end_index = query_start, query_end = start_time, end_time + try: - series = expression.load(instrument, max(0, start_index - lft_etd), end_index + rght_etd, freq) + series = expression.load(instrument, query_start, query_end, freq) except Exception as e: get_module_logger("data").debug( f"Loading expression error: " @@ -726,8 +796,18 @@ class LocalDatasetProvider(DatasetProvider): Provide dataset data from local data source. """ - def __init__(self): - pass + def __init__(self, align_time: bool = True): + """ + Parameters + ---------- + align_time : bool + Will we align the time to calendar + the frequency is flexible in some dataset and can't be aligned. + For the data with fixed frequency with a shared calendar, the align data to the calendar will provides following benefits + - Align queries to the same parameters, so the cache can be shared. + """ + super().__init__() + self.align_time = align_time def dataset( self, @@ -740,14 +820,16 @@ def dataset( ): instruments_d = self.get_instruments_d(instruments, freq) column_names = self.get_column_names(fields) - cal = Cal.calendar(start_time, end_time, freq) - if len(cal) == 0: - return pd.DataFrame( - index=pd.MultiIndex.from_arrays([[], []], names=("instrument", "datetime")), columns=column_names - ) - start_time = cal[0] - end_time = cal[-1] - + if self.align_time: + # NOTE: if the frequency is a fixed value. + # align the data to fixed calendar point + cal = Cal.calendar(start_time, end_time, freq) + if len(cal) == 0: + return pd.DataFrame( + index=pd.MultiIndex.from_arrays([[], []], names=("instrument", "datetime")), columns=column_names + ) + start_time = cal[0] + end_time = cal[-1] data = self.dataset_processor( instruments_d, column_names, start_time, end_time, freq, inst_processors=inst_processors ) diff --git a/qlib/data/ops.py b/qlib/data/ops.py index 049adece9d..dfbe013f79 100644 --- a/qlib/data/ops.py +++ b/qlib/data/ops.py @@ -721,9 +721,9 @@ def _load_internal(self, instrument, start_index, end_index, freq): # NOTE: remove all null check, # now it's user's responsibility to decide whether use features in null days # isnull = series.isnull() # NOTE: isnull = NaN, inf is not null - if self.N == 0: + if isinstance(self.N, int) and self.N == 0: series = getattr(series.expanding(min_periods=1), self.func)() - elif 0 < self.N < 1: + elif isinstance(self.N, int) and 0 < self.N < 1: series = series.ewm(alpha=self.N, min_periods=1).mean() else: series = getattr(series.rolling(self.N, min_periods=1), self.func)() @@ -1380,6 +1380,7 @@ class PairRolling(ExpressionOps): """ def __init__(self, feature_left, feature_right, N, func): + # TODO: in what case will a const be passed into `__init__` as `feature_left` or `feature_right` self.feature_left = feature_left self.feature_right = feature_right self.N = N @@ -1389,8 +1390,19 @@ def __str__(self): return "{}({},{},{})".format(type(self).__name__, self.feature_left, self.feature_right, self.N) def _load_internal(self, instrument, start_index, end_index, freq): - series_left = self.feature_left.load(instrument, start_index, end_index, freq) - series_right = self.feature_right.load(instrument, start_index, end_index, freq) + assert any( + [isinstance(self.feature_left, Expression), self.feature_right, Expression] + ), "at least one of two inputs is Expression instance" + + if isinstance(self.feature_left, Expression): + series_left = self.feature_left.load(instrument, start_index, end_index, freq) + else: + series_left = self.feature_left # numeric value + if isinstance(self.feature_right, Expression): + series_right = self.feature_right.load(instrument, start_index, end_index, freq) + else: + series_right = self.feature_right + if self.N == 0: series = getattr(series_left.expanding(min_periods=1), self.func)(series_right) else: @@ -1400,21 +1412,33 @@ def _load_internal(self, instrument, start_index, end_index, freq): def get_longest_back_rolling(self): if self.N == 0: return np.inf - return ( - max(self.feature_left.get_longest_back_rolling(), self.feature_right.get_longest_back_rolling()) - + self.N - - 1 - ) + if isinstance(self.feature_left, Expression): + left_br = self.feature_left.get_longest_back_rolling() + else: + left_br = 0 + + if isinstance(self.feature_right, Expression): + right_br = self.feature_right.get_longest_back_rolling() + else: + right_br = 0 + return max(left_br, right_br) def get_extended_window_size(self): - ll, lr = self.feature_left.get_extended_window_size() - rl, rr = self.feature_right.get_extended_window_size() if self.N == 0: get_module_logger(self.__class__.__name__).warning( "The PairRolling(ATTR, 0) will not be accurately calculated" ) return -np.inf, max(lr, rr) else: + if isinstance(self.feature_left, Expression): + ll, lr = self.feature_left.get_extended_window_size() + else: + ll, lr = 0, 0 + + if isinstance(self.feature_right, Expression): + rl, rr = self.feature_right.get_extended_window_size() + else: + rl, rr = 0, 0 return max(ll, rl) + self.N - 1, max(lr, rr) @@ -1474,7 +1498,50 @@ def __init__(self, feature_left, feature_right, N): super(Cov, self).__init__(feature_left, feature_right, N, "cov") +#################### Operator which only support data with time index #################### +# Convention +# - The name of the operators in this section will start with "T" + + +class TResample(ElemOperator): + def __init__(self, feature, freq, func): + """ + Resampling the data to target frequency. + The resample function of pandas is used. + - the timestamp will be at the start of the time span after resample. + + Parameters + ---------- + feature : Expression + An expression for calculating the feature + freq : str + It will be passed into the resample method for resampling basedn on given frequency + func : method + The method to get the resampled values + Some expression are high frequently used + """ + self.feature = feature + self.freq = freq + self.func = func + + def __str__(self): + return "{}({},{})".format(type(self).__name__, self.feature, self.freq) + + def _load_internal(self, instrument, start_index, end_index, freq): + series = self.feature.load(instrument, start_index, end_index, freq) + + if series.empty: + return series + else: + if self.func == "sum": + return getattr(series.resample(self.freq), self.func)(min_count=1) + else: + return getattr(series.resample(self.freq), self.func)() + + +TOpsList = [TResample] OpsList = [ + Rolling, Ref, Max, Min, @@ -1521,7 +1588,7 @@ def __init__(self, feature_left, feature_right, N): IdxMin, If, Feature, -] +] + [TResample] class OpsWrapper: diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 04a1931b40..13f202a318 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -167,9 +167,14 @@ def parse_field(field): # - $close -> Feature("close") # - $close5 -> Feature("close5") # - $open+$close -> Feature("open")+Feature("close") + # TODO: this maybe used in the feature if we want to support the computation of different frequency data + # - $close@5min -> Feature("close", "5min") + if not isinstance(field, str): field = str(field) - return re.sub(r"\$(\w+)", r'Feature("\1")', re.sub(r"(\w+\s*)\(", r"Operators.\1(", field)) + for pattern, new in [(r"\$(\w+)", rf'Feature("\1")'), (r"(\w+\s*)\(", r"Operators.\1(")]: # Features # Operators + field = re.sub(pattern, new, field) + return field def get_module_by_module_path(module_path: Union[str, ModuleType]): diff --git a/setup.py b/setup.py index ab397e1cf5..1b9d4f490a 100644 --- a/setup.py +++ b/setup.py @@ -78,6 +78,7 @@ def get_version(rel_path: str) -> str: "dill", "dataclasses;python_version<'3.7'", "filelock", + "arctic", ] # Numpy include diff --git a/tests/ops/test_elem_operator.py b/tests/ops/test_elem_operator.py index 0e21e53548..e641b1ac2e 100644 --- a/tests/ops/test_elem_operator.py +++ b/tests/ops/test_elem_operator.py @@ -19,7 +19,7 @@ def setUp(self) -> None: "Abs($change)", ] columns = ["change", "abs"] - self.data = DatasetProvider.expression_calculator( + self.data = DatasetProvider.inst_calculator( self.inst, self.start_time, self.end_time, freq, expressions, self.spans, C, [] ) self.data.columns = columns