@@ -37,6 +37,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust
3737import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
3838import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
3939import software.amazon.smithy.rust.codegen.core.rustlang.withBlock
40+ import software.amazon.smithy.rust.codegen.core.rustlang.withBlockTemplate
4041import software.amazon.smithy.rust.codegen.core.rustlang.writable
4142import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
4243import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
@@ -200,12 +201,10 @@ class ServerProtocolTestGenerator(
200201 #{RegistryBuilderMethods:W}
201202 }
202203
203- /// The operation full name is a concatenation of `<operation namespace>.<operation name>`.
204204 pub(crate) async fn build_router_and_make_request(
205205 http_request: #{Http}::request::Request<#{SmithyHttpServer}::body::Body>,
206- operation_full_name: &str,
207206 f: &dyn Fn(RegistryBuilder) -> RegistryBuilder,
208- ) {
207+ ) -> #{Http}::response::Response<#{SmithyHttpServer}::body::BoxBody> {
209208 let mut router: #{Router} = f(create_operation_registry_builder())
210209 .build()
211210 .expect("unable to build operation registry")
@@ -214,6 +213,12 @@ class ServerProtocolTestGenerator(
214213 .call(http_request)
215214 .await
216215 .expect("unable to make an HTTP request");
216+
217+ http_response
218+ }
219+
220+ /// The operation full name is a concatenation of `<operation namespace>.<operation name>`.
221+ pub(crate) fn check_operation_extension_was_set(http_response: #{Http}::response::Response<#{SmithyHttpServer}::body::BoxBody>, operation_full_name: &str) {
217222 let operation_extension = http_response.extensions()
218223 .get::<#{SmithyHttpServer}::extension::OperationExtension>()
219224 .expect("extension `OperationExtension` not found");
@@ -284,6 +289,7 @@ class ServerProtocolTestGenerator(
284289
285290 is TestCase .MalformedRequestTest -> this .renderHttpMalformedRequestTestCase(
286291 it.testCase,
292+ operationShape,
287293 operationSymbol,
288294 )
289295 }
@@ -388,15 +394,17 @@ class ServerProtocolTestGenerator(
388394 renderHttpRequest(uri, method, headers, body.orNull(), queryParams, host.orNull())
389395 }
390396 if (protocolSupport.requestBodyDeserialization) {
391- checkRequest(operationShape, operationSymbol, httpRequestTestCase, this )
397+ makeRequest(operationShape, this , checkRequestHandler(operationShape, httpRequestTestCase))
398+ checkHandlerWasEntered(operationShape, operationSymbol, this )
392399 }
393400
394401 // Test against new service builder.
395402 with (httpRequestTestCase) {
396403 renderHttpRequest(uri, method, headers, body.orNull(), queryParams, host.orNull())
397404 }
398405 if (protocolSupport.requestBodyDeserialization) {
399- checkRequest2(operationShape, operationSymbol, httpRequestTestCase, this )
406+ makeRequest2(operationShape, operationSymbol, this , checkRequestHandler(operationShape, httpRequestTestCase))
407+ checkHandlerWasEntered2(this )
400408 }
401409
402410 // Explicitly warn if the test case defined parameters that we aren't doing anything with
@@ -467,24 +475,30 @@ class ServerProtocolTestGenerator(
467475 */
468476 private fun RustWriter.renderHttpMalformedRequestTestCase (
469477 testCase : HttpMalformedRequestTestCase ,
478+ operationShape : OperationShape ,
470479 operationSymbol : Symbol ,
471480 ) {
472- with (testCase.request) {
473- // TODO(https://github.com/awslabs/smithy/issues/1102): `uri` should probably not be an `Optional`.
474- renderHttpRequest(uri.get(), method, headers, body.orNull(), queryParams, host.orNull())
481+ val (_, outputT) = operationInputOutputTypes[operationShape]!!
482+
483+ rust(" // Use the `OperationRegistryBuilder`" )
484+ rustBlock(" " ) {
485+ with (testCase.request) {
486+ // TODO(https://github.com/awslabs/smithy/issues/1102): `uri` should probably not be an `Optional`.
487+ renderHttpRequest(uri.get(), method, headers, body.orNull(), queryParams, host.orNull())
488+ }
489+ makeRequest(operationShape, this , writable(" todo!() as $outputT " ))
490+ checkResponse(this , testCase.response)
475491 }
476492
477- val operationName = " ${operationSymbol.name}${ServerHttpBoundProtocolGenerator .OPERATION_INPUT_WRAPPER_SUFFIX } "
478- rustTemplate(
479- """
480- let mut http_request = #{SmithyHttpServer}::request::RequestParts::new(http_request);
481- let rejection = super::$operationName ::from_request(&mut http_request).await.expect_err("request was accepted but we expected it to be rejected");
482- let http_response = #{SmithyHttpServer}::response::IntoResponse::<#{Protocol}>::into_response(rejection);
483- """ ,
484- " Protocol" to protocolGenerator.protocol.markerStruct(),
485- * codegenScope,
486- )
487- checkResponse(this , testCase.response)
493+ rust(" // Use new service builder" )
494+ rustBlock(" " ) {
495+ with (testCase.request) {
496+ // TODO(https://github.com/awslabs/smithy/issues/1102): `uri` should probably not be an `Optional`.
497+ renderHttpRequest(uri.get(), method, headers, body.orNull(), queryParams, host.orNull())
498+ }
499+ makeRequest2(operationShape, operationSymbol, this , writable(" todo!() as $outputT " ))
500+ checkResponse(this , testCase.response)
501+ }
488502 }
489503
490504 private fun RustWriter.renderHttpRequest (
@@ -563,41 +577,53 @@ class ServerProtocolTestGenerator(
563577 }
564578
565579 /* * Checks the request using the `OperationRegistryBuilder`. */
566- private fun checkRequest (
580+ private fun makeRequest (
567581 operationShape : OperationShape ,
568- operationSymbol : Symbol ,
569- httpRequestTestCase : HttpRequestTestCase ,
570582 rustWriter : RustWriter ,
583+ operationBody : Writable ,
571584 ) {
572585 val (inputT, outputT) = operationInputOutputTypes[operationShape]!!
573586
574- rustWriter.withBlock (
587+ rustWriter.withBlockTemplate (
575588 """
576- super::$PROTOCOL_TEST_HELPER_MODULE_NAME ::build_router_and_make_request(
589+ let http_response = super::$PROTOCOL_TEST_HELPER_MODULE_NAME ::build_router_and_make_request(
577590 http_request,
578- "${operationShape.id.namespace} .${operationSymbol.name} ",
579591 &|builder| {
580592 builder.${operationShape.toName()} ((|input| Box::pin(async move {
581593 """ ,
582594
583595 " })) as super::$PROTOCOL_TEST_HELPER_MODULE_NAME ::Fun<$inputT , $outputT >)}).await;" ,
584-
596+ * codegenScope,
585597 ) {
586- checkRequestHandler(operationShape, httpRequestTestCase) ()
598+ operationBody ()
587599 }
588600 }
589601
602+ private fun checkHandlerWasEntered (
603+ operationShape : OperationShape ,
604+ operationSymbol : Symbol ,
605+ rustWriter : RustWriter ,
606+ ) {
607+ val operationFullName = " ${operationShape.id.namespace} .${operationSymbol.name} "
608+ rustWriter.rust(
609+ """
610+ super::$PROTOCOL_TEST_HELPER_MODULE_NAME ::check_operation_extension_was_set(http_response, "$operationFullName ");
611+ """ ,
612+ )
613+ }
614+
590615 /* * Checks the request using the new service builder. */
591- private fun checkRequest2 (
616+ private fun makeRequest2 (
592617 operationShape : OperationShape ,
593618 operationSymbol : Symbol ,
594- httpRequestTestCase : HttpRequestTestCase ,
595619 rustWriter : RustWriter ,
620+ body : Writable ,
596621 ) {
597622 val (inputT, _) = operationInputOutputTypes[operationShape]!!
598623 val operationName = RustReservedWords .escapeIfNeeded(operationSymbol.name.toSnakeCase())
599624 rustWriter.rustTemplate(
600625 """
626+ ##[allow(unused_mut)]
601627 let (sender, mut receiver) = #{Tokio}::sync::mpsc::channel(1);
602628 let service = crate::service::$serviceName ::unchecked_builder()
603629 .$operationName (move |input: $inputT | {
@@ -612,13 +638,20 @@ class ServerProtocolTestGenerator(
612638 let http_response = #{Tower}::ServiceExt::oneshot(service, http_request)
613639 .await
614640 .expect("unable to make an HTTP request");
615- assert!(receiver.recv().await.is_some())
616641 """ ,
617- " Body" to checkRequestHandler(operationShape, httpRequestTestCase) ,
642+ " Body" to body ,
618643 * codegenScope,
619644 )
620645 }
621646
647+ private fun checkHandlerWasEntered2 (rustWriter : RustWriter ) {
648+ rustWriter.rust(
649+ """
650+ assert!(receiver.recv().await.is_some());
651+ """ ,
652+ )
653+ }
654+
622655 private fun checkRequestParams (inputShape : StructureShape , rustWriter : RustWriter ) {
623656 if (inputShape.hasStreamingMember(model)) {
624657 // A streaming shape does not implement `PartialEq`, so we have to iterate over the input shape's members
@@ -842,7 +875,7 @@ class ServerProtocolTestGenerator(
842875 private fun assertOk (rustWriter : RustWriter , inner : Writable ) {
843876 rustWriter.rust(" #T(" , RuntimeType .ProtocolTestHelper (codegenContext.runtimeConfig, " assert_ok" ))
844877 inner(rustWriter)
845- rustWriter.rust (" );" )
878+ rustWriter.write (" );" )
846879 }
847880
848881 private fun strSlice (writer : RustWriter , args : List <String >) {
@@ -872,6 +905,10 @@ class ServerProtocolTestGenerator(
872905 private val AwsQuery = " aws.protocoltests.query#AwsQuery"
873906 private val Ec2Query = " aws.protocoltests.ec2#AwsEc2"
874907 private val ExpectFail = setOf<FailingTest >(
908+ // Pending merge from the Smithy team: see https://github.com/awslabs/smithy/pull/1477.
909+ FailingTest (RestJson , " RestJsonWithPayloadExpectsImpliedContentType" , TestType .MalformedRequest ),
910+ FailingTest (RestJson , " RestJsonBodyMalformedMapNullKey" , TestType .MalformedRequest ),
911+
875912 // Pending resolution from the Smithy team, see https://github.com/awslabs/smithy/issues/1068.
876913 FailingTest (RestJson , " RestJsonHttpWithHeadersButNoPayload" , TestType .Request ),
877914
0 commit comments