Blog Index

A DHT for iroh

by Rüdiger Klaehn

RPC client

Using an irpc client directly is not exactly horrible, but nevertheless we want to add some sugar to make it more easy to use. So we write a wrapper around the irpc client that makes using it more convenient. Set and FindNode are just async fns, GetAll returns a stream of responses.

impl RpcClient {
    ...

    pub async fn set(&self, key: Id, value: Value) -> irpc::Result<SetResponse> {
        self.0.rpc(Set { key, value }).await
    }

    pub async fn get_all(
        &self,
        key: Id,
        kind: Kind,
    ) -> irpc::Result<irpc::channel::mpsc::Receiver<Value>> {
        self.0
            .server_streaming(GetAll { key, kind }, 32)
            .await
    }

    pub async fn find_node(
        &self,
        id: Id,
        requester: Option<NodeId>,
    ) -> irpc::Result<Vec<NodeAddr>> {
        self.0.rpc(FindNode { id, requester }).await
    }
}

This client can now be used either with a remote node that is connected via a memory transport, or with a node that is connected via an iroh connection.

Storage implementation

The first thing we would have to do to implement this protocol would be the storage part. For this experiment we will use a very simple memory storage. This might even be a good idea for production! We have a limited value size, and DHTs are not persistent storage anyway. DHT records need to be continuously republished, so if a DHT node goes down it will just be repopulated with values shortly after becoming online again.

The only notable thing we do here is to store values for different value kinds separately for more simple retrieval, and to use an [IndexSet] for the values to keep values sorted by insertion time.

struct MemStorage {
    /// The DHT data storage, mapping keys to values.
    /// Separated by kind to allow for efficient retrieval.
    data: BTreeMap<Id, BTreeMap<Kind, IndexSet<Value>>>,
}

impl MemStorage {
    fn new() -> Self {
        Self {
            data: BTreeMap::new(),
        }
    }

    /// Set a value for a key.
    fn set(&mut self, key: Id, value: Value) {
        let kind = value.kind();
        self.data
            .entry(key)
            .or_default()
            .entry(kind)
            .or_default()
            .insert(value);
    }

    /// Get all values of a certain kind for a key.
    fn get_all(&self, key: &Id, kind: &Kind) -> Option<&IndexSet<Value>> {
        self.data.get(key).and_then(|kinds| kinds.get(kind))
    }
}

Routing implementation

Now it looks like we have run out of simple things to do and need to actually implement the routing part. The routing API does not care how the routing table is organized internally - it could just as well be the full set of nodes. But we want to implement the Kademlia algorithm to get that nice power law distribution.

So let's define the routing table. First of all we need some simple integer arithmetic like xor and leading_zeros for 256 bit numbers. There are various crates that provide this, but since we don't need anything fancy like multiplication or division, we just quickly implemented it inline.

The routing table itself is just a 2d array of node ids. Each row (k-bucket) has a small fixed upper size, so we are going to use the [ArrayVec] crate to prevent allocations. For each node id we just keep a tiny bit of extra information - a timestamp when we have last seen any evidence that the node actually exists and responds, to decide which nodes to check for liveness.

A KBucket is tiny, so doing full scans for addition and removal is totally acceptable. We don't want any clever algorithms here.

struct NodeInfo {
    pub id: NodeId,
    pub last_seen: u64,
}

struct KBucket {
    nodes: ArrayVec<NodeInfo, 20>,
}

The routing table data is now just one bucket per bit, so 256 buckets in our case where we have decided to bucket by leading zero bits:

struct Buckets([KBucket; 256]);

The only additional information we need for the routing table is our own node id. Data in the routing table is organized in terms of closeness to the local node id, so we frequently need to access the local node id when inserting data.

struct RoutingTable {
    buckets: Box<Buckets>,
    local_id: NodeId,
}

Now assuming that the system has some way to find valid DHT nodes, all we need is a way to insert nodes into the routing table, and to query the routing table for the k closest nodes to some key x to implement the FindNode rpc call.

Insertion

Insertion means first computing which bucket the node should go into, and then inserting at that index. Computing the bucket index is computing the xor distance to our own node id, then counting leading zeros and flipping the result around, since we want bucket 0 to contain the closest nodes and bucket 255 to contain the furthest away nodes as per Kademlia convention.

fn bucket_index(&self, target: &[u8; 32]) -> usize {
    let distance = xor(&self.local_id.as_bytes(), target);
    let zeros = leading_zeros(&distance);
    if zeros >= BUCKET_COUNT {
        0 // Same node case
    } else {
        BUCKET_COUNT - 1 - zeros
    }
}

fn add_node(&mut self, node: NodeInfo) -> bool {
    if node.id == self.local_id {
        return false;
    }

    let bucket_idx = self.bucket_index(node.id.as_bytes());
    self.buckets[bucket_idx].add_node(node)
}

Insertion in a bucket where the node already exists means just updating the timestamp. Otherwise just append the node, and if there is no room either make room by evicting the oldest existing node, or ping the oldest node and fail the insertion if the old node responds. For now we just fail, favoring stability. Nodes will be pinged in regular intervals anyway, and nodes that are non-responsive will be purged.

impl KBucket {
    fn add_node(&mut self, node: NodeInfo) -> bool {
        // Check if node already exists and update it
        for existing in &mut self.nodes {
            if existing.id == node.id {
                existing.last_seen = node.last_seen;
                return true; // Updated existing node
            }
        }

        // Add new node if space available
        if self.nodes.len() < K {
            self.nodes.push(node);
            return true;
        }

        false // Bucket full
    }
}

As you can see this is a very simple implementation. Within the bucket we don't care about order.

Querying

Since the xor metric is so simple, and the routing table is of limited size, it is not worth doing anything fancy when querying for a key. Conceptually we just create an array of nodes and distances, sort it by distance, take the k smallest, and that's it.

Since this operation is performed very frequently we did a few simple optimizations though.

impl RoutingTable {
    find_closest_nodes(&self, target: &Id, k: usize) -> Vec<NodeId> {
        let mut candidates = Vec::with_capacity(self.nodes().count());
        candidates.extend(
            self.nodes()
                .map(|node| Distance::between(target, node.id.as_bytes())),
        );
        if k < candidates.len() {
            candidates.select_nth_unstable(k - 1);
            candidates.truncate(k);
        }
        candidates.sort_unstable();

        candidates
            .into_iter()
            .map(|dist| {
                NodeId::from_bytes(&dist.inverse(target))
                    .expect("inverse called with different target than between")
            })
            .collect()
    }
}

We first create an array of candidates that contains all node ids in the routing table. This will almost always be larger than k.

We could just sort it, but we are only interested in the order of the k smallest values, not the overall order. So we can save some comparisons by using select_nth_unstable to sort such that the kth element is in the right place, then truncate and sort just the remaining <= k elements. We can always use unstable sort since the xor distance is an injective function, no two nodes can have the same distance to target id.

As a last trick, instead of storing (id, distance) tuples we just store the distance itself while sorting, and recompute the node id itself by xor-ing again with the target id. This reduces the size of the temporary array by half.

Wiring it up

The handler for our rpc protocol is a typical rust actor. The actor has the mem storage as well as the routing table as state, and processes messages one by one. If the storage was persistent, you might want to perform the actual storage and retrieval as well as the sending of the response stream in a background task, but for now it is all sequential.

There are some background tasks to update the routing table to add new nodes and forget unreachable nodes, but these are omitted for now.

struct Node {
    routing_table: RoutingTable,
    storage: MemStorage,
}

struct Actor {
    node: Node,
    /// receiver for rpc messages from the network
    rpc_rx: tokio::sync::mpsc::Receiver<RpcMessage>,
    ... more plumbing for background tasks
}

impl Actor {
    async fn run(mut self) {
        loop {
            tokio::select! {
                msg = self.rpc_rx.recv() => {
                    if let Some(msg) = msg {
                        self.handle_rpc(msg).await;
                    } else {
                        break;
                    }
                }
                ... other background tasks and stuff
            }
        }
    }

    async fn handle_rpc(&mut self, message: RpcMessage) {
        match message {
            RpcMessage::Set(msg) => {
                /// msg validation omitted
                self.node.storage.set(msg.key, msg.value.clone());
                msg.tx.send(SetResponse::Ok).await.ok();
            }
            RpcMessage::GetAll(msg) => {
                let Some(values) = self.node.storage.get_all(&msg.key, &msg.kind) else {
                    return;
                };
                // sampling values and randomizing omitted
                for value in values {
                    if msg.tx.send(value.clone()).await.is_err() {
                        break;
                    }
                }
            }
            RpcMessage::FindNode(msg) => {
                // call local find_node and just return the results
                let ids = self
                    .node
                    .routing_table
                    .find_closest_nodes(&msg.id, self.state.config.k)
                    .into_iter()
                    .map(|id| self.state.pool.node_addr(id))
                    .collect();
                msg.tx.send(ids).await.ok();
            }
        }
    }
}

Set is trivial. It just sets the value and returns Ok to the requester. There is some logic to validate the value based on the key, but this has been omitted here.

GetAll is a bit more complex. It queries the storage for values, then does some limiting and randomizing (omitted here), and the streams out the responses.

FindNode queries the routing table and gets back a sequence of node ids. It then augments this information with dialing information from the connection pool (a wrapper around an iroh endpoint) and sends out the response all at once.

What we have now is an actor that stores values and maintains a routing table. All rpc operations are fully local, there is no way for a remote node to trigger something expensive.

The next step is to implement the iterative lookup algorithm. Once we have that, storage and retrieval are just calls to the k closest nodes to a key that are the result of the iterative lookup algorithm.

Both storage and retrieval involve a lot of network operations. To hide all these details from the user, we will need a message based protocol that the DHT client uses to communicate with the DHT actor. This will also be an irpc protocol, but it will be used either in memory or to control a DHT node running in a different process on a local machine, so it does not have to concern itself as much with having small messages and with adversarial scenarios.

We also don't have to care about stability, since this will be used only between the same version of the binary.

As mentined, the main complexity of a DHT is the routing. What values we store almost does not matter, as long as they can be validated somehow and are small enough to fit. So for testing, we are going to implement just storage of immutable small blobs.

We need the ability to store and retrieve such blobs, and for the user facing API we don't care about nodes. All these details are for the DHT to sort out internally. So let's design the API.

The API protocol will also contain internal messages that the DHT needs for periodic tasks. We can just hide them from the public API wrapper if we don't want our users to mess with internals.

#[rpc_requests(message = ApiMessage)]
#[derive(Debug, Serialize, Deserialize)]
pub enum ApiProto {
    #[rpc(wrap, tx = mpsc::Sender<NodeId>)]
    NetworkPut { id: Id, value: Value },
    #[rpc(wrap, tx = mpsc::Sender<(NodeId, Value)>)]
    NetworkGet { id: Id, kind: Kind },
    ... plumbing rpc calls
}

We need the ability to store and retrieve values.

Storing values is a two step process, first use the iterative algorithm to find the k closest nodes, then, in parallel, try to store the value on all these nodes. To give the user some feedback over where the data is stored, we return a stream of node ids where the data was successfully stored.

Retrieval is almost identical. We first find the k closest nodes, then, in parallel, ask all of them for the value. Again we return a stream of (NodeId, Value) so we can get answers to the user as soon as they become available.

In case of immutable values, the first validated value is all it takes, as soon as we got that we can abort the operation. For other values we might want to wait for all results and then choose the most recent one, or use them all, e.g. to retrieve content over iroh-blobs from multiple sources.

Here is the ApiClient for the get_immutable and put_immutable rpc calls:

async fn put_immutable(
    &self,
    value: &[u8],
) -> irpc::Result<(blake3::Hash, Vec<NodeId>)> {
    let hash = blake3::hash(value);
    let id = Id::from(*hash.as_bytes());
    let mut rx = self
        .0
        .server_streaming(
            NetworkPut {
                id,
                value: Value::Blake3Immutable(Blake3Immutable {
                    timestamp: now(),
                    data: value.to_vec(),
                }),
            },
            32,
        )
        .await?;
    let mut res = Vec::new();
    loop {
        match rx.recv().await {
            Ok(Some(id)) => res.push(id),
            Ok(None) => break,
            Err(_) => {}
        }
    }
    Ok((hash, res))
}

async fn get_immutable(&self, hash: blake3::Hash) -> irpc::Result<Option<Vec<u8>>> {
    let id = Id::from(*hash.as_bytes());
    let mut rx = self
        .0
        .server_streaming(
            NetworkGet {
                id,
                kind: Kind::Blake3Immutable,
            },
            32,
        )
        .await?;
    loop {
        match rx.recv().await {
            Ok(Some((_, value))) => {
                let Value::Blake3Immutable(Blake3Immutable { data, .. }) = value else {
                    continue; // Skip non-Blake3Immutable values
                };
                if blake3::hash(&data) == hash {
                    return Ok(Some(data));
                } else {
                    continue; // Hash mismatch, skip this value
                }
            }
            Ok(None) => {
                break Ok(None);
            }
            Err(e) => {
                break Err(e.into());
            }
        }
    }
}

put_immutable aggregates all node ids where the data was stored and returns them. You could have a different API where you don't wait for storage on all nodes. get_immutable just returns after the first correct value - at this point you have the correct data and there is no point in waiting for more of the same.

The first phase of the implementation of both NetworkPut and NetworkGet is powered by the iterative lookup algo. Since this is used both externally to store and retrieve values and internally to perform random lookups to maintain the routing table, it gets its own RPC call.

pub enum ApiProto {
    ...
    #[rpc(wrap, tx = oneshot::Sender<Vec<NodeId>>)]
    Lookup {
        initial: Option<Vec<NodeId>>,
        id: Id,
    },
    ...
}
...
impl ApiClient {
    async fn lookup(
        &self,
        id: Id,
        initial: Option<Vec<NodeId>>,
    ) -> irpc::Result<Vec<NodeId>> {
        self.0.rpc(Lookup { id, initial }).await
    }
}

Lookup gets an initial set of node ids to start, and an id to look up. It returns the k closest nodes to the id. All the internals of the iterative lookup algorithm are hidden.

The plumbing to process this is not that interesting. But let's take a look at the iterative lookup algorithm itself now:

async fn iterative_find_node(self, target: Id, initial: Vec<NodeId>) -> Vec<NodeId> {
    let mut candidates = initial
        .into_iter()
        .filter(|addr| *addr != self.pool.id())
        .map(|id| (Distance::between(&target, &id.as_bytes()), id))
        .collect::<BTreeSet<_>>();
    let mut queried = HashSet::new();
    let mut tasks = FuturesUnordered::new();
    let mut result = BTreeSet::new();
    queried.insert(self.pool.id());
    result.insert((
        Distance::between(&self.pool.id().as_bytes(), &target),
        self.pool.id(),
    ));

    loop {
        for _ in 0..self.config.alpha {
            let Some(pair @ (_, id)) = candidates.pop_first() else {
                break;
            };
            queried.insert(id);
            let fut = self.query_one(id, target);
            tasks.push(async move { (pair, fut.await) });
        }

        while let Some((pair @ (_, id), cands)) = tasks.next().await {
            let Ok(cands) = cands else {
                self.api.nodes_dead(&[id]).await.ok();
                continue;
            };
            for cand in cands {
                let dist = Distance::between(&target, &cand.as_bytes());
                if !queried.contains(&cand) {
                    candidates.insert((dist, cand));
                }
            }
            self.api.nodes_seen(&[id]).await.ok();
            result.insert(pair);
        }

        // truncate the result to k.
        while result.len() > self.config.k {
            result.pop_last();
        }

        // find the k-th best distance
        let kth_best_distance = result
            .iter()
            .nth(self.config.k - 1)
            .map(|(dist, _)| *dist)
            .unwrap_or(Distance::MAX);

        // true if we candidates that are better than distance for result[k-1].
        let has_closer_candidates = candidates
            .first()
            .map(|(dist, _)| *dist < kth_best_distance)
            .unwrap_or_default();

        if !has_closer_candidates {
            break;
        }
    }

    // result already has size <= k
    result.into_iter().map(|(_, id)| id).collect()
}

The algorithm maintains a set of candidates sorted by distance to the target. This set is initially populated from the local routing table, or can be passed in by the user.

It also maintains a set of nodes that were already queried to prevent running around in circles. The id of the node itself is added to the queried set to prevent self queries. It is also added to the result set for the case where the node itself is in the set of closest nodes.

Last but not least, it maintains a set of result nodes sorted by distance to the target. The only way for a node to end up in the result set is if it has actually answered a query, so is proven to be alive from the point of view of the local node at the time of the query.

The algorithm makes no assumptions whatsoever about the candidates, not even that they exist. So it pulls candidates in distance order with a configurable parallelism level alpha and validates them by performing a FindNode query on them for the target key. Ongoing FindNode queries are tracked in a [FuturesUnordered].

As FindNode tasks complete, we get information about whether the candidate is alive or not.

For unreachable candidates, we drop them but also inform the actor that the node id should be removed from the routing table. For reachable nodes, we add the node we called to the result and the nodes the node returned to the candidates set. We can't add them to the result yet since they might be stale, or the node might lie to us. We also call nodes_seen to update the timestamp of the node in the routing table.

We only ever keep the best k nodes in the result set. The abort criterion is the most tricky part. We want to abort if none of the candidates is better than the kth best node in the result set, or if we have run out of candidates to try.

There might be an off by one error somewhere in there, but the algorithm seems to work well. The algorithm does not do any pings to check node liveness, since a Ping request and a FindNode request are identical in terms of latency and cost, and a FindNode request does something useful.

NetworkGet and NetworkPut

The iterative lookup algorithm is really the most complex part, but for completeness sake here is how the actual NetworkGet and NetworkPut rpc calls are implemented. We first get good initial candidates from the local routing table. Even a pure client node will after some time accumulate some knowledge about the network that is useful for finding good places to start the search.

async fn handle_api(&mut self, message: ApiMessage) {
    match message {
        ...
        ApiMessage::NetworkGet(msg) => {
            let initial = self.node.routing_table.find_closest_nodes(&msg.id, K);
            self.tasks.spawn(self.state.clone().network_get(initial, msg.inner, msg.tx));
        }
        ...
    }
}

The network_get itself, which is run in a task, then just calls the iterative_find_node fn above to find the k closest nodes, and then attempts to get the value from all of them with a configurable parallelism level, returning results to the caller immediately as they arrive.

The network_put fn looks extremely similar.

async fn network_get(
    self,
    initial: Vec<NodeId>,
    msg: NetworkGet,
    tx: mpsc::Sender<(NodeId, Value)>,
) {
    let ids = self.clone().iterative_find_node(msg.id, initial).await;
    stream::iter(ids)
        .for_each_concurrent(self.config.parallelism, |id| {
            let pool = self.pool.clone();
            let tx = tx.clone();
            let msg = NetworkGet { id: msg.id, kind: msg.kind };
            async move {
                let Ok(client) = pool.client(id).await else {
                    return;
                };
                // Get all values of the specified kind for the key
                let Ok(mut rx) = client.get_all(msg.id, msg.kind).await else {
                    return;
                };
                while let Ok(Some(value)) = rx.recv().await {
                    if tx.send((id, value)).await.is_err() {
                        break;
                    }
                }
                drop(client);
            }
        })
        .await;
}
Iroh is a dial-any-device networking library that just works. Compose from an ecosystem of ready-made protocols to get the features you need, or go fully custom on a clean abstraction over dumb pipes. Iroh is open source, and already running in production on hundreds of thousands of devices.
To get started, take a look at our docs, dive directly into the code, or chat with us in our discord channel.