Skip to content

Commit cb6ddfc

Browse files
authored
feat(sync): total difficulty stage (#665)
* feat(sync): total difficulty stage * linter * rm commented log * patch current td in headers tests
1 parent dafc01d commit cb6ddfc

File tree

6 files changed

+239
-57
lines changed

6 files changed

+239
-57
lines changed

bin/reth/src/config.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ impl Config {
4747
pub struct StageConfig {
4848
/// Header stage configuration.
4949
pub headers: HeadersConfig,
50+
/// Total difficulty stage configuration
51+
pub total_difficulty: TotalDifficultyConfig,
5052
/// Body stage configuration.
5153
pub bodies: BodiesConfig,
5254
/// Sender recovery stage configuration.
@@ -70,6 +72,20 @@ impl Default for HeadersConfig {
7072
}
7173
}
7274

75+
/// Total difficulty stage configuration
76+
#[derive(Debug, Clone, Deserialize, Serialize)]
77+
pub struct TotalDifficultyConfig {
78+
/// The maximum number of total difficulty entries to sum up before committing progress to the
79+
/// database.
80+
pub commit_threshold: u64,
81+
}
82+
83+
impl Default for TotalDifficultyConfig {
84+
fn default() -> Self {
85+
Self { commit_threshold: 100_000 }
86+
}
87+
}
88+
7389
/// Body stage configuration.
7490
#[derive(Debug, Clone, Deserialize, Serialize)]
7591
pub struct BodiesConfig {

bin/reth/src/node/mod.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use reth_stages::{
2525
metrics::HeaderMetrics,
2626
stages::{
2727
bodies::BodyStage, execution::ExecutionStage, headers::HeaderStage,
28-
sender_recovery::SenderRecoveryStage,
28+
sender_recovery::SenderRecoveryStage, total_difficulty::TotalDifficultyStage,
2929
},
3030
};
3131
use std::{net::SocketAddr, path::Path, sync::Arc};
@@ -130,6 +130,9 @@ impl Command {
130130
commit_threshold: config.stages.headers.commit_threshold,
131131
metrics: HeaderMetrics::default(),
132132
})
133+
.push(TotalDifficultyStage {
134+
commit_threshold: config.stages.total_difficulty.commit_threshold,
135+
})
133136
.push(BodyStage {
134137
downloader: Arc::new(
135138
bodies::concurrent::ConcurrentDownloader::new(

crates/stages/src/stages/headers.rs

Lines changed: 11 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,9 @@ const HEADERS: StageId = StageId("Headers");
3636
/// - [`HeaderNumbers`][reth_interfaces::db::tables::HeaderNumbers]
3737
/// - [`Headers`][reth_interfaces::db::tables::Headers]
3838
/// - [`CanonicalHeaders`][reth_interfaces::db::tables::CanonicalHeaders]
39-
/// - [`HeaderTD`][reth_interfaces::db::tables::HeaderTD]
4039
///
41-
/// NOTE: This stage commits the header changes to the database (everything except the changes to
42-
/// [`HeaderTD`][reth_interfaces::db::tables::HeaderTD] table). The stage does not return the
43-
/// control flow to the pipeline in order to preserve the context of the chain tip.
40+
/// NOTE: This stage downloads headers in reverse. Upon returning the control flow to the pipeline,
41+
/// the stage progress is not updated unless this stage is done.
4442
#[derive(Debug)]
4543
pub struct HeaderStage<D: HeaderDownloader, C: Consensus, H: HeadersClient, S: StatusUpdater> {
4644
/// Strategy for downloading the headers
@@ -101,9 +99,6 @@ impl<DB: Database, D: HeaderDownloader, C: Consensus, H: HeadersClient, S: Statu
10199
self.write_headers::<DB>(tx, res).await?.unwrap_or_default();
102100

103101
if self.is_stage_done(tx, current_progress).await? {
104-
// Update total difficulty values after we have reached fork choice
105-
debug!(target: "sync::stages::headers", head = ?head.hash(), "Writing total difficulty");
106-
self.write_td::<DB>(tx, &head)?;
107102
let stage_progress = current_progress.max(
108103
tx.cursor::<tables::CanonicalHeaders>()?
109104
.last()?
@@ -147,7 +142,6 @@ impl<DB: Database, D: HeaderDownloader, C: Consensus, H: HeadersClient, S: Statu
147142
)?;
148143
tx.unwind_table_by_num::<tables::CanonicalHeaders>(input.unwind_to)?;
149144
tx.unwind_table_by_num_hash::<tables::Headers>(input.unwind_to)?;
150-
tx.unwind_table_by_num_hash::<tables::HeaderTD>(input.unwind_to)?;
151145
Ok(UnwindOutput { stage_progress: input.unwind_to })
152146
}
153147
}
@@ -280,33 +274,6 @@ impl<D: HeaderDownloader, C: Consensus, H: HeadersClient, S: StatusUpdater>
280274
}
281275
Ok(latest)
282276
}
283-
284-
/// Iterate over inserted headers and write td entries
285-
fn write_td<DB: Database>(
286-
&self,
287-
tx: &Transaction<'_, DB>,
288-
head: &SealedHeader,
289-
) -> Result<(), StageError> {
290-
// Acquire cursor over total difficulty table
291-
let mut cursor_td = tx.cursor_mut::<tables::HeaderTD>()?;
292-
293-
// Get latest total difficulty
294-
let last_entry = cursor_td
295-
.seek_exact(head.num_hash().into())?
296-
.ok_or(DatabaseIntegrityError::TotalDifficulty { number: head.number })?;
297-
let mut td: U256 = last_entry.1.into();
298-
299-
// Start at first inserted block during this iteration
300-
let start_key = tx.get_block_numhash(head.number + 1)?;
301-
302-
// Walk over newly inserted headers, update & insert td
303-
for entry in tx.cursor::<tables::Headers>()?.walk(start_key)? {
304-
let (key, header) = entry?;
305-
td += header.difficulty;
306-
cursor_td.append(key, td.into())?;
307-
}
308-
Ok(())
309-
}
310277
}
311278

312279
#[cfg(test)]
@@ -472,7 +439,11 @@ mod tests {
472439
},
473440
ExecInput, ExecOutput, UnwindInput,
474441
};
475-
use reth_db::{models::blocks::BlockNumHash, tables, transaction::DbTx};
442+
use reth_db::{
443+
models::blocks::BlockNumHash,
444+
tables,
445+
transaction::{DbTx, DbTxMut},
446+
};
476447
use reth_downloaders::headers::linear::{LinearDownloadBuilder, LinearDownloader};
477448
use reth_interfaces::{
478449
p2p::headers::downloader::HeaderDownloader,
@@ -533,6 +504,10 @@ mod tests {
533504
let start = input.stage_progress.unwrap_or_default();
534505
let head = random_header(start, None);
535506
self.tx.insert_headers(std::iter::once(&head))?;
507+
// patch td table for `update_head` call
508+
self.tx.commit(|tx| {
509+
tx.put::<tables::HeaderTD>(head.num_hash().into(), U256::zero().into())
510+
})?;
536511

537512
// use previous progress as seed size
538513
let end = input.previous_stage.map(|(_, num)| num).unwrap_or_default() + 1;
@@ -571,18 +546,6 @@ mod tests {
571546
assert!(header.is_some());
572547
let header = header.unwrap().seal();
573548
assert_eq!(header.hash(), hash);
574-
575-
// validate td consistency in the database
576-
if header.number > initial_stage_progress {
577-
let parent_td = tx.get::<tables::HeaderTD>(
578-
(header.number - 1, header.parent_hash).into(),
579-
)?;
580-
let td: U256 = *tx.get::<tables::HeaderTD>(key)?.unwrap();
581-
assert_eq!(
582-
parent_td.map(|td| *td + header.difficulty),
583-
Some(td)
584-
);
585-
}
586549
}
587550
Ok(())
588551
})?;
@@ -639,7 +602,6 @@ mod tests {
639602
.check_no_entry_above_by_value::<tables::HeaderNumbers, _>(block, |val| val)?;
640603
self.tx.check_no_entry_above::<tables::CanonicalHeaders, _>(block, |key| key)?;
641604
self.tx.check_no_entry_above::<tables::Headers, _>(block, |key| key.number())?;
642-
self.tx.check_no_entry_above::<tables::HeaderTD, _>(block, |key| key.number())?;
643605
Ok(())
644606
}
645607
}

crates/stages/src/stages/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,5 @@ pub mod execution;
66
pub mod headers;
77
/// The sender recovery stage.
88
pub mod sender_recovery;
9+
/// The total difficulty stage
10+
pub mod total_difficulty;
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
use crate::{
2+
db::Transaction, DatabaseIntegrityError, ExecInput, ExecOutput, Stage, StageError, StageId,
3+
UnwindInput, UnwindOutput,
4+
};
5+
use reth_db::{
6+
cursor::{DbCursorRO, DbCursorRW},
7+
database::Database,
8+
tables,
9+
transaction::DbTxMut,
10+
};
11+
use reth_primitives::U256;
12+
use tracing::*;
13+
14+
const TOTAL_DIFFICULTY: StageId = StageId("TotalDifficulty");
15+
16+
/// The total difficulty stage.
17+
///
18+
/// This stage walks over inserted headers and computes total difficulty
19+
/// at each block. The entries are inserted into [`HeaderTD`][reth_interfaces::db::tables::HeaderTD]
20+
/// table.
21+
#[derive(Debug)]
22+
pub struct TotalDifficultyStage {
23+
/// The number of table entries to commit at once
24+
pub commit_threshold: u64,
25+
}
26+
27+
#[async_trait::async_trait]
28+
impl<DB: Database> Stage<DB> for TotalDifficultyStage {
29+
/// Return the id of the stage
30+
fn id(&self) -> StageId {
31+
TOTAL_DIFFICULTY
32+
}
33+
34+
/// Write total difficulty entries
35+
async fn execute(
36+
&mut self,
37+
tx: &mut Transaction<'_, DB>,
38+
input: ExecInput,
39+
) -> Result<ExecOutput, StageError> {
40+
let stage_progress = input.stage_progress.unwrap_or_default();
41+
let previous_stage_progress = input.previous_stage_progress();
42+
43+
let start_block = stage_progress + 1;
44+
let end_block = previous_stage_progress.min(start_block + self.commit_threshold);
45+
46+
if start_block > end_block {
47+
info!(target: "sync::stages::total_difficulty", stage_progress, "Target block already reached");
48+
return Ok(ExecOutput { stage_progress, done: true })
49+
}
50+
51+
debug!(target: "sync::stages::total_difficulty", start_block, end_block, "Commencing sync");
52+
53+
// Acquire cursor over total difficulty and headers tables
54+
let mut cursor_td = tx.cursor_mut::<tables::HeaderTD>()?;
55+
let mut cursor_headers = tx.cursor_mut::<tables::Headers>()?;
56+
57+
// Get latest total difficulty
58+
let last_header_key = tx.get_block_numhash(stage_progress)?;
59+
let last_entry = cursor_td
60+
.seek_exact(last_header_key)?
61+
.ok_or(DatabaseIntegrityError::TotalDifficulty { number: last_header_key.number() })?;
62+
63+
let mut td: U256 = last_entry.1.into();
64+
debug!(target: "sync::stages::total_difficulty", ?td, block_number = last_header_key.number(), "Last total difficulty entry");
65+
66+
let start_key = tx.get_block_numhash(start_block)?;
67+
let walker = cursor_headers
68+
.walk(start_key)?
69+
.take_while(|e| e.as_ref().map(|(_, h)| h.number <= end_block).unwrap_or_default());
70+
// Walk over newly inserted headers, update & insert td
71+
for entry in walker {
72+
let (key, header) = entry?;
73+
td += header.difficulty;
74+
cursor_td.append(key, td.into())?;
75+
}
76+
77+
let done = end_block >= previous_stage_progress;
78+
info!(target: "sync::stages::total_difficulty", stage_progress = end_block, done, "Sync iteration finished");
79+
Ok(ExecOutput { done, stage_progress: end_block })
80+
}
81+
82+
/// Unwind the stage.
83+
async fn unwind(
84+
&mut self,
85+
tx: &mut Transaction<'_, DB>,
86+
input: UnwindInput,
87+
) -> Result<UnwindOutput, Box<dyn std::error::Error + Send + Sync>> {
88+
tx.unwind_table_by_num_hash::<tables::HeaderTD>(input.unwind_to)?;
89+
Ok(UnwindOutput { stage_progress: input.unwind_to })
90+
}
91+
}
92+
93+
#[cfg(test)]
94+
mod tests {
95+
use reth_db::transaction::DbTx;
96+
use reth_interfaces::test_utils::generators::{random_header, random_header_range};
97+
use reth_primitives::{BlockNumber, SealedHeader};
98+
99+
use super::*;
100+
use crate::test_utils::{
101+
stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, TestRunnerError,
102+
TestTransaction, UnwindStageTestRunner,
103+
};
104+
105+
stage_test_suite_ext!(TotalDifficultyTestRunner);
106+
107+
#[derive(Default)]
108+
struct TotalDifficultyTestRunner {
109+
tx: TestTransaction,
110+
}
111+
112+
impl StageTestRunner for TotalDifficultyTestRunner {
113+
type S = TotalDifficultyStage;
114+
115+
fn tx(&self) -> &TestTransaction {
116+
&self.tx
117+
}
118+
119+
fn stage(&self) -> Self::S {
120+
TotalDifficultyStage { commit_threshold: 500 }
121+
}
122+
}
123+
124+
#[async_trait::async_trait]
125+
impl ExecuteStageTestRunner for TotalDifficultyTestRunner {
126+
type Seed = Vec<SealedHeader>;
127+
128+
fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
129+
let start = input.stage_progress.unwrap_or_default();
130+
let head = random_header(start, None);
131+
self.tx.insert_headers(std::iter::once(&head))?;
132+
self.tx.commit(|tx| {
133+
let td: U256 = tx
134+
.cursor::<tables::HeaderTD>()?
135+
.last()?
136+
.map(|(_, v)| v)
137+
.unwrap_or_default()
138+
.into();
139+
tx.put::<tables::HeaderTD>(head.num_hash().into(), (td + head.difficulty).into())
140+
})?;
141+
142+
// use previous progress as seed size
143+
let end = input.previous_stage.map(|(_, num)| num).unwrap_or_default() + 1;
144+
145+
if start + 1 >= end {
146+
return Ok(Vec::default())
147+
}
148+
149+
let mut headers = random_header_range(start + 1..end, head.hash());
150+
self.tx.insert_headers(headers.iter())?;
151+
headers.insert(0, head);
152+
Ok(headers)
153+
}
154+
155+
/// Validate stored headers
156+
fn validate_execution(
157+
&self,
158+
input: ExecInput,
159+
output: Option<ExecOutput>,
160+
) -> Result<(), TestRunnerError> {
161+
let initial_stage_progress = input.stage_progress.unwrap_or_default();
162+
match output {
163+
Some(output) if output.stage_progress > initial_stage_progress => {
164+
self.tx.query(|tx| {
165+
let start_hash = tx
166+
.get::<tables::CanonicalHeaders>(initial_stage_progress)?
167+
.expect("no initial header hash");
168+
let start_key = (initial_stage_progress, start_hash).into();
169+
let mut header_cursor = tx.cursor::<tables::Headers>()?;
170+
let (_, mut current_header) =
171+
header_cursor.seek_exact(start_key)?.expect("no initial header");
172+
let mut td: U256 =
173+
tx.get::<tables::HeaderTD>(start_key)?.expect("no initial td").into();
174+
175+
while let Some((next_key, next_header)) = header_cursor.next()? {
176+
assert_eq!(current_header.number + 1, next_header.number);
177+
td += next_header.difficulty;
178+
assert_eq!(
179+
tx.get::<tables::HeaderTD>(next_key)?.map(Into::into),
180+
Some(td)
181+
);
182+
current_header = next_header;
183+
}
184+
Ok(())
185+
})?;
186+
}
187+
_ => self.check_no_td_above(initial_stage_progress)?,
188+
};
189+
Ok(())
190+
}
191+
}
192+
193+
impl UnwindStageTestRunner for TotalDifficultyTestRunner {
194+
fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> {
195+
self.check_no_td_above(input.unwind_to)
196+
}
197+
}
198+
199+
impl TotalDifficultyTestRunner {
200+
fn check_no_td_above(&self, block: BlockNumber) -> Result<(), TestRunnerError> {
201+
self.tx.check_no_entry_above::<tables::HeaderTD, _>(block, |key| key.number())?;
202+
Ok(())
203+
}
204+
}
205+
}

0 commit comments

Comments
 (0)