diff --git a/arangoasync/aql.py b/arangoasync/aql.py index 0e00c57..804744e 100644 --- a/arangoasync/aql.py +++ b/arangoasync/aql.py @@ -1,7 +1,7 @@ __all__ = ["AQL", "AQLQueryCache"] -from typing import Optional +from typing import Optional, cast from arangoasync.cursor import Cursor from arangoasync.errno import HTTP_NOT_FOUND @@ -10,6 +10,9 @@ AQLCacheConfigureError, AQLCacheEntriesError, AQLCachePropertiesError, + AQLFunctionCreateError, + AQLFunctionDeleteError, + AQLFunctionListError, AQLQueryClearError, AQLQueryExecuteError, AQLQueryExplainError, @@ -634,3 +637,117 @@ def response_handler(resp: Response) -> Jsons: return self.deserializer.loads_many(resp.raw_body) return await self._executor.execute(request, response_handler) + + async def functions(self, namespace: Optional[str] = None) -> Result[Jsons]: + """List the registered used-defined AQL functions. + + Args: + namespace (str | None): Returns all registered AQL user functions from + the specified namespace. + + Returns: + list: List of the AQL functions defined in the database. + + Raises: + AQLFunctionListError: If retrieval fails. + + References: + - `list-the-registered-user-defined-aql-functions `__ + """ # noqa: E501 + params: Json = dict() + if namespace is not None: + params["namespace"] = namespace + request = Request( + method=Method.GET, + endpoint="/_api/aqlfunction", + params=params, + ) + + def response_handler(resp: Response) -> Jsons: + if not resp.is_success: + raise AQLFunctionListError(resp, request) + result = cast(Jsons, self.deserializer.loads(resp.raw_body).get("result")) + if result is None: + raise AQLFunctionListError(resp, request) + return result + + return await self._executor.execute(request, response_handler) + + async def create_function( + self, + name: str, + code: str, + is_deterministic: Optional[bool] = None, + ) -> Result[Json]: + """Registers a user-defined AQL function (UDF) written in JavaScript. + + Args: + name (str): Name of the function. + code (str): JavaScript code of the function. + is_deterministic (bool | None): If set to `True`, the function is + deterministic. + + Returns: + dict: Information about the registered function. + + Raises: + AQLFunctionCreateError: If registration fails. + + References: + - `create-a-user-defined-aql-function `__ + """ # noqa: E501 + request = Request( + method=Method.POST, + endpoint="/_api/aqlfunction", + data=self.serializer.dumps( + dict(name=name, code=code, isDeterministic=is_deterministic) + ), + ) + + def response_handler(resp: Response) -> Json: + if not resp.is_success: + raise AQLFunctionCreateError(resp, request) + return self.deserializer.loads(resp.raw_body) + + return await self._executor.execute(request, response_handler) + + async def delete_function( + self, + name: str, + group: Optional[bool] = None, + ignore_missing: bool = False, + ) -> Result[Json]: + """Remove a user-defined AQL function. + + Args: + name (str): Name of the function. + group (bool | None): If set to `True`, the function name is treated + as a namespace prefix. + ignore_missing (bool): If set to `True`, will not raise an exception + if the function is not found. + + Returns: + dict: Information about the removed functions (their count). + + Raises: + AQLFunctionDeleteError: If removal fails. + + References: + - `remove-a-user-defined-aql-function `__ + """ # noqa: E501 + params: Json = dict() + if group is not None: + params["group"] = group + request = Request( + method=Method.DELETE, + endpoint=f"/_api/aqlfunction/{name}", + params=params, + ) + + def response_handler(resp: Response) -> Json: + if not resp.is_success: + if not (resp.status_code == HTTP_NOT_FOUND and ignore_missing): + raise AQLFunctionDeleteError(resp, request) + return self.deserializer.loads(resp.raw_body) + + return await self._executor.execute(request, response_handler) diff --git a/arangoasync/exceptions.py b/arangoasync/exceptions.py index ff3e0d1..f78b7fb 100644 --- a/arangoasync/exceptions.py +++ b/arangoasync/exceptions.py @@ -87,6 +87,18 @@ class AQLCachePropertiesError(ArangoServerError): """Failed to retrieve query cache properties.""" +class AQLFunctionCreateError(ArangoServerError): + """Failed to create AQL user function.""" + + +class AQLFunctionDeleteError(ArangoServerError): + """Failed to delete AQL user function.""" + + +class AQLFunctionListError(ArangoServerError): + """Failed to retrieve AQL user functions.""" + + class AQLQueryClearError(ArangoServerError): """Failed to clear slow AQL queries.""" diff --git a/tests/test_aql.py b/tests/test_aql.py index 9176974..c09450f 100644 --- a/tests/test_aql.py +++ b/tests/test_aql.py @@ -4,12 +4,20 @@ import pytest from packaging import version -from arangoasync.errno import FORBIDDEN, QUERY_PARSE +from arangoasync.errno import ( + FORBIDDEN, + QUERY_FUNCTION_INVALID_CODE, + QUERY_FUNCTION_NOT_FOUND, + QUERY_PARSE, +) from arangoasync.exceptions import ( AQLCacheClearError, AQLCacheConfigureError, AQLCacheEntriesError, AQLCachePropertiesError, + AQLFunctionCreateError, + AQLFunctionDeleteError, + AQLFunctionListError, AQLQueryClearError, AQLQueryExecuteError, AQLQueryExplainError, @@ -276,3 +284,82 @@ async def test_cache_plan_management(db, bad_db, doc_col, docs, db_version): with pytest.raises(AQLCacheClearError) as err: await bad_db.aql.cache.clear_plan() assert err.value.error_code == FORBIDDEN + + +@pytest.mark.asyncio +async def test_aql_function_management(db, bad_db): + fn_group = "functions::temperature" + fn_name_1 = "functions::temperature::celsius_to_fahrenheit" + fn_body_1 = "function (celsius) { return celsius * 1.8 + 32; }" + fn_name_2 = "functions::temperature::fahrenheit_to_celsius" + fn_body_2 = "function (fahrenheit) { return (fahrenheit - 32) / 1.8; }" + bad_fn_name = "functions::temperature::should_not_exist" + bad_fn_body = "function (celsius) { invalid syntax }" + + aql = db.aql + # List AQL functions + assert await aql.functions() == [] + + # List AQL functions with bad database + with pytest.raises(AQLFunctionListError) as err: + await bad_db.aql.functions() + assert err.value.error_code == FORBIDDEN + + # Create invalid AQL function + with pytest.raises(AQLFunctionCreateError) as err: + await aql.create_function(bad_fn_name, bad_fn_body) + assert err.value.error_code == QUERY_FUNCTION_INVALID_CODE + + # Create first AQL function + result = await aql.create_function(fn_name_1, fn_body_1, is_deterministic=True) + assert result["isNewlyCreated"] is True + functions = await aql.functions() + assert len(functions) == 1 + assert functions[0]["name"] == fn_name_1 + assert functions[0]["code"] == fn_body_1 + assert functions[0]["isDeterministic"] is True + + # Create same AQL function again + result = await aql.create_function(fn_name_1, fn_body_1, is_deterministic=True) + assert result["isNewlyCreated"] is False + functions = await aql.functions() + assert len(functions) == 1 + assert functions[0]["name"] == fn_name_1 + assert functions[0]["code"] == fn_body_1 + assert functions[0]["isDeterministic"] is True + + # Create second AQL function + result = await aql.create_function(fn_name_2, fn_body_2, is_deterministic=False) + assert result["isNewlyCreated"] is True + functions = await aql.functions() + assert len(functions) == 2 + assert functions[0]["name"] == fn_name_1 + assert functions[0]["code"] == fn_body_1 + assert functions[0]["isDeterministic"] is True + assert functions[1]["name"] == fn_name_2 + assert functions[1]["code"] == fn_body_2 + assert functions[1]["isDeterministic"] is False + + # Delete first function + result = await aql.delete_function(fn_name_1) + assert result["deletedCount"] == 1 + functions = await aql.functions() + assert len(functions) == 1 + + # Delete missing function + with pytest.raises(AQLFunctionDeleteError) as err: + await aql.delete_function(fn_name_1) + assert err.value.error_code == QUERY_FUNCTION_NOT_FOUND + result = await aql.delete_function(fn_name_1, ignore_missing=True) + assert "deletedCount" not in result + + # Delete function from bad db + with pytest.raises(AQLFunctionDeleteError) as err: + await bad_db.aql.delete_function(fn_name_2) + assert err.value.error_code == FORBIDDEN + + # Delete function group + result = await aql.delete_function(fn_group, group=True) + assert result["deletedCount"] == 1 + functions = await aql.functions() + assert len(functions) == 0 diff --git a/tests/test_cursor.py b/tests/test_cursor.py index f998836..fd0237f 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -128,6 +128,7 @@ async def test_cursor_write_query(db, doc_col, docs): cursor = await aql.execute( """ FOR d IN {col} FILTER d.val == @first OR d.val == @second + SORT d.val UPDATE {{_key: d._key, _val: @val }} IN {col} RETURN NEW """.format(