diff --git a/chdb/__init__.py b/chdb/__init__.py index f2ded0467e9..a581f5eb3cb 100644 --- a/chdb/__init__.py +++ b/chdb/__init__.py @@ -1,5 +1,6 @@ import sys import os +import pyarrow as pa chdb_version = (0, 1, 0) if sys.version_info[:2] >= (3, 7): @@ -22,7 +23,19 @@ except: # pragma: no cover __version__ = "unknown" -# wrap _chdb functions +def _to_arrowTable(res): + """convert res to arrow table""" + return pa.RecordBatchFileReader(res.get_memview()).read_all() + +def to_df(r): + """"convert arrow table to Dataframe""" + t = _to_arrowTable(r) + return t.to_pandas(use_threads=True) + +# wrap _chdb functions def query(sql, output_format="CSV", **kwargs): + if output_format.lower() == "dataframe": + r = _chdb.query(sql, "Arrow", **kwargs) + return to_df(r) return _chdb.query(sql, output_format, **kwargs) diff --git a/chdb/__main__.py b/chdb/__main__.py new file mode 100644 index 00000000000..f962da08a43 --- /dev/null +++ b/chdb/__main__.py @@ -0,0 +1,31 @@ +import sys +import argparse +from .__init__ import query + +def main(): + prog = 'python -m chdb' + description = ('''A simple command line interface for chdb + to run SQL and output in specified format''') + parser = argparse.ArgumentParser(prog=prog, description=description) + parser.add_argument('sql', nargs=1, + type=str, + help='sql, e.g: select 1112222222,555') + parser.add_argument('format', nargs='?', + type=str, + help='''sql result output format, + e.g: CSV, Dataframe, JSON etc, + more format checkout on + https://clickhouse.com/docs/en/interfaces/formats''', + default="CSV") + options = parser.parse_args() + sql = options.sql[0] + output_format = options.format + res = query(sql, output_format) + if output_format.lower() == 'dataframe': + temp = res + else: + temp = res.data() + print(temp, end="") + +if __name__ == '__main__': + main() diff --git a/chdb/test_smoke.sh b/chdb/test_smoke.sh index 0f9db1b79e3..52a067e06b2 100755 --- a/chdb/test_smoke.sh +++ b/chdb/test_smoke.sh @@ -24,3 +24,5 @@ python3 -c \ python3 -c \ "import chdb; res = chdb.query('select version()', 'CSV'); print(str(res.get_memview().tobytes()))" +# test cli +python3 -m chdb "select 1112222222,555" Dataframe diff --git a/setup.py b/setup.py index 58f93f011be..355f4020798 100644 --- a/setup.py +++ b/setup.py @@ -155,6 +155,7 @@ def build_extensions(self): exclude_package_data={'': ['*.pyc', 'src/**']}, ext_modules=ext_modules, python_requires='>=3.7', + install_requires=['pyarrow', 'pandas'], cmdclass={'build_ext': BuildExt}, test_suite="tests", zip_safe=False,