Skip to content

refactor!: make Dialer trait private and inline iroh::dialer::Dialer #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 88 additions & 8 deletions src/downloader.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
//! Handle downloading blobs and collections concurrently and from nodes.
//!
//! The [`Downloader`] interacts with four main components to this end.
//! - [`Dialer`]: Used to queue opening connections to nodes we need to perform downloads.
//! - `ProviderMap`: Where the downloader obtains information about nodes that could be
//! used to perform a download.
//! - [`Store`]: Where data is stored.
Expand All @@ -10,7 +9,7 @@
//! 1. The `ProviderMap` is queried for nodes. From these nodes some are selected
//! prioritizing connected nodes with lower number of active requests. If no useful node is
//! connected, or useful connected nodes have no capacity to perform the request, a connection
//! attempt is started using the [`Dialer`].
//! attempt is started using the `DialerT`.
//! 2. The download is queued for processing at a later time. Downloads are not performed right
//! away. Instead, they are initially delayed to allow the node to obtain the data itself, and
//! to wait for the new connection to be established if necessary.
Expand All @@ -34,13 +33,16 @@ use std::{
fmt,
future::Future,
num::NonZeroUsize,
pin::Pin,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
task::Poll,
time::Duration,
};

use anyhow::anyhow;
use futures_lite::{future::BoxedLocal, Stream, StreamExt};
use hashlink::LinkedHashSet;
use iroh::{endpoint, Endpoint, NodeAddr, NodeId};
Expand All @@ -51,7 +53,7 @@ use tokio::{
task::JoinSet,
};
use tokio_util::{either::Either, sync::CancellationToken, time::delay_queue};
use tracing::{debug, error_span, trace, warn, Instrument};
use tracing::{debug, error, error_span, trace, warn, Instrument};

use crate::{
get::{db::DownloadProgress, Stats},
Expand All @@ -77,7 +79,7 @@ const SERVICE_CHANNEL_CAPACITY: usize = 128;
pub struct IntentId(pub u64);

/// Trait modeling a dialer. This allows for IO-less testing.
pub trait Dialer: Stream<Item = (NodeId, anyhow::Result<Self::Connection>)> + Unpin {
trait DialerT: Stream<Item = (NodeId, anyhow::Result<Self::Connection>)> + Unpin {
/// Type of connections returned by the Dialer.
type Connection: Clone + 'static;
/// Dial a node.
Expand Down Expand Up @@ -354,7 +356,7 @@ impl Downloader {
{
let me = endpoint.node_id().fmt_short();
let (msg_tx, msg_rx) = mpsc::channel(SERVICE_CHANNEL_CAPACITY);
let dialer = iroh::dialer::Dialer::new(endpoint);
let dialer = Dialer::new(endpoint);

let create_future = move || {
let getter = get::IoGetter {
Expand Down Expand Up @@ -532,7 +534,7 @@ enum NodeState<'a, Conn> {
}

#[derive(Debug)]
struct Service<G: Getter, D: Dialer> {
struct Service<G: Getter, D: DialerT> {
/// The getter performs individual requests.
getter: G,
/// Map to query for nodes that we believe have the data we are looking for.
Expand Down Expand Up @@ -564,7 +566,7 @@ struct Service<G: Getter, D: Dialer> {
/// Progress tracker
progress_tracker: ProgressTracker,
}
impl<G: Getter<Connection = D::Connection>, D: Dialer> Service<G, D> {
impl<G: Getter<Connection = D::Connection>, D: DialerT> Service<G, D> {
fn new(
getter: G,
dialer: D,
Expand Down Expand Up @@ -1492,7 +1494,7 @@ impl Queue {
}
}

impl Dialer for iroh::dialer::Dialer {
impl DialerT for Dialer {
type Connection = endpoint::Connection;

fn queue_dial(&mut self, node_id: NodeId) {
Expand All @@ -1511,3 +1513,81 @@ impl Dialer for iroh::dialer::Dialer {
self.endpoint().node_id()
}
}

/// Dials nodes and maintains a queue of pending dials.
///
/// The [`Dialer`] wraps an [`Endpoint`], connects to nodes through the endpoint, stores the
/// pending connect futures and emits finished connect results.
///
/// The [`Dialer`] also implements [`Stream`] to retrieve the dialled connections.
#[derive(Debug)]
struct Dialer {
endpoint: Endpoint,
pending: JoinSet<(NodeId, anyhow::Result<quinn::Connection>)>,
pending_dials: HashMap<NodeId, CancellationToken>,
}

impl Dialer {
/// Create a new dialer for a [`Endpoint`]
fn new(endpoint: Endpoint) -> Self {
Self {
endpoint,
pending: Default::default(),
pending_dials: Default::default(),
}
}

/// Starts to dial a node by [`NodeId`].
fn queue_dial(&mut self, node_id: NodeId, alpn: &'static [u8]) {
if self.is_pending(node_id) {
return;
}
let cancel = CancellationToken::new();
self.pending_dials.insert(node_id, cancel.clone());
let endpoint = self.endpoint.clone();
self.pending.spawn(async move {
let res = tokio::select! {
biased;
_ = cancel.cancelled() => Err(anyhow!("Cancelled")),
res = endpoint.connect(node_id, alpn) => res
};
(node_id, res)
});
}

/// Checks if a node is currently being dialed.
fn is_pending(&self, node: NodeId) -> bool {
self.pending_dials.contains_key(&node)
}

/// Number of pending connections to be opened.
fn pending_count(&self) -> usize {
self.pending_dials.len()
}

/// Returns a reference to the endpoint used in this dialer.
fn endpoint(&self) -> &Endpoint {
&self.endpoint
}
}

impl Stream for Dialer {
type Item = (NodeId, anyhow::Result<quinn::Connection>);

fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
match self.pending.poll_join_next(cx) {
Poll::Ready(Some(Ok((node_id, result)))) => {
self.pending_dials.remove(&node_id);
Poll::Ready(Some((node_id, result)))
}
Poll::Ready(Some(Err(e))) => {
error!("dialer error: {:?}", e);
Poll::Pending
}
_ => Poll::Pending,
}
}
}
2 changes: 1 addition & 1 deletion src/downloader/invariants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
use super::*;

/// invariants for the service.
impl<G: Getter<Connection = D::Connection>, D: Dialer> Service<G, D> {
impl<G: Getter<Connection = D::Connection>, D: DialerT> Service<G, D> {
/// Checks the various invariants the service must maintain
#[track_caller]
pub(in crate::downloader) fn check_invariants(&self) {
Expand Down
2 changes: 1 addition & 1 deletion src/downloader/test/dialer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl Default for TestingDialerInner {
}
}

impl Dialer for TestingDialer {
impl DialerT for TestingDialer {
type Connection = NodeId;

fn queue_dial(&mut self, node_id: NodeId) {
Expand Down
Loading