diff --git a/CHANGELOG.md b/CHANGELOG.md index 9614d0b0..52c28763 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,9 @@ ## 3.0.1b1 ## * start 3.0 beta branch +## 2.13.4 ## +* fixed sqlalchemy get_columns method with not null columns + ## 2.13.3 ## * fixed use transaction object when commit with flag diff --git a/examples/_sqlalchemy_example/example.py b/examples/_sqlalchemy_example/example.py index 00cd80d3..96f47820 100644 --- a/examples/_sqlalchemy_example/example.py +++ b/examples/_sqlalchemy_example/example.py @@ -196,7 +196,7 @@ def run_example_core(engine): def main(): parser = argparse.ArgumentParser( formatter_class=argparse.RawDescriptionHelpFormatter, - description="""\033[92mYandex.Database examples sqlalchemy usage.\x1b[0m\n""", + description="""\033[92mYandex.Database examples _sqlalchemy usage.\x1b[0m\n""", ) parser.add_argument( "-d", @@ -219,7 +219,7 @@ def main(): ) logging.basicConfig(level=logging.INFO) - logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO) + logging.getLogger("_sqlalchemy.engine").setLevel(logging.INFO) run_example_core(engine) # run_example_orm(engine) diff --git a/tests/_sqlalchemy/conftest.py b/tests/_sqlalchemy/conftest.py new file mode 100644 index 00000000..6ebabac3 --- /dev/null +++ b/tests/_sqlalchemy/conftest.py @@ -0,0 +1,22 @@ +import pytest +import sqlalchemy as sa + +from ydb._sqlalchemy import register_dialect + + +@pytest.fixture(scope="module") +def engine(endpoint, database): + register_dialect() + engine = sa.create_engine( + "yql:///ydb/", + connect_args={"database": database, "endpoint": endpoint}, + ) + + yield engine + engine.dispose() + + +@pytest.fixture(scope="module") +def connection(engine): + with engine.connect() as conn: + yield conn diff --git a/tests/sqlalchemy/test_dbapi.py b/tests/_sqlalchemy/test_dbapi.py similarity index 100% rename from tests/sqlalchemy/test_dbapi.py rename to tests/_sqlalchemy/test_dbapi.py diff --git a/tests/sqlalchemy/test_sqlalchemy.py b/tests/_sqlalchemy/test_sqlalchemy.py similarity index 100% rename from tests/sqlalchemy/test_sqlalchemy.py rename to tests/_sqlalchemy/test_sqlalchemy.py diff --git a/tests/sqlalchemy/conftest.py b/tests/sqlalchemy/conftest.py index 6ebabac3..9b7d99c1 100644 --- a/tests/sqlalchemy/conftest.py +++ b/tests/sqlalchemy/conftest.py @@ -1,11 +1,11 @@ import pytest import sqlalchemy as sa -from ydb._sqlalchemy import register_dialect +from ydb.sqlalchemy import register_dialect -@pytest.fixture(scope="module") -def engine(endpoint, database): +@pytest.fixture +def sa_engine(endpoint, database): register_dialect() engine = sa.create_engine( "yql:///ydb/", @@ -14,9 +14,3 @@ def engine(endpoint, database): yield engine engine.dispose() - - -@pytest.fixture(scope="module") -def connection(engine): - with engine.connect() as conn: - yield conn diff --git a/tests/sqlalchemy/test_inspect.py b/tests/sqlalchemy/test_inspect.py new file mode 100644 index 00000000..ea0fe9e2 --- /dev/null +++ b/tests/sqlalchemy/test_inspect.py @@ -0,0 +1,24 @@ +import ydb + +import sqlalchemy as sa + + +def test_get_columns(driver_sync, sa_engine): + session = ydb.retry_operation_sync( + lambda: driver_sync.table_client.session().create() + ) + session.execute_scheme( + "CREATE TABLE test(id Int64 NOT NULL, value TEXT, num DECIMAL(22, 9), PRIMARY KEY (id))" + ) + inspect = sa.inspect(sa_engine) + columns = inspect.get_columns("test") + for c in columns: + c["type"] = type(c["type"]) + + assert columns == [ + {"name": "id", "type": sa.INTEGER, "nullable": False}, + {"name": "value", "type": sa.TEXT, "nullable": True}, + {"name": "num", "type": sa.DECIMAL, "nullable": True}, + ] + + session.execute_scheme("DROP TABLE test") diff --git a/tests/test_errors.py b/tests/test_errors.py new file mode 100644 index 00000000..93f6f4a3 --- /dev/null +++ b/tests/test_errors.py @@ -0,0 +1,17 @@ +import ydb +import pytest + +from os import path + + +def test_scheme_error(driver_sync, database): + session = driver_sync.table_client.session().create() + with pytest.raises(ydb.issues.SchemeError) as exc: + session.drop_table(path.join(database, "foobartable")) + + server_code = ydb.issues.StatusCode.SCHEME_ERROR + + assert type(exc.value) == ydb.issues.SchemeError + assert exc.value.status == server_code + assert f"server_code: {server_code}" in str(exc.value) + assert "Path does not exist" in str(exc.value) diff --git a/ydb/issues.py b/ydb/issues.py index 55c14cea..100af01d 100644 --- a/ydb/issues.py +++ b/ydb/issues.py @@ -159,11 +159,18 @@ class SessionPoolEmpty(Error, queue.Empty): def _format_issues(issues): if not issues: return "" + return " ,".join( - [text_format.MessageToString(issue, False, True) for issue in issues] + text_format.MessageToString(issue, as_utf8=False, as_one_line=True) + for issue in issues ) +def _format_response(response): + fmt_issues = _format_issues(response.issues) + return f"{fmt_issues} (server_code: {response.status})" + + _success_status_codes = {StatusCode.STATUS_CODE_UNSPECIFIED, StatusCode.SUCCESS} _server_side_error_map = { StatusCode.BAD_REQUEST: BadRequest, @@ -190,4 +197,4 @@ def _format_issues(issues): def _process_response(response_proto): if response_proto.status not in _success_status_codes: exc_obj = _server_side_error_map.get(response_proto.status) - raise exc_obj(_format_issues(response_proto.issues), response_proto.issues) + raise exc_obj(_format_response(response_proto), response_proto.issues) diff --git a/ydb/sqlalchemy/__init__.py b/ydb/sqlalchemy/__init__.py index 9e065d8f..aa9b2d00 100644 --- a/ydb/sqlalchemy/__init__.py +++ b/ydb/sqlalchemy/__init__.py @@ -191,11 +191,16 @@ def visit_function(self, func, add_to_result_map=None, **kwargs): ydb.PrimitiveType.DyNumber: sa.TEXT, } - def _get_column_type(t): - if isinstance(t.item, ydb.DecimalType): - return sa.DECIMAL(precision=t.item.precision, scale=t.item.scale) + def _get_column_info(t): + nullable = False + if isinstance(t, ydb.OptionalType): + nullable = True + t = t.item - return COLUMN_TYPES[t.item] + if isinstance(t, ydb.DecimalType): + return sa.DECIMAL(precision=t.precision, scale=t.scale), nullable + + return COLUMN_TYPES[t], nullable class YqlDialect(DefaultDialect): name = "yql" @@ -250,11 +255,12 @@ def get_columns(self, connection, table_name, schema=None, **kw): columns = raw_conn.describe(qt) as_compatible = [] for column in columns: + col_type, nullable = _get_column_info(column.type) as_compatible.append( { "name": column.name, - "type": _get_column_type(column.type), - "nullable": True, + "type": col_type, + "nullable": nullable, } )