|
27 | 27 | import logging |
28 | 28 | import os as _os |
29 | 29 | import re as _re |
30 | | -import sqlite3 as _sqlite3 |
| 30 | +import peewee as _peewee |
31 | 31 | import sys as _sys |
32 | 32 | import threading |
33 | 33 | from functools import lru_cache |
@@ -901,136 +901,60 @@ def __str__(self): |
901 | 901 | # TimeZone cache related code |
902 | 902 | # --------------------------------- |
903 | 903 |
|
904 | | -class _KVStore: |
905 | | - """Simple Sqlite backed key/value store, key and value are strings. Should be thread safe.""" |
906 | 904 |
|
907 | | - def __init__(self, filename): |
908 | | - self._cache_mutex = Lock() |
909 | | - with self._cache_mutex: |
910 | | - self.conn = _sqlite3.connect(filename, timeout=10, check_same_thread=False) |
911 | | - self.conn.execute('pragma journal_mode=wal') |
912 | | - try: |
913 | | - self.conn.execute('create table if not exists "kv" (key TEXT primary key, value TEXT) without rowid') |
914 | | - except Exception as e: |
915 | | - if 'near "without": syntax error' in str(e): |
916 | | - # "without rowid" requires sqlite 3.8.2. Older versions will raise exception |
917 | | - self.conn.execute('create table if not exists "kv" (key TEXT primary key, value TEXT)') |
918 | | - else: |
919 | | - raise |
920 | | - self.conn.commit() |
921 | | - _atexit.register(self.close) |
922 | | - |
923 | | - def close(self): |
924 | | - if self.conn is not None: |
925 | | - with self._cache_mutex: |
926 | | - self.conn.close() |
927 | | - self.conn = None |
928 | | - |
929 | | - def get(self, key: str) -> Union[str, None]: |
930 | | - """Get value for key if it exists else returns None""" |
931 | | - try: |
932 | | - item = self.conn.execute('select value from "kv" where key=?', (key,)) |
933 | | - except _sqlite3.IntegrityError as e: |
934 | | - self.delete(key) |
935 | | - return None |
936 | | - if item: |
937 | | - return next(item, (None,))[0] |
938 | | - |
939 | | - def set(self, key: str, value: str) -> None: |
940 | | - if value is None: |
941 | | - self.delete(key) |
942 | | - else: |
943 | | - with self._cache_mutex: |
944 | | - self.conn.execute('replace into "kv" (key, value) values (?,?)', (key, value)) |
945 | | - self.conn.commit() |
946 | | - |
947 | | - def bulk_set(self, kvdata: Dict[str, str]): |
948 | | - records = tuple(i for i in kvdata.items()) |
949 | | - with self._cache_mutex: |
950 | | - self.conn.executemany('replace into "kv" (key, value) values (?,?)', records) |
951 | | - self.conn.commit() |
952 | | - |
953 | | - def delete(self, key: str): |
954 | | - with self._cache_mutex: |
955 | | - self.conn.execute('delete from "kv" where key=?', (key,)) |
956 | | - self.conn.commit() |
| 905 | +_cache_dir = _os.path.join(_ad.user_cache_dir(), "py-yfinance") |
| 906 | +DB_PATH = _os.path.join(_cache_dir, 'tkr-tz.db') |
| 907 | +db = _peewee.SqliteDatabase(DB_PATH, pragmas={'journal_mode': 'wal', 'cache_size': -64}) |
| 908 | +_tz_cache = None |
957 | 909 |
|
958 | 910 |
|
959 | 911 | class _TzCacheException(Exception): |
960 | 912 | pass |
961 | 913 |
|
962 | 914 |
|
963 | | -class _TzCache: |
964 | | - """Simple sqlite file cache of ticker->timezone""" |
965 | | - |
966 | | - def __init__(self): |
967 | | - self._setup_cache_folder() |
968 | | - # Must init db here, where is thread-safe |
969 | | - try: |
970 | | - self._tz_db = _KVStore(_os.path.join(self._db_dir, "tkr-tz.db")) |
971 | | - except _sqlite3.DatabaseError as err: |
972 | | - raise _TzCacheException(f"Error creating TzCache folder: '{self._db_dir}' reason: {err}") |
973 | | - self._migrate_cache_tkr_tz() |
974 | | - |
975 | | - def _setup_cache_folder(self): |
976 | | - if not _os.path.isdir(self._db_dir): |
977 | | - try: |
978 | | - _os.makedirs(self._db_dir) |
979 | | - except OSError as err: |
980 | | - raise _TzCacheException(f"Error creating TzCache folder: '{self._db_dir}' reason: {err}") |
981 | | - |
982 | | - elif not (_os.access(self._db_dir, _os.R_OK) and _os.access(self._db_dir, _os.W_OK)): |
983 | | - raise _TzCacheException(f"Cannot read and write in TzCache folder: '{self._db_dir}'") |
984 | | - |
985 | | - def lookup(self, tkr): |
986 | | - return self.tz_db.get(tkr) |
987 | | - |
988 | | - def store(self, tkr, tz): |
989 | | - if tz is None: |
990 | | - self.tz_db.delete(tkr) |
991 | | - else: |
992 | | - tz_db = self.tz_db.get(tkr) |
993 | | - if tz_db is not None: |
994 | | - if tz != tz_db: |
995 | | - get_yf_logger().debug(f'{tkr}: Overwriting cached TZ "{tz_db}" with different TZ "{tz}"') |
996 | | - self.tz_db.set(tkr, tz) |
997 | | - else: |
998 | | - self.tz_db.set(tkr, tz) |
| 915 | +class KV(_peewee.Model): |
| 916 | + key = _peewee.CharField(primary_key=True) |
| 917 | + value = _peewee.CharField(null=True) |
| 918 | + |
| 919 | + class Meta: |
| 920 | + database = db |
| 921 | + without_rowid = True |
999 | 922 |
|
1000 | | - @property |
1001 | | - def _db_dir(self): |
1002 | | - global _cache_dir |
1003 | | - return _os.path.join(_cache_dir, "py-yfinance") |
1004 | 923 |
|
1005 | | - @property |
1006 | | - def tz_db(self): |
1007 | | - return self._tz_db |
| 924 | +class _TzCache: |
| 925 | + def __init__(self): |
| 926 | + db.connect() |
| 927 | + db.create_tables([KV]) |
1008 | 928 |
|
1009 | | - def _migrate_cache_tkr_tz(self): |
1010 | | - """Migrate contents from old ticker CSV-cache to SQLite db""" |
1011 | | - old_cache_file_path = _os.path.join(self._db_dir, "tkr-tz.csv") |
| 929 | + old_cache_file_path = _os.path.join(_cache_dir, "tkr-tz.csv") |
| 930 | + if _os.path.isfile(old_cache_file_path): |
| 931 | + _os.remove(old_cache_file_path) |
1012 | 932 |
|
1013 | | - if not _os.path.isfile(old_cache_file_path): |
| 933 | + def lookup(self, key): |
| 934 | + try: |
| 935 | + return KV.get(KV.key == key).value |
| 936 | + except KV.DoesNotExist: |
1014 | 937 | return None |
| 938 | + |
| 939 | + def store(self, key, value): |
1015 | 940 | try: |
1016 | | - df = _pd.read_csv(old_cache_file_path, index_col="Ticker", on_bad_lines="skip") |
1017 | | - except _pd.errors.EmptyDataError: |
1018 | | - _os.remove(old_cache_file_path) |
1019 | | - except TypeError: |
1020 | | - _os.remove(old_cache_file_path) |
1021 | | - else: |
1022 | | - # Discard corrupt data: |
1023 | | - df = df[~df["Tz"].isna().to_numpy()] |
1024 | | - df = df[~(df["Tz"] == '').to_numpy()] |
1025 | | - df = df[~df.index.isna()] |
1026 | | - if not df.empty: |
1027 | | - try: |
1028 | | - self.tz_db.bulk_set(df.to_dict()['Tz']) |
1029 | | - except Exception as e: |
1030 | | - # Ignore |
1031 | | - pass |
| 941 | + if value is None: |
| 942 | + q = KV.delete().where(KV.key == key) |
| 943 | + q.execute() |
| 944 | + return |
| 945 | + with db.atomic(): |
| 946 | + KV.insert(key=key, value=value).execute() |
| 947 | + except IntegrityError: |
| 948 | + # Integrity error means the key already exists. Try updating the key. |
| 949 | + old_value = self.lookup(key) |
| 950 | + if old_value != value: |
| 951 | + get_yf_logger().debug(f"Value for key {key} changed from {old_value} to {value}.") |
| 952 | + with db.atomic(): |
| 953 | + q = KV.update(value=value).where(KV.key == key) |
| 954 | + q.execute() |
1032 | 955 |
|
1033 | | - _os.remove(old_cache_file_path) |
| 956 | + def close(self): |
| 957 | + db.close() |
1034 | 958 |
|
1035 | 959 |
|
1036 | 960 | class _TzCacheDummy: |
@@ -1068,9 +992,7 @@ def get_tz_cache(): |
1068 | 992 | return _tz_cache |
1069 | 993 |
|
1070 | 994 |
|
1071 | | -_cache_dir = _ad.user_cache_dir() |
1072 | 995 | _cache_init_lock = Lock() |
1073 | | -_tz_cache = None |
1074 | 996 |
|
1075 | 997 |
|
1076 | 998 | def set_tz_cache_location(cache_dir: str): |
|
0 commit comments