diff --git a/docs/source/developer_guide/guides/3_simple_cpp_stage.md b/docs/source/developer_guide/guides/3_simple_cpp_stage.md index 610fe49f8d..b9362e9755 100644 --- a/docs/source/developer_guide/guides/3_simple_cpp_stage.md +++ b/docs/source/developer_guide/guides/3_simple_cpp_stage.md @@ -383,7 +383,7 @@ As mentioned in the previous section, our `_build_single` method needs to be upd ```python def _build_single(self, builder: mrc.Builder, input_node: mrc.SegmentObject) -> mrc.SegmentObject: - if self._build_cpp_node() and isinstance(self._input_type, ControlMessage): + if self._build_cpp_node() and issubclass(self._input_type, ControlMessage): from ._lib import pass_thru_cpp node = pass_thru_cpp.PassThruStage(builder, self.unique_name) @@ -438,7 +438,7 @@ class PassThruStage(PassThruTypeMixin, GpuAndCpuMixin, SinglePortStage): return message def _build_single(self, builder: mrc.Builder, input_node: mrc.SegmentObject) -> mrc.SegmentObject: - if self._build_cpp_node() and isinstance(self._input_type, ControlMessage): + if self._build_cpp_node() and issubclass(self._input_type, ControlMessage): from ._lib import pass_thru_cpp node = pass_thru_cpp.PassThruStage(builder, self.unique_name) diff --git a/examples/developer_guide/3_simple_cpp_stage/src/simple_cpp_stage/pass_thru.py b/examples/developer_guide/3_simple_cpp_stage/src/simple_cpp_stage/pass_thru.py index 8364b89644..728e8ef431 100644 --- a/examples/developer_guide/3_simple_cpp_stage/src/simple_cpp_stage/pass_thru.py +++ b/examples/developer_guide/3_simple_cpp_stage/src/simple_cpp_stage/pass_thru.py @@ -53,9 +53,8 @@ def on_data(self, message: typing.Any): return message def _build_single(self, builder: mrc.Builder, input_node: mrc.SegmentObject) -> mrc.SegmentObject: - if self._build_cpp_node() and isinstance(self._input_type, ControlMessage): + if self._build_cpp_node() and issubclass(self._input_type, ControlMessage): from ._lib import pass_thru_cpp - node = pass_thru_cpp.PassThruStage(builder, self.unique_name) else: node = builder.make_node(self.unique_name, ops.map(self.on_data))