Skip to content

Skeleton for testing new API #507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions nutkit/frontend/_transaction_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from .. import protocol
from .exceptions import ApplicationCodeError
from .transaction import Transaction


def handle_retry_func(retry_func, res, driver):
assert isinstance(res, protocol.RetryFunc)
if not retry_func:
raise Exception("No retry function was registered")
res = retry_func(res.exception, res.attempt, res.max_attempts)
try:
retry, delay_ms = res
except (TypeError, ValueError) as exc:
raise ValueError(
"retry_func must return retry(bool), delay_ms(int)"
) from exc
driver.send(
protocol.RetryFuncResult(retry=bool(retry), delay_ms=int(delay_ms))
)


def run_tx_loop(fn, req, driver, retry_func=None, hooks=None):
driver.send(req, hooks=hooks)
x = None
while True:
res = driver.receive(hooks=hooks, allow_resolution=True)
if isinstance(res, protocol.RetryableTry):
retryable_id = res.id
tx = Transaction(driver, retryable_id)
try:
# Invoke the frontend test function until we succeed, note
# that the frontend test function makes calls to the
# backend itself.
x = fn(tx)
except (ApplicationCodeError, protocol.DriverError) as e:
# If this is an error originating from the driver in the
# backend, retrieve the id of the error and send that,
# this saves us from having to recreate errors on backend
# side, backend just needs to track the returned errors.
error_id = ""
if isinstance(e, protocol.DriverError):
error_id = e.id
driver.send(
protocol.RetryableNegative(retryable_id,
error_id=error_id),
hooks=hooks
)
except Exception as e:
# If this fails any other way, we still want the backend
# to rollback the transaction.
try:
res = driver.send_and_receive(
protocol.RetryableNegative(retryable_id),
allow_resolution=False, hooks=hooks
)
except protocol.FrontendError:
raise e
else:
raise Exception("Should be FrontendError but was: %s"
% res)
else:
# The frontend test function were fine with the
# interaction, notify backend that we're happy to go.
driver.send(
protocol.RetryablePositive(retryable_id),
hooks=hooks
)
elif isinstance(res, protocol.RetryFunc):
handle_retry_func(retry_func, res, driver)
elif isinstance(res, protocol.RetryableDone):
return x
else:
allowed = ["RetryableTry", "RetryableDone", "RetryFunc"]
raise Exception("Should be one of %s but was: %s" % (allowed, res))
118 changes: 118 additions & 0 deletions nutkit/frontend/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,129 @@
from __future__ import annotations

from dataclasses import dataclass
from enum import Enum
from typing import (
Any,
Callable,
Collection,
Dict,
List,
Optional,
Tuple,
)

# exception, attempt, maxAttempts
_T_RetryFunc = Callable[[Exception, int, int], Tuple[bool, int]]


class ClusterAccessMode(Enum):
NAIVE = "naive"
READERS = "readers"
WRITERS = "writers"


class TxClusterMemberAccess(Enum):
READERS = "readers"
WRITERS = "writers"


class _BaseConf:
_attr_to_conf_key = {}

def to_protocol(self) -> Dict[str, Any]:
return {
self._attr_to_conf_key[k]: v
for k, v in vars(self).items() if v is not None
}


@dataclass
class QueryConfig(_BaseConf):
max_record_count: Optional[int] = None
skip_records: Optional[bool] = None

_attr_to_conf_key = {
**_BaseConf._attr_to_conf_key,
"max_record_count": "maxRecordCount",
"skip_records": "skipRecords",
}


@dataclass
class SessionQueryConfig(QueryConfig):
cluster_access_mode: Optional[ClusterAccessMode] = None
timeout: Optional[int] = None
metadata: Optional[Dict[str, Any]] = None
max_retries: Optional[int] = None
# TODO: what is RetryInfo?
retry_function: Optional[_T_RetryFunc] = None
execute_in_transaction: Optional[bool] = None

_attr_to_conf_key = {
**QueryConfig._attr_to_conf_key,
"cluster_access_mode": "clusterAccessMode",
"timeout": "timeout",
"metadata": "metadata",
"max_retries": "maxRetries",
"retry_function": "retryFunctionRegistered",
"execute_in_transaction": "executeInTransaction",
}

def to_protocol(self) -> Dict[str, Any]:
res = super().to_protocol()
res["retryFunctionRegistered"] = self.retry_function is not None
return res


@dataclass
class DriverQueryConfig(SessionQueryConfig):
database: Optional[str] = None
bookmarks: Optional[Collection[str]] = None
impersonated_user: Optional[str] = None

_attr_to_conf_key = {
**SessionQueryConfig._attr_to_conf_key,
"database": "database",
"bookmarks": "bookmarks",
"impersonated_user": "impersonatedUser",
}


@dataclass
class SessionTxConfig(_BaseConf):
timeout: Optional[int] = None
metadata: Optional[Dict[str, Any]] = None
max_retries: Optional[int] = None
# TODO: what is RetryInfo?
retry_function: Optional[_T_RetryFunc] = None

_attr_to_conf_key = {
**_BaseConf._attr_to_conf_key,
"timeout": "timeout",
"metadata": "metadata",
"max_retries": "maxRetries",
"retry_function": "retryFunctionRegistered",
}

def to_protocol(self) -> Dict[str, Any]:
res = super().to_protocol()
res["retryFunctionRegistered"] = self.retry_function is not None
return res


@dataclass
class DriverTxConfig(SessionTxConfig):
database: Optional[str] = None
bookmarks: Optional[Collection[str]] = None
impersonated_user: Optional[str] = None

_attr_to_conf_key = {
**SessionTxConfig._attr_to_conf_key,
"database": "database",
"bookmarks": "bookmarks",
"impersonated_user": "impersonatedUser",
}


@dataclass
class Neo4jBookmarkManagerConfig:
Expand Down
43 changes: 41 additions & 2 deletions nutkit/frontend/driver.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
from typing import Optional
from __future__ import annotations

from typing import (
Optional,
TYPE_CHECKING,
)

from .. import protocol
from ._transaction_loop import (
handle_retry_func,
run_tx_loop,
)
from .config import from_bookmark_manager_config_to_protocol
from .config import Neo4jBookmarkManagerConfig as BMMConfig
from .eager_result import EagerResult
from .session import Session

if TYPE_CHECKING:
from .config import Neo4jBookmarkManagerConfig as BMMConfig


class Driver:
def __init__(self, backend, uri, auth_token, user_agent=None,
Expand Down Expand Up @@ -90,6 +102,13 @@ def supports_multi_db(self):
raise Exception("Should be MultiDBSupport")
return res.available

def supports_auto_query_routing(self):
req = protocol.CheckAutoQueryRoutingSupport(self._driver.id)
res = self.send_and_receive(req, allow_resolution=False)
if not isinstance(res, protocol.AutoQueryRoutingSupport):
raise Exception("Should be AutoQueryRoutingSupport")
return res.available

def is_encrypted(self):
req = protocol.CheckDriverIsEncrypted(self._driver.id)
res = self.send_and_receive(req, allow_resolution=False)
Expand Down Expand Up @@ -117,6 +136,26 @@ def session(self, access_mode, bookmarks=None, database=None,
raise Exception("Should be session")
return Session(self, res)

def query(self, query, params=None, config=None, hooks=None):
retry_func = None
if config:
retry_func = getattr(config, "retry_function", None)
req = protocol.DriverQuery(self._driver.id, query, params, config)
res = self.send_and_receive(req, hooks=hooks, allow_resolution=True)
while isinstance(res, protocol.RetryFunc):
handle_retry_func(retry_func, res, self)
res = self.receive(hooks=hooks, allow_resolution=True)
if not isinstance(res, protocol.EagerResult):
raise Exception("Should be EagerResult or RetryFunc")
return EagerResult(self, res)

def execute(self, fn, config=None, hooks=None):
retry_func = None
if config:
retry_func = getattr(config, "retry_function", None)
req = protocol.DriverExecute(self._driver.id, config)
return run_tx_loop(fn, req, self, retry_func=retry_func, hooks=hooks)

def resolve(self, address):
return self._resolver_fn(address)

Expand Down
27 changes: 27 additions & 0 deletions nutkit/frontend/eager_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from .. import protocol


class EagerResult:
def __init__(self, driver, eager_result):
self._driver = driver
self._eager_result = eager_result
self.keys = eager_result.keys
self.records = eager_result.records
self.summary = eager_result.summary

def single(self):
"""Return one record if there is exactly one.

Raises error otherwise.
"""
req = protocol.EagerResultSingle(self._eager_result.id)
return self._driver.send_and_receive(req, allow_resolution=True)

def scalar(self):
"""Unpack a single record with a single value.

Raise error if there are not exactly one record or not exactly one
value.
"""
req = protocol.EagerResultScalar(self._eager_result.id)
return self._driver.send_and_receive(req, allow_resolution=True)
9 changes: 9 additions & 0 deletions nutkit/frontend/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ def single_optional(self):
req = protocol.ResultSingleOptional(self._result.id)
return self._driver.send_and_receive(req, allow_resolution=True)

def scalar(self):
"""Unpack a single record with a single value.

Raise error if there are not exactly one record or not exactly one
value.
"""
req = protocol.EagerResultScalar(self._result.id)
return self._driver.send_and_receive(req, allow_resolution=True)

def peek(self):
"""Return the next Record or NullRecord without consuming it."""
req = protocol.ResultPeek(self._result.id)
Expand Down
Loading