diff --git a/.coveragerc-py37 b/.coveragerc-py37 index fb6dbb6e..b1c98d23 100644 --- a/.coveragerc-py37 +++ b/.coveragerc-py37 @@ -18,4 +18,8 @@ exclude_lines = # Don't complain about async-specific imports and code from functions_framework.aio import from functions_framework._http.asgi import - from functions_framework._http.gunicorn import UvicornApplication \ No newline at end of file + from functions_framework._http.gunicorn import UvicornApplication + + # Exclude async-specific classes and functions in execution_id.py + class AsgiMiddleware: + def set_execution_context_async \ No newline at end of file diff --git a/conftest.py b/conftest.py index 1d17e9bf..257f60d4 100644 --- a/conftest.py +++ b/conftest.py @@ -50,8 +50,12 @@ def pytest_ignore_collect(collection_path, config): if sys.version_info >= (3, 8): return None - # Skip test_aio.py and test_asgi.py entirely on Python 3.7 - if collection_path.name in ["test_aio.py", "test_asgi.py"]: + # Skip test_aio.py, test_asgi.py, and test_execution_id_async.py entirely on Python 3.7 + if collection_path.name in [ + "test_aio.py", + "test_asgi.py", + "test_execution_id_async.py", + ]: return True return None diff --git a/pyproject.toml b/pyproject.toml index 350d8997..2b6e0639 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,3 +61,13 @@ functions_framework = ["py.typed"] [tool.setuptools.package-dir] "" = "src" + +[dependency-groups] +dev = [ + "black>=23.3.0", + "build>=1.1.1", + "isort>=5.11.5", + "pretend>=1.0.9", + "pytest>=7.4.4", + "pytest-asyncio>=0.21.2", +] diff --git a/src/functions_framework/aio/__init__.py b/src/functions_framework/aio/__init__.py index 21f12754..4245f2d1 100644 --- a/src/functions_framework/aio/__init__.py +++ b/src/functions_framework/aio/__init__.py @@ -13,16 +13,24 @@ # limitations under the License. import asyncio +import contextvars import functools import inspect +import logging +import logging.config import os +import traceback from typing import Any, Awaitable, Callable, Dict, Tuple, Union from cloudevents.http import from_http from cloudevents.http.event import CloudEvent -from functions_framework import _function_registry +from functions_framework import ( + _enable_execution_id_logging, + _function_registry, + execution_id, +) from functions_framework.exceptions import ( FunctionsFrameworkException, MissingSourceException, @@ -31,6 +39,7 @@ try: from starlette.applications import Starlette from starlette.exceptions import HTTPException + from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import JSONResponse, Response from starlette.routing import Route @@ -96,29 +105,27 @@ def wrapper(*args, **kwargs): return wrapper -async def _crash_handler(request, exc): - headers = {_FUNCTION_STATUS_HEADER_FIELD: _CRASH} - return Response(f"Internal Server Error: {exc}", status_code=500, headers=headers) - - -def _http_func_wrapper(function, is_async): +def _http_func_wrapper(function, is_async, enable_id_logging=False): + @execution_id.set_execution_context_async(enable_id_logging) @functools.wraps(function) async def handler(request): if is_async: result = await function(request) else: # TODO: Use asyncio.to_thread when we drop Python 3.8 support - # Python 3.8 compatible version of asyncio.to_thread loop = asyncio.get_event_loop() - result = await loop.run_in_executor(None, function, request) + ctx = contextvars.copy_context() + result = await loop.run_in_executor(None, ctx.run, function, request) if isinstance(result, str): return Response(result) elif isinstance(result, dict): return JSONResponse(result) elif isinstance(result, tuple) and len(result) == 2: - # Support Flask-style tuple response content, status_code = result - return Response(content, status_code=status_code) + if isinstance(content, dict): + return JSONResponse(content, status_code=status_code) + else: + return Response(content, status_code=status_code) elif result is None: raise HTTPException(status_code=500, detail="No response returned") else: @@ -127,7 +134,8 @@ async def handler(request): return handler -def _cloudevent_func_wrapper(function, is_async): +def _cloudevent_func_wrapper(function, is_async, enable_id_logging=False): + @execution_id.set_execution_context_async(enable_id_logging) @functools.wraps(function) async def handler(request): data = await request.body() @@ -142,9 +150,9 @@ async def handler(request): await function(event) else: # TODO: Use asyncio.to_thread when we drop Python 3.8 support - # Python 3.8 compatible version of asyncio.to_thread loop = asyncio.get_event_loop() - await loop.run_in_executor(None, function, event) + ctx = contextvars.copy_context() + await loop.run_in_executor(None, ctx.run, function, event) return Response("OK") return handler @@ -154,6 +162,64 @@ async def _handle_not_found(request: Request): raise HTTPException(status_code=404, detail="Not Found") +def _configure_app_execution_id_logging(): + logging.config.dictConfig( + { + "version": 1, + "handlers": { + "asgi": { + "class": "logging.StreamHandler", + "stream": "ext://functions_framework.execution_id.logging_stream", + }, + }, + "root": {"level": "INFO", "handlers": ["asgi"]}, + } + ) + + +class ExceptionHandlerMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + if scope["type"] != "http": # pragma: no cover + await self.app(scope, receive, send) + return + + try: + await self.app(scope, receive, send) + except Exception as exc: + logger = logging.getLogger() + tb_lines = traceback.format_exception(type(exc), exc, exc.__traceback__) + tb_text = "".join(tb_lines) + + path = scope.get("path", "/") + method = scope.get("method", "GET") + error_msg = f"Exception on {path} [{method}]\n{tb_text}".rstrip() + + logger.error(error_msg) + + headers = [ + [b"content-type", b"text/plain"], + [_FUNCTION_STATUS_HEADER_FIELD.encode(), _CRASH.encode()], + ] + + await send( + { + "type": "http.response.start", + "status": 500, + "headers": headers, + } + ) + await send( + { + "type": "http.response.body", + "body": b"Internal Server Error", + } + ) + # Don't re-raise to prevent starlette from printing traceback again + + def create_asgi_app(target=None, source=None, signature_type=None): """Create an ASGI application for the function. @@ -175,6 +241,11 @@ def create_asgi_app(target=None, source=None, signature_type=None): ) source_module, spec = _function_registry.load_function_module(source) + + enable_id_logging = _enable_execution_id_logging() + if enable_id_logging: + _configure_app_execution_id_logging() + spec.loader.exec_module(source_module) function = _function_registry.get_user_function(source, source_module, target) signature_type = _function_registry.get_func_signature_type(target, signature_type) @@ -182,7 +253,7 @@ def create_asgi_app(target=None, source=None, signature_type=None): is_async = inspect.iscoroutinefunction(function) routes = [] if signature_type == _function_registry.HTTP_SIGNATURE_TYPE: - http_handler = _http_func_wrapper(function, is_async) + http_handler = _http_func_wrapper(function, is_async, enable_id_logging) routes.append( Route( "/", @@ -202,7 +273,9 @@ def create_asgi_app(target=None, source=None, signature_type=None): ) ) elif signature_type == _function_registry.CLOUDEVENT_SIGNATURE_TYPE: - cloudevent_handler = _cloudevent_func_wrapper(function, is_async) + cloudevent_handler = _cloudevent_func_wrapper( + function, is_async, enable_id_logging + ) routes.append( Route("/{path:path}", endpoint=cloudevent_handler, methods=["POST"]) ) @@ -221,10 +294,14 @@ def create_asgi_app(target=None, source=None, signature_type=None): f"Unsupported signature type for ASGI server: {signature_type}" ) - exception_handlers = { - 500: _crash_handler, - } - app = Starlette(routes=routes, exception_handlers=exception_handlers) + app = Starlette( + routes=routes, + middleware=[ + Middleware(ExceptionHandlerMiddleware), + Middleware(execution_id.AsgiMiddleware), + ], + ) + return app diff --git a/src/functions_framework/execution_id.py b/src/functions_framework/execution_id.py index 2b106531..df412187 100644 --- a/src/functions_framework/execution_id.py +++ b/src/functions_framework/execution_id.py @@ -13,7 +13,9 @@ # limitations under the License. import contextlib +import contextvars import functools +import inspect import io import json import logging @@ -38,6 +40,9 @@ logger = logging.getLogger(__name__) +# Context variable for async execution context +execution_context_var = contextvars.ContextVar("execution_context", default=None) + class ExecutionContext: def __init__(self, execution_id=None, span_id=None): @@ -46,7 +51,10 @@ def __init__(self, execution_id=None, span_id=None): def _get_current_context(): - return ( + context = execution_context_var.get() + if context is not None: + return context + return ( # pragma: no cover flask.g.execution_id_context if flask.has_request_context() and "execution_id_context" in flask.g else None @@ -54,6 +62,8 @@ def _get_current_context(): def _set_current_context(context): + execution_context_var.set(context) + # Also set in Flask context if available for sync if flask.has_request_context(): flask.g.execution_id_context = context @@ -65,6 +75,18 @@ def _generate_execution_id(): ) +def _extract_context_from_headers(headers): + """Extract execution context from request headers.""" + trace_context = re.match( + _TRACE_CONTEXT_REGEX_PATTERN, + headers.get(TRACE_CONTEXT_REQUEST_HEADER, ""), + ) + execution_id = headers.get(EXECUTION_ID_REQUEST_HEADER) + span_id = trace_context.group("span_id") if trace_context else None + + return ExecutionContext(execution_id, span_id) + + # Middleware to add execution id to request header if one does not already exist class WsgiMiddleware: def __init__(self, wsgi_app): @@ -78,8 +100,42 @@ def __call__(self, environ, start_response): return self.wsgi_app(environ, start_response) -# Sets execution id and span id for the request +class AsgiMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + if scope["type"] == "http": # pragma: no branch + execution_id_header = b"function-execution-id" + execution_id = None + + for name, value in scope.get("headers", []): + if name.lower() == execution_id_header: + execution_id = value.decode("latin-1") + break + + if not execution_id: + execution_id = _generate_execution_id() + new_headers = list(scope.get("headers", [])) + new_headers.append( + (execution_id_header, execution_id.encode("latin-1")) + ) + scope["headers"] = new_headers + + await self.app(scope, receive, send) + + def set_execution_context(request, enable_id_logging=False): + """Decorator for Flask/WSGI handlers that sets execution context. + + Takes request object at decoration time (Flask pattern where request is available + via thread-local context when decorator is applied). + + Usage: + @set_execution_context(request, enable_id_logging=True) + def view_func(path): + ... + """ if enable_id_logging: stdout_redirect = contextlib.redirect_stdout( LoggingHandlerAddExecutionId(sys.stdout) @@ -94,22 +150,71 @@ def set_execution_context(request, enable_id_logging=False): def decorator(view_function): @functools.wraps(view_function) def wrapper(*args, **kwargs): - trace_context = re.match( - _TRACE_CONTEXT_REGEX_PATTERN, - request.headers.get(TRACE_CONTEXT_REQUEST_HEADER, ""), - ) - execution_id = request.headers.get(EXECUTION_ID_REQUEST_HEADER) - span_id = trace_context.group("span_id") if trace_context else None - _set_current_context(ExecutionContext(execution_id, span_id)) + context = _extract_context_from_headers(request.headers) + _set_current_context(context) with stderr_redirect, stdout_redirect: - return view_function(*args, **kwargs) + result = view_function(*args, **kwargs) + return result return wrapper return decorator +def set_execution_context_async(enable_id_logging=False): + """Decorator for ASGI/async handlers that sets execution context. + + Unlike set_execution_context which takes request at decoration time (Flask pattern), + this expects the decorated function to receive request as its first parameter (ASGI pattern). + + Usage: + @set_execution_context_async(enable_id_logging=True) + async def handler(request, *args, **kwargs): + ... + """ + if enable_id_logging: + stdout_redirect = contextlib.redirect_stdout( + LoggingHandlerAddExecutionId(sys.stdout) + ) + stderr_redirect = contextlib.redirect_stderr( + LoggingHandlerAddExecutionId(sys.stderr) + ) + else: + stdout_redirect = contextlib.nullcontext() + stderr_redirect = contextlib.nullcontext() + + def decorator(func): + @functools.wraps(func) + async def async_wrapper(request, *args, **kwargs): + context = _extract_context_from_headers(request.headers) + token = execution_context_var.set(context) + + with stderr_redirect, stdout_redirect: + result = await func(request, *args, **kwargs) + + execution_context_var.reset(token) + return result + + @functools.wraps(func) + def sync_wrapper(request, *args, **kwargs): + context = _extract_context_from_headers(request.headers) + token = execution_context_var.set(context) + + with stderr_redirect, stdout_redirect: + result = func(request, *args, **kwargs) + + execution_context_var.reset(token) + return result + + if inspect.iscoroutinefunction(func): + return async_wrapper + else: + return sync_wrapper + + return decorator + + @LocalProxy def logging_stream(): return LoggingHandlerAddExecutionId(stream=flask.logging.wsgi_errors_stream) diff --git a/tests/test_aio.py b/tests/test_aio.py index cf69479a..4f34c279 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -143,6 +143,8 @@ async def http_func(request): wrapper = _http_func_wrapper(http_func, is_async=True) request = Mock() + request.headers = Mock() + request.headers.get = Mock(return_value="") response = await wrapper(request) assert response.__class__.__name__ == "JSONResponse" @@ -158,6 +160,8 @@ def sync_http_func(request): wrapper = _http_func_wrapper(sync_http_func, is_async=False) request = Mock() + request.headers = Mock() + request.headers.get = Mock(return_value="") response = await wrapper(request) assert response.__class__.__name__ == "Response" diff --git a/tests/test_execution_id.py b/tests/test_execution_id.py index a2601817..b8c5b9f0 100644 --- a/tests/test_execution_id.py +++ b/tests/test_execution_id.py @@ -223,6 +223,7 @@ def view_func(): monkeypatch.setattr( execution_id, "_generate_execution_id", lambda: TEST_EXECUTION_ID ) + mock_g = Mock() monkeypatch.setattr(execution_id.flask, "g", mock_g) monkeypatch.setattr(execution_id.flask, "has_request_context", lambda: True) diff --git a/tests/test_execution_id_async.py b/tests/test_execution_id_async.py new file mode 100644 index 00000000..01e638a1 --- /dev/null +++ b/tests/test_execution_id_async.py @@ -0,0 +1,365 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import json +import pathlib +import re + +from functools import partial +from unittest.mock import Mock + +import pytest + +from starlette.testclient import TestClient + +from functions_framework import execution_id +from functions_framework.aio import create_asgi_app + +TEST_FUNCTIONS_DIR = pathlib.Path(__file__).resolve().parent / "test_functions" +TEST_EXECUTION_ID = "test_execution_id" +TEST_SPAN_ID = "123456" + + +def test_user_function_can_retrieve_execution_id_from_header(): + source = TEST_FUNCTIONS_DIR / "execution_id" / "async_main.py" + target = "async_function" + app = create_asgi_app(target, source) + client = TestClient(app) + resp = client.post( + "/", + headers={ + "Function-Execution-Id": TEST_EXECUTION_ID, + "Content-Type": "application/json", + }, + ) + + assert resp.json()["execution_id"] == TEST_EXECUTION_ID + + +def test_uncaught_exception_in_user_function_sets_execution_id(capsys, monkeypatch): + monkeypatch.setenv("LOG_EXECUTION_ID", "true") + source = TEST_FUNCTIONS_DIR / "execution_id" / "async_main.py" + target = "async_error" + app = create_asgi_app(target, source) + # Don't raise server exceptions so we can capture the logs + client = TestClient(app, raise_server_exceptions=False) + resp = client.post( + "/", + headers={ + "Function-Execution-Id": TEST_EXECUTION_ID, + "Content-Type": "application/json", + }, + ) + assert resp.status_code == 500 + record = capsys.readouterr() + assert f'"execution_id": "{TEST_EXECUTION_ID}"' in record.err + assert '"logging.googleapis.com/labels"' in record.err + assert "ZeroDivisionError" in record.err + + +def test_print_from_user_function_sets_execution_id(capsys, monkeypatch): + monkeypatch.setenv("LOG_EXECUTION_ID", "true") + source = TEST_FUNCTIONS_DIR / "execution_id" / "async_main.py" + target = "async_print_message" + app = create_asgi_app(target, source) + client = TestClient(app) + client.post( + "/", + headers={ + "Function-Execution-Id": TEST_EXECUTION_ID, + "Content-Type": "application/json", + }, + json={"message": "some-message"}, + ) + record = capsys.readouterr() + assert f'"execution_id": "{TEST_EXECUTION_ID}"' in record.out + assert '"message": "some-message"' in record.out + + +def test_log_from_user_function_sets_execution_id(capsys, monkeypatch): + monkeypatch.setenv("LOG_EXECUTION_ID", "true") + source = TEST_FUNCTIONS_DIR / "execution_id" / "async_main.py" + target = "async_log_message" + app = create_asgi_app(target, source) + client = TestClient(app) + client.post( + "/", + headers={ + "Function-Execution-Id": TEST_EXECUTION_ID, + "Content-Type": "application/json", + }, + json={"message": json.dumps({"custom-field": "some-message"})}, + ) + record = capsys.readouterr() + assert f'"execution_id": "{TEST_EXECUTION_ID}"' in record.err + assert '"custom-field": "some-message"' in record.err + assert '"logging.googleapis.com/labels"' in record.err + + +def test_user_function_can_retrieve_generated_execution_id(monkeypatch): + monkeypatch.setattr( + execution_id, "_generate_execution_id", lambda: TEST_EXECUTION_ID + ) + source = TEST_FUNCTIONS_DIR / "execution_id" / "async_main.py" + target = "async_function" + app = create_asgi_app(target, source) + client = TestClient(app) + resp = client.post( + "/", + headers={ + "Content-Type": "application/json", + }, + ) + + assert resp.json()["execution_id"] == TEST_EXECUTION_ID + + +def test_does_not_set_execution_id_when_not_enabled(capsys): + source = TEST_FUNCTIONS_DIR / "execution_id" / "async_main.py" + target = "async_print_message" + app = create_asgi_app(target, source) + client = TestClient(app) + client.post( + "/", + headers={ + "Function-Execution-Id": TEST_EXECUTION_ID, + "Content-Type": "application/json", + }, + json={"message": "some-message"}, + ) + record = capsys.readouterr() + assert f'"execution_id": "{TEST_EXECUTION_ID}"' not in record.out + assert "some-message" in record.out + + +def test_does_not_set_execution_id_when_env_var_is_false(capsys, monkeypatch): + monkeypatch.setenv("LOG_EXECUTION_ID", "false") + source = TEST_FUNCTIONS_DIR / "execution_id" / "async_main.py" + target = "async_print_message" + app = create_asgi_app(target, source) + client = TestClient(app) + client.post( + "/", + headers={ + "Function-Execution-Id": TEST_EXECUTION_ID, + "Content-Type": "application/json", + }, + json={"message": "some-message"}, + ) + record = capsys.readouterr() + assert f'"execution_id": "{TEST_EXECUTION_ID}"' not in record.out + assert "some-message" in record.out + + +def test_does_not_set_execution_id_when_env_var_is_not_bool_like(capsys, monkeypatch): + monkeypatch.setenv("LOG_EXECUTION_ID", "maybe") + source = TEST_FUNCTIONS_DIR / "execution_id" / "async_main.py" + target = "async_print_message" + app = create_asgi_app(target, source) + client = TestClient(app) + client.post( + "/", + headers={ + "Function-Execution-Id": TEST_EXECUTION_ID, + "Content-Type": "application/json", + }, + json={"message": "some-message"}, + ) + record = capsys.readouterr() + assert f'"execution_id": "{TEST_EXECUTION_ID}"' not in record.out + assert "some-message" in record.out + + +def test_generate_execution_id(): + expected_matching_regex = "^[0-9a-zA-Z]{12}$" + actual_execution_id = execution_id._generate_execution_id() + + match = re.match(expected_matching_regex, actual_execution_id).group(0) + assert match == actual_execution_id + + +@pytest.mark.parametrize( + "headers,expected_execution_id,expected_span_id,should_generate", + [ + ( + { + "X-Cloud-Trace-Context": f"TRACE_ID/{TEST_SPAN_ID};o=1", + "Function-Execution-Id": TEST_EXECUTION_ID, + }, + TEST_EXECUTION_ID, + TEST_SPAN_ID, + False, + ), + ({}, None, None, True), # Middleware will generate an ID + ( + { + "X-Cloud-Trace-Context": "malformed trace context string", + "Function-Execution-Id": TEST_EXECUTION_ID, + }, + TEST_EXECUTION_ID, + None, + False, + ), + ], +) +def test_set_execution_context_headers( + headers, expected_execution_id, expected_span_id, should_generate +): + source = TEST_FUNCTIONS_DIR / "execution_id" / "async_main.py" + target = "async_trace_test" + app = create_asgi_app(target, source) + client = TestClient(app) + + resp = client.post("/", headers=headers) + + result = resp.json() + if should_generate: + # When no execution ID is provided, middleware generates one + assert result.get("execution_id") is not None + assert len(result.get("execution_id")) == 12 # Generated IDs are 12 chars + else: + assert result.get("execution_id") == expected_execution_id + assert result.get("span_id") == expected_span_id + + +@pytest.mark.asyncio +async def test_maintains_execution_id_for_concurrent_requests(monkeypatch, capsys): + monkeypatch.setenv("LOG_EXECUTION_ID", "true") + + expected_logs = ( + { + "message": "message1", + "logging.googleapis.com/labels": {"execution_id": "test-execution-id-1"}, + }, + { + "message": "message2", + "logging.googleapis.com/labels": {"execution_id": "test-execution-id-2"}, + }, + { + "message": "message1", + "logging.googleapis.com/labels": {"execution_id": "test-execution-id-1"}, + }, + { + "message": "message2", + "logging.googleapis.com/labels": {"execution_id": "test-execution-id-2"}, + }, + ) + + source = TEST_FUNCTIONS_DIR / "execution_id" / "async_main.py" + target = "async_sleep" + app = create_asgi_app(target, source) + client = TestClient(app) + loop = asyncio.get_event_loop() + response1 = loop.run_in_executor( + None, + partial( + client.post, + "/", + headers={ + "Content-Type": "application/json", + "Function-Execution-Id": "test-execution-id-1", + }, + json={"message": "message1"}, + ), + ) + response2 = loop.run_in_executor( + None, + partial( + client.post, + "/", + headers={ + "Content-Type": "application/json", + "Function-Execution-Id": "test-execution-id-2", + }, + json={"message": "message2"}, + ), + ) + await asyncio.wait((response1, response2)) + record = capsys.readouterr() + logs = record.err.strip().split("\n") + logs_as_json = tuple(json.loads(log) for log in logs) + + sort_key = lambda d: d["message"] + assert sorted(logs_as_json, key=sort_key) == sorted(expected_logs, key=sort_key) + + +def test_async_decorator_with_sync_function(): + def sync_func(request): + return {"status": "ok"} + + wrapped = execution_id.set_execution_context_async(enable_id_logging=False)( + sync_func + ) + + request = Mock() + request.headers = Mock() + request.headers.get = Mock(return_value="") + + result = wrapped(request) + + assert result == {"status": "ok"} + + +def test_sync_cloudevent_function_has_execution_context(monkeypatch, capsys): + """Test that sync CloudEvent functions can access execution context.""" + monkeypatch.setenv("LOG_EXECUTION_ID", "true") + + source = TEST_FUNCTIONS_DIR / "execution_id" / "async_main.py" + target = "sync_cloudevent_with_context" + app = create_asgi_app(target, source, signature_type="cloudevent") + client = TestClient(app) + + response = client.post( + "/", + headers={ + "ce-specversion": "1.0", + "ce-type": "com.example.test", + "ce-source": "test-source", + "ce-id": "test-id", + "Function-Execution-Id": TEST_EXECUTION_ID, + "Content-Type": "application/json", + }, + json={"message": "test"}, + ) + + assert response.status_code == 200 + assert response.text == "OK" + + record = capsys.readouterr() + assert f"Execution ID in sync CloudEvent: {TEST_EXECUTION_ID}" in record.err + assert "No execution context in sync CloudEvent function!" not in record.err + + +def test_cloudevent_returns_500(capsys, monkeypatch): + monkeypatch.setenv("LOG_EXECUTION_ID", "true") + source = TEST_FUNCTIONS_DIR / "execution_id" / "async_main.py" + target = "async_cloudevent_error" + app = create_asgi_app(target, source, signature_type="cloudevent") + client = TestClient(app, raise_server_exceptions=False) + resp = client.post( + "/", + headers={ + "ce-specversion": "1.0", + "ce-type": "com.example.test", + "ce-source": "test-source", + "ce-id": "test-id", + "Function-Execution-Id": TEST_EXECUTION_ID, + "Content-Type": "application/json", + }, + ) + assert resp.status_code == 500 + record = capsys.readouterr() + assert f'"execution_id": "{TEST_EXECUTION_ID}"' in record.err + assert '"logging.googleapis.com/labels"' in record.err + assert "ValueError" in record.err diff --git a/tests/test_functions/execution_id/async_main.py b/tests/test_functions/execution_id/async_main.py new file mode 100644 index 00000000..7149e7fb --- /dev/null +++ b/tests/test_functions/execution_id/async_main.py @@ -0,0 +1,62 @@ +import asyncio +import logging + +from functions_framework import execution_id + +logger = logging.getLogger(__name__) + + +async def async_print_message(request): + json = await request.json() + print(json.get("message")) + return {"status": "success"}, 200 + + +async def async_log_message(request): + json = await request.json() + logger.warning(json.get("message")) + return {"status": "success"}, 200 + + +async def async_function(request): + return {"execution_id": request.headers.get("Function-Execution-Id")} + + +async def async_error(request): + return 1 / 0 + + +async def async_sleep(request): + json = await request.json() + message = json.get("message") + logger.warning(message) + await asyncio.sleep(1) + logger.warning(message) + return {"status": "success"}, 200 + + +async def async_trace_test(request): + context = execution_id._get_current_context() + return { + "execution_id": context.execution_id if context else None, + "span_id": context.span_id if context else None, + } + + +def sync_function_in_async_context(request): + return { + "execution_id": request.headers.get("Function-Execution-Id"), + "type": "sync", + } + + +def sync_cloudevent_with_context(cloud_event): + context = execution_id._get_current_context() + if context: + logger.info(f"Execution ID in sync CloudEvent: {context.execution_id}") + else: + logger.error("No execution context in sync CloudEvent function!") + + +async def async_cloudevent_error(cloudevent): + raise ValueError("This is a test error")