Skip to content

Commit 9ae2e52

Browse files
authored
fix: Fix the task leak with the lazy in-mem rpc client while still keeping it lazy (#31)
* Make blobs more cheaply cloneable by by giving it an Inner * Remove the lazy part. The lazy handler kept a reference to Blobs alive. This caused both the task and the blobs to never be dropped. To solve this you can just split the inner part in 2 parts, one that has the handle and one that has the logic. But that is not nice. I think it is best for the mem rpc handler to exist completely separately, especially given that rpc is a non-default feature. * spawn_rpc should make it sufficiently clear that this is a thing you need to put away somewhere. Or maybe spawn_client? * back to the lazy client * add comment about the purpose of the handler
1 parent dba7850 commit 9ae2e52

File tree

5 files changed

+276
-201
lines changed

5 files changed

+276
-201
lines changed

examples/custom-protocol.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ async fn main() -> Result<()> {
122122

123123
// Print out our query results.
124124
for hash in hashes {
125-
read_and_print(&blobs_client, hash).await?;
125+
read_and_print(blobs_client, hash).await?;
126126
}
127127
}
128128
}

src/net_protocol.rs

Lines changed: 47 additions & 179 deletions
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,22 @@
55

66
use std::{collections::BTreeSet, fmt::Debug, ops::DerefMut, sync::Arc};
77

8-
use anyhow::{anyhow, bail, Result};
8+
use anyhow::{bail, Result};
99
use futures_lite::future::Boxed as BoxedFuture;
1010
use futures_util::future::BoxFuture;
1111
use iroh::{endpoint::Connecting, protocol::ProtocolHandler, Endpoint, NodeAddr};
1212
use iroh_base::hash::{BlobFormat, Hash};
1313
use serde::{Deserialize, Serialize};
14-
use tracing::{debug, warn};
14+
use tracing::debug;
1515

1616
use crate::{
17-
downloader::{DownloadRequest, Downloader},
18-
get::{
19-
db::{DownloadProgress, GetState},
20-
Stats,
21-
},
17+
downloader::Downloader,
2218
provider::EventSender,
2319
store::GcConfig,
2420
util::{
2521
local_pool::{self, LocalPoolHandle},
26-
progress::{AsyncChannelProgressSender, ProgressSender},
2722
SetTagOption,
2823
},
29-
HashAndFormat,
3024
};
3125

3226
/// A callback that blobs can ask about a set of hashes that should not be garbage collected.
@@ -47,16 +41,21 @@ impl Default for GcState {
4741
}
4842
}
4943

50-
#[derive(Debug, Clone)]
51-
pub struct Blobs<S> {
52-
rt: LocalPoolHandle,
44+
#[derive(Debug)]
45+
pub(crate) struct BlobsInner<S> {
46+
pub(crate) rt: LocalPoolHandle,
5347
pub(crate) store: S,
5448
events: EventSender,
55-
downloader: Downloader,
49+
pub(crate) downloader: Downloader,
50+
pub(crate) endpoint: Endpoint,
51+
gc_state: std::sync::Mutex<GcState>,
5652
#[cfg(feature = "rpc")]
57-
batches: Arc<tokio::sync::Mutex<BlobBatches>>,
58-
endpoint: Endpoint,
59-
gc_state: Arc<std::sync::Mutex<GcState>>,
53+
pub(crate) batches: tokio::sync::Mutex<BlobBatches>,
54+
}
55+
56+
#[derive(Debug, Clone)]
57+
pub struct Blobs<S> {
58+
pub(crate) inner: Arc<BlobsInner<S>>,
6059
#[cfg(feature = "rpc")]
6160
pub(crate) rpc_handler: Arc<std::sync::OnceLock<crate::rpc::RpcHandler>>,
6261
}
@@ -76,7 +75,7 @@ pub(crate) struct BlobBatches {
7675
#[derive(Debug, Default)]
7776
struct BlobBatch {
7877
/// The tags in this batch.
79-
tags: std::collections::BTreeMap<HashAndFormat, Vec<crate::TempTag>>,
78+
tags: std::collections::BTreeMap<iroh::hash::HashAndFormat, Vec<crate::TempTag>>,
8079
}
8180

8281
#[cfg(feature = "rpc")]
@@ -95,7 +94,11 @@ impl BlobBatches {
9594
}
9695

9796
/// Remove a tag from a batch.
98-
pub fn remove_one(&mut self, batch: BatchId, content: &HashAndFormat) -> Result<()> {
97+
pub fn remove_one(
98+
&mut self,
99+
batch: BatchId,
100+
content: &iroh::hash::HashAndFormat,
101+
) -> Result<()> {
99102
if let Some(batch) = self.batches.get_mut(&batch) {
100103
if let Some(tags) = batch.tags.get_mut(content) {
101104
tags.pop();
@@ -178,40 +181,46 @@ impl<S: crate::store::Store> Blobs<S> {
178181
endpoint: Endpoint,
179182
) -> Self {
180183
Self {
181-
rt,
182-
store,
183-
events,
184-
downloader,
185-
endpoint,
186-
#[cfg(feature = "rpc")]
187-
batches: Default::default(),
188-
gc_state: Default::default(),
184+
inner: Arc::new(BlobsInner {
185+
rt,
186+
store,
187+
events,
188+
downloader,
189+
endpoint,
190+
#[cfg(feature = "rpc")]
191+
batches: Default::default(),
192+
gc_state: Default::default(),
193+
}),
189194
#[cfg(feature = "rpc")]
190195
rpc_handler: Default::default(),
191196
}
192197
}
193198

194199
pub fn store(&self) -> &S {
195-
&self.store
200+
&self.inner.store
201+
}
202+
203+
pub fn events(&self) -> &EventSender {
204+
&self.inner.events
196205
}
197206

198207
pub fn rt(&self) -> &LocalPoolHandle {
199-
&self.rt
208+
&self.inner.rt
200209
}
201210

202211
pub fn downloader(&self) -> &Downloader {
203-
&self.downloader
212+
&self.inner.downloader
204213
}
205214

206215
pub fn endpoint(&self) -> &Endpoint {
207-
&self.endpoint
216+
&self.inner.endpoint
208217
}
209218

210219
/// Add a callback that will be called before the garbage collector runs.
211220
///
212221
/// This can only be called before the garbage collector has started, otherwise it will return an error.
213222
pub fn add_protected(&self, cb: ProtectCb) -> Result<()> {
214-
let mut state = self.gc_state.lock().unwrap();
223+
let mut state = self.inner.gc_state.lock().unwrap();
215224
match &mut *state {
216225
GcState::Initial(cbs) => {
217226
cbs.push(cb);
@@ -225,7 +234,7 @@ impl<S: crate::store::Store> Blobs<S> {
225234

226235
/// Start garbage collection with the given settings.
227236
pub fn start_gc(&self, config: GcConfig) -> Result<()> {
228-
let mut state = self.gc_state.lock().unwrap();
237+
let mut state = self.inner.gc_state.lock().unwrap();
229238
let protected = match state.deref_mut() {
230239
GcState::Initial(items) => std::mem::take(items),
231240
GcState::Started(_) => bail!("gc already started"),
@@ -241,161 +250,20 @@ impl<S: crate::store::Store> Blobs<S> {
241250
set
242251
}
243252
};
244-
let store = self.store.clone();
253+
let store = self.store().clone();
245254
let run = self
246-
.rt
255+
.rt()
247256
.spawn(move || async move { store.gc_run(config, protected_cb).await });
248257
*state = GcState::Started(Some(run));
249258
Ok(())
250259
}
251-
252-
#[cfg(feature = "rpc")]
253-
pub(crate) async fn batches(&self) -> tokio::sync::MutexGuard<'_, BlobBatches> {
254-
self.batches.lock().await
255-
}
256-
257-
pub(crate) async fn download(
258-
&self,
259-
endpoint: Endpoint,
260-
req: BlobDownloadRequest,
261-
progress: AsyncChannelProgressSender<DownloadProgress>,
262-
) -> Result<()> {
263-
let BlobDownloadRequest {
264-
hash,
265-
format,
266-
nodes,
267-
tag,
268-
mode,
269-
} = req;
270-
let hash_and_format = HashAndFormat { hash, format };
271-
let temp_tag = self.store.temp_tag(hash_and_format);
272-
let stats = match mode {
273-
DownloadMode::Queued => {
274-
self.download_queued(endpoint, hash_and_format, nodes, progress.clone())
275-
.await?
276-
}
277-
DownloadMode::Direct => {
278-
self.download_direct_from_nodes(endpoint, hash_and_format, nodes, progress.clone())
279-
.await?
280-
}
281-
};
282-
283-
progress.send(DownloadProgress::AllDone(stats)).await.ok();
284-
match tag {
285-
SetTagOption::Named(tag) => {
286-
self.store.set_tag(tag, Some(hash_and_format)).await?;
287-
}
288-
SetTagOption::Auto => {
289-
self.store.create_tag(hash_and_format).await?;
290-
}
291-
}
292-
drop(temp_tag);
293-
294-
Ok(())
295-
}
296-
297-
async fn download_queued(
298-
&self,
299-
endpoint: Endpoint,
300-
hash_and_format: HashAndFormat,
301-
nodes: Vec<NodeAddr>,
302-
progress: AsyncChannelProgressSender<DownloadProgress>,
303-
) -> Result<Stats> {
304-
/// Name used for logging when new node addresses are added from gossip.
305-
const BLOB_DOWNLOAD_SOURCE_NAME: &str = "blob_download";
306-
307-
let mut node_ids = Vec::with_capacity(nodes.len());
308-
let mut any_added = false;
309-
for node in nodes {
310-
node_ids.push(node.node_id);
311-
if !node.info.is_empty() {
312-
endpoint.add_node_addr_with_source(node, BLOB_DOWNLOAD_SOURCE_NAME)?;
313-
any_added = true;
314-
}
315-
}
316-
let can_download = !node_ids.is_empty() && (any_added || endpoint.discovery().is_some());
317-
anyhow::ensure!(can_download, "no way to reach a node for download");
318-
let req = DownloadRequest::new(hash_and_format, node_ids).progress_sender(progress);
319-
let handle = self.downloader.queue(req).await;
320-
let stats = handle.await?;
321-
Ok(stats)
322-
}
323-
324-
#[tracing::instrument("download_direct", skip_all, fields(hash=%hash_and_format.hash.fmt_short()))]
325-
async fn download_direct_from_nodes(
326-
&self,
327-
endpoint: Endpoint,
328-
hash_and_format: HashAndFormat,
329-
nodes: Vec<NodeAddr>,
330-
progress: AsyncChannelProgressSender<DownloadProgress>,
331-
) -> Result<Stats> {
332-
let mut last_err = None;
333-
let mut remaining_nodes = nodes.len();
334-
let mut nodes_iter = nodes.into_iter();
335-
'outer: loop {
336-
match crate::get::db::get_to_db_in_steps(
337-
self.store.clone(),
338-
hash_and_format,
339-
progress.clone(),
340-
)
341-
.await?
342-
{
343-
GetState::Complete(stats) => return Ok(stats),
344-
GetState::NeedsConn(needs_conn) => {
345-
let (conn, node_id) = 'inner: loop {
346-
match nodes_iter.next() {
347-
None => break 'outer,
348-
Some(node) => {
349-
remaining_nodes -= 1;
350-
let node_id = node.node_id;
351-
if node_id == endpoint.node_id() {
352-
debug!(
353-
?remaining_nodes,
354-
"skip node {} (it is the node id of ourselves)",
355-
node_id.fmt_short()
356-
);
357-
continue 'inner;
358-
}
359-
match endpoint.connect(node, crate::protocol::ALPN).await {
360-
Ok(conn) => break 'inner (conn, node_id),
361-
Err(err) => {
362-
debug!(
363-
?remaining_nodes,
364-
"failed to connect to {}: {err}",
365-
node_id.fmt_short()
366-
);
367-
continue 'inner;
368-
}
369-
}
370-
}
371-
}
372-
};
373-
match needs_conn.proceed(conn).await {
374-
Ok(stats) => return Ok(stats),
375-
Err(err) => {
376-
warn!(
377-
?remaining_nodes,
378-
"failed to download from {}: {err}",
379-
node_id.fmt_short()
380-
);
381-
last_err = Some(err);
382-
}
383-
}
384-
}
385-
}
386-
}
387-
match last_err {
388-
Some(err) => Err(err.into()),
389-
None => Err(anyhow!("No nodes to download from provided")),
390-
}
391-
}
392260
}
393261

394262
impl<S: crate::store::Store> ProtocolHandler for Blobs<S> {
395263
fn accept(&self, conn: Connecting) -> BoxedFuture<Result<()>> {
396-
let db = self.store.clone();
397-
let events = self.events.clone();
398-
let rt = self.rt.clone();
264+
let db = self.store().clone();
265+
let events = self.events().clone();
266+
let rt = self.rt().clone();
399267

400268
Box::pin(async move {
401269
crate::provider::handle_connection(conn.await?, db, events, rt).await;
@@ -404,7 +272,7 @@ impl<S: crate::store::Store> ProtocolHandler for Blobs<S> {
404272
}
405273

406274
fn shutdown(&self) -> BoxedFuture<()> {
407-
let store = self.store.clone();
275+
let store = self.store().clone();
408276
Box::pin(async move {
409277
store.shutdown().await;
410278
})

0 commit comments

Comments
 (0)