Skip to content

Commit 2175f92

Browse files
committed
Add weak_ptr<void> callback_lifetime to SubscriptionOptions
Avoid potential use after free usage of a registered subscription callback function by allowing user to specify a weak_ptr to be checked for expiry before the associated subscription callback is called. If user does not specify callback_lifetime, the mechanism falls back to a tracking the lifetime of a user specified callback_group, failing that it tracks the lifetime of the nodes default_callback_group. Signed-off-by: Mike Wake <[email protected]>
1 parent 9cabd69 commit 2175f92

File tree

3 files changed

+62
-18
lines changed

3 files changed

+62
-18
lines changed

rclcpp/include/rclcpp/experimental/subscription_intra_process.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,16 @@ class SubscriptionIntraProcess
7777
rclcpp::Context::SharedPtr context,
7878
const std::string & topic_name,
7979
const rclcpp::QoS & qos_profile,
80-
rclcpp::IntraProcessBufferType buffer_type)
80+
rclcpp::IntraProcessBufferType buffer_type,
81+
std::weak_ptr<void> callback_lifetime)
8182
: SubscriptionIntraProcessBuffer<SubscribedType, SubscribedTypeAlloc,
8283
SubscribedTypeDeleter, ROSMessageType>(
8384
std::make_shared<SubscribedTypeAlloc>(*allocator),
8485
context,
8586
topic_name,
8687
qos_profile,
8788
buffer_type),
89+
callback_lifetime_(callback_lifetime),
8890
any_callback_(callback)
8991
{
9092
TRACETOOLS_TRACEPOINT(
@@ -166,6 +168,10 @@ class SubscriptionIntraProcess
166168
typename std::enable_if<!std::is_same<T, rcl_serialized_message_t>::value, void>::type
167169
execute_impl(const std::shared_ptr<void> & data)
168170
{
171+
if (callback_lifetime_.expired()) {
172+
return;
173+
}
174+
169175
if (nullptr == data) {
170176
return;
171177
}
@@ -187,6 +193,7 @@ class SubscriptionIntraProcess
187193
shared_ptr.reset();
188194
}
189195

196+
std::weak_ptr<void> callback_lifetime_;
190197
AnySubscriptionCallback<MessageT, Alloc> any_callback_;
191198
};
192199

rclcpp/include/rclcpp/subscription.hpp

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,29 @@ class Subscription : public SubscriptionBase
156156
"' is not allowed with 0 depth qos policy");
157157
}
158158

159+
// Use std::weak_ptr owner_before trick to determine if user
160+
// has assigned a subscription options_.callback_lifetime weak_ptr.
161+
// https://stackoverflow.com/a/45507610
162+
std::weak_ptr<void> empty;
163+
if (!options_.callback_lifetime.owner_before(empty) &&
164+
!empty.owner_before(options_.callback_lifetime)) {
165+
// options_.callback_lifetime was not user assigned,
166+
// So use options_.callback_group if user assigned,
167+
// falling back to node's default_callback_group
168+
std::shared_ptr<void> vsp = options_.callback_group == nullptr ?
169+
node_base->get_default_callback_group() :
170+
options_.callback_group;
171+
std::weak_ptr<void> vwp = vsp;
172+
options_.callback_lifetime = vwp;
173+
}
174+
175+
if (options_.callback_lifetime.expired())
176+
{
177+
throw std::invalid_argument(
178+
"callback_lifetime weak_ptr for topic '" + topic_name +
179+
"' has already expired");
180+
}
181+
159182
using SubscriptionIntraProcessT = rclcpp::experimental::SubscriptionIntraProcess<
160183
MessageT,
161184
SubscribedType,
@@ -172,7 +195,8 @@ class Subscription : public SubscriptionBase
172195
context,
173196
this->get_topic_name(), // important to get like this, as it has the fully-qualified name
174197
qos_profile,
175-
resolve_intra_process_buffer_type(options_.intra_process_buffer_type, callback));
198+
resolve_intra_process_buffer_type(options_.intra_process_buffer_type, callback),
199+
options_.callback_lifetime);
176200
TRACETOOLS_TRACEPOINT(
177201
rclcpp_subscription_init,
178202
static_cast<const void *>(get_subscription_handle().get()),
@@ -300,12 +324,15 @@ class Subscription : public SubscriptionBase
300324
now = std::chrono::system_clock::now();
301325
}
302326

303-
any_callback_.dispatch(typed_message, message_info);
327+
if (!options_.callback_lifetime.expired())
328+
{
329+
any_callback_.dispatch(typed_message, message_info);
304330

305-
if (subscription_topic_statistics_) {
306-
const auto nanos = std::chrono::time_point_cast<std::chrono::nanoseconds>(now);
307-
const auto time = rclcpp::Time(nanos.time_since_epoch().count());
308-
subscription_topic_statistics_->handle_message(message_info.get_rmw_message_info(), time);
331+
if (subscription_topic_statistics_) {
332+
const auto nanos = std::chrono::time_point_cast<std::chrono::nanoseconds>(now);
333+
const auto time = rclcpp::Time(nanos.time_since_epoch().count());
334+
subscription_topic_statistics_->handle_message(message_info.get_rmw_message_info(), time);
335+
}
309336
}
310337
}
311338

@@ -321,12 +348,15 @@ class Subscription : public SubscriptionBase
321348
now = std::chrono::system_clock::now();
322349
}
323350

324-
any_callback_.dispatch(serialized_message, message_info);
351+
if (!options_.callback_lifetime.expired())
352+
{
353+
any_callback_.dispatch(serialized_message, message_info);
325354

326-
if (subscription_topic_statistics_) {
327-
const auto nanos = std::chrono::time_point_cast<std::chrono::nanoseconds>(now);
328-
const auto time = rclcpp::Time(nanos.time_since_epoch().count());
329-
subscription_topic_statistics_->handle_message(message_info.get_rmw_message_info(), time);
355+
if (subscription_topic_statistics_) {
356+
const auto nanos = std::chrono::time_point_cast<std::chrono::nanoseconds>(now);
357+
const auto time = rclcpp::Time(nanos.time_since_epoch().count());
358+
subscription_topic_statistics_->handle_message(message_info.get_rmw_message_info(), time);
359+
}
330360
}
331361
}
332362

@@ -353,12 +383,15 @@ class Subscription : public SubscriptionBase
353383
now = std::chrono::system_clock::now();
354384
}
355385

356-
any_callback_.dispatch(sptr, message_info);
386+
if (!options_.callback_lifetime.expired())
387+
{
388+
any_callback_.dispatch(sptr, message_info);
357389

358-
if (subscription_topic_statistics_) {
359-
const auto nanos = std::chrono::time_point_cast<std::chrono::nanoseconds>(now);
360-
const auto time = rclcpp::Time(nanos.time_since_epoch().count());
361-
subscription_topic_statistics_->handle_message(message_info.get_rmw_message_info(), time);
390+
if (subscription_topic_statistics_) {
391+
const auto nanos = std::chrono::time_point_cast<std::chrono::nanoseconds>(now);
392+
const auto time = rclcpp::Time(nanos.time_since_epoch().count());
393+
subscription_topic_statistics_->handle_message(message_info.get_rmw_message_info(), time);
394+
}
362395
}
363396
}
364397

@@ -449,7 +482,9 @@ class Subscription : public SubscriptionBase
449482
* It is important to save a copy of this so that the rmw payload which it
450483
* may contain is kept alive for the duration of the subscription.
451484
*/
452-
const rclcpp::SubscriptionOptionsWithAllocator<AllocatorT> options_;
485+
// NOTE: Had to drop const in order to set default options_.callback_lifetime
486+
// if not set in user code.
487+
rclcpp::SubscriptionOptionsWithAllocator<AllocatorT> options_;
453488
typename message_memory_strategy::MessageMemoryStrategy<ROSMessageType, AllocatorT>::SharedPtr
454489
message_memory_strategy_;
455490

rclcpp/include/rclcpp/subscription_options.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ struct SubscriptionOptionsBase
8989
QosOverridingOptions qos_overriding_options;
9090

9191
ContentFilterOptions content_filter_options;
92+
93+
std::weak_ptr<void> callback_lifetime;
9294
};
9395

9496
/// Structure containing optional configuration for Subscriptions.

0 commit comments

Comments
 (0)