subcoin_mempool/
inner.rs

1//! Inner mempool state protected by RwLock.
2
3use crate::arena::MemPoolArena;
4use crate::types::{EntryId, RemovalReason};
5use bitcoin::{Amount, OutPoint as COutPoint, Txid, Wtxid};
6use std::collections::{HashMap, HashSet};
7
8/// Inner mempool state (protected by RwLock in main MemPool).
9pub struct MemPoolInner {
10    /// Arena-based entry storage with multi-index support.
11    pub(crate) arena: MemPoolArena,
12
13    /// Track which outputs are spent by mempool transactions.
14    /// Maps outpoint -> txid that spends it (for conflict detection).
15    pub(crate) map_next_tx: HashMap<COutPoint, Txid>,
16
17    /// Priority deltas (external to entries, applied via prioritise_transaction).
18    /// Maps txid -> fee delta.
19    #[allow(dead_code)]
20    pub(crate) map_deltas: HashMap<Txid, Amount>,
21
22    /// Randomized list for transaction relay (getdata responses).
23    /// Contains (wtxid, entry_id) pairs shuffled for privacy.
24    pub(crate) txns_randomized: Vec<(Wtxid, EntryId)>,
25
26    /// Transactions not yet broadcast to peers.
27    pub(crate) unbroadcast: HashSet<Txid>,
28
29    // === Statistics (updated on add/remove) ===
30    /// Total size of all transactions in bytes.
31    pub(crate) total_tx_size: u64,
32
33    /// Total fees of all transactions.
34    pub(crate) total_fee: Amount,
35
36    /// Rolling minimum fee rate for mempool acceptance.
37    /// Updated when mempool is trimmed to size.
38    #[allow(dead_code)]
39    pub(crate) rolling_minimum_feerate: u64,
40
41    /// Last time rolling fee was updated.
42    #[allow(dead_code)]
43    pub(crate) last_rolling_fee_update: i64,
44}
45
46impl MemPoolInner {
47    /// Create new empty mempool inner state.
48    pub fn new() -> Self {
49        Self {
50            arena: MemPoolArena::new(),
51            map_next_tx: HashMap::new(),
52            map_deltas: HashMap::new(),
53            txns_randomized: Vec::new(),
54            unbroadcast: HashSet::new(),
55            total_tx_size: 0,
56            total_fee: Amount::ZERO,
57            rolling_minimum_feerate: 0,
58            last_rolling_fee_update: 0,
59        }
60    }
61
62    /// Get entry by txid.
63    pub fn get_entry(&self, txid: &Txid) -> Option<&crate::arena::TxMemPoolEntry> {
64        let entry_id = self.arena.get_by_txid(txid)?;
65        self.arena.get(entry_id)
66    }
67
68    /// Check if transaction exists in mempool by txid.
69    pub fn contains_txid(&self, txid: &Txid) -> bool {
70        self.arena.get_by_txid(txid).is_some()
71    }
72
73    /// Check if transaction exists in mempool by wtxid.
74    pub fn contains_wtxid(&self, wtxid: &Wtxid) -> bool {
75        self.arena.get_by_wtxid(wtxid).is_some()
76    }
77
78    /// Get transaction that spends the given outpoint (conflict detection).
79    pub fn get_conflict_tx(&self, outpoint: &COutPoint) -> Option<Txid> {
80        self.map_next_tx.get(outpoint).copied()
81    }
82
83    /// Calculate descendants of a transaction (recursively).
84    ///
85    /// Returns set of all descendant entry IDs (including the starting entry).
86    pub fn calculate_descendants(&self, entry_id: EntryId, descendants: &mut HashSet<EntryId>) {
87        if !descendants.insert(entry_id) {
88            return; // Already visited
89        }
90
91        if let Some(entry) = self.arena.get(entry_id) {
92            for &child_id in &entry.children {
93                self.calculate_descendants(child_id, descendants);
94            }
95        }
96    }
97
98    /// Calculate ancestors of a transaction (recursively).
99    ///
100    /// Returns set of all ancestor entry IDs (including the starting entry).
101    pub fn calculate_ancestors(&self, entry_id: EntryId, ancestors: &mut HashSet<EntryId>) {
102        if !ancestors.insert(entry_id) {
103            return; // Already visited
104        }
105
106        if let Some(entry) = self.arena.get(entry_id) {
107            for &parent_id in &entry.parents {
108                self.calculate_ancestors(parent_id, ancestors);
109            }
110        }
111    }
112
113    /// Remove transactions and update state.
114    ///
115    /// This is the core removal function that:
116    /// - Removes entries from arena
117    /// - Updates map_next_tx
118    /// - Updates statistics
119    /// - Optionally updates ancestor/descendant state for surviving children (RBF)
120    pub fn remove_staged(
121        &mut self,
122        to_remove: &HashSet<EntryId>,
123        update_descendants: bool,
124        _reason: RemovalReason,
125    ) {
126        // Phase 1: Collect metadata for surviving children BEFORE any removal
127        let surviving_children_metadata = if update_descendants {
128            collect_surviving_children_metadata(&self.arena, to_remove)
129        } else {
130            Vec::new()
131        };
132
133        // Phase 2: Remove entries from arena
134        for &entry_id in to_remove {
135            if let Some(entry) = self.arena.remove(entry_id) {
136                // Update statistics
137                self.total_tx_size -= entry.tx_weight.to_wu();
138                self.total_fee =
139                    Amount::from_sat(self.total_fee.to_sat().saturating_sub(entry.fee.to_sat()));
140
141                // Remove from map_next_tx
142                for input in &entry.tx.input {
143                    self.map_next_tx.remove(&input.previous_output);
144                }
145
146                // Remove from unbroadcast set
147                self.unbroadcast.remove(&entry.tx.compute_txid());
148            }
149        }
150
151        // Phase 3: Update surviving children (AFTER removal)
152        if update_descendants {
153            for child_metadata in surviving_children_metadata {
154                update_child_for_removed_ancestors(&mut self.arena, child_metadata);
155            }
156        }
157
158        // Rebuild randomized transaction list
159        self.rebuild_randomized_list();
160    }
161
162    /// Trim mempool to maximum size by evicting lowest-feerate transactions.
163    pub fn trim_to_size(&mut self, max_size: u64) {
164        while self.total_tx_size > max_size {
165            // Get lowest feerate descendant cluster
166            let Some((entry_id, _)) = self.arena.iter_by_descendant_score().next() else {
167                break;
168            };
169
170            // Collect all descendants
171            let mut to_remove = HashSet::new();
172            self.calculate_descendants(entry_id, &mut to_remove);
173
174            // Remove cluster
175            self.remove_staged(&to_remove, false, RemovalReason::SizeLimit);
176        }
177    }
178
179    /// Expire old transactions.
180    pub fn expire(&mut self, current_time: i64, max_age_seconds: i64) {
181        let cutoff_time = current_time - max_age_seconds;
182        let mut to_remove = HashSet::new();
183
184        // Find all transactions older than cutoff
185        for (entry_id, entry) in self.arena.iter_by_entry_time() {
186            if entry.time < cutoff_time {
187                to_remove.insert(entry_id);
188            } else {
189                break; // Sorted by time, rest are newer
190            }
191        }
192
193        if !to_remove.is_empty() {
194            self.remove_staged(&to_remove, false, RemovalReason::Expiry);
195        }
196    }
197
198    /// Rebuild the randomized transaction list for relay.
199    fn rebuild_randomized_list(&mut self) {
200        self.txns_randomized.clear();
201
202        for (entry_id, entry) in self.arena.iter_by_ancestor_score() {
203            self.txns_randomized
204                .push((entry.tx.compute_wtxid(), entry_id));
205        }
206
207        // TODO: Shuffle for privacy (use a proper RNG)
208        // For now, we keep the ancestor score order
209    }
210
211    /// Get total number of transactions.
212    pub fn size(&self) -> usize {
213        self.arena.len()
214    }
215
216    /// Get total transaction size in bytes.
217    pub fn total_size(&self) -> u64 {
218        self.total_tx_size
219    }
220
221    /// Get total fees.
222    pub fn total_fees(&self) -> Amount {
223        self.total_fee
224    }
225}
226
227/// Metadata for a child whose ancestors are being removed.
228///
229/// Captured BEFORE removal to preserve parent graph state.
230struct SurvivingChildMetadata {
231    child_id: EntryId,
232    removed_ancestor_stats: AncestorStats,
233}
234
235/// Aggregated statistics for ancestors being removed.
236#[derive(Default)]
237struct AncestorStats {
238    count: i64,
239    size: i64,
240    fees: bitcoin::SignedAmount,
241    sigops: i64,
242}
243
244/// Collect metadata for surviving children before removal.
245///
246/// A "surviving child" is a transaction that:
247/// - Has at least one parent in `to_remove`
248/// - Has at least one parent NOT in `to_remove` (survives)
249///
250/// Returns metadata for each surviving child, capturing the stats of
251/// ancestors that will be truly removed (not reachable through surviving parents).
252fn collect_surviving_children_metadata(
253    arena: &MemPoolArena,
254    to_remove: &HashSet<EntryId>,
255) -> Vec<SurvivingChildMetadata> {
256    let mut result = Vec::new();
257
258    // Find all children of transactions being removed
259    let mut potential_survivors = HashSet::new();
260    for &entry_id in to_remove {
261        if let Some(entry) = arena.get(entry_id) {
262            for &child_id in &entry.children {
263                // Only consider children that are NOT being removed
264                if !to_remove.contains(&child_id) {
265                    potential_survivors.insert(child_id);
266                }
267            }
268        }
269    }
270
271    // For each surviving child, calculate which ancestors are truly removed
272    for child_id in potential_survivors {
273        let removed_stats = calculate_truly_removed_ancestors(arena, child_id, to_remove);
274
275        result.push(SurvivingChildMetadata {
276            child_id,
277            removed_ancestor_stats: removed_stats,
278        });
279    }
280
281    result
282}
283
284/// Calculate statistics for ancestors that are truly being removed.
285///
286/// An ancestor is "truly removed" if:
287/// - It's in the `to_remove` set
288/// - It's NOT reachable through any surviving parent
289///
290/// This prevents double-counting shared ancestors.
291fn calculate_truly_removed_ancestors(
292    arena: &MemPoolArena,
293    child_id: EntryId,
294    to_remove: &HashSet<EntryId>,
295) -> AncestorStats {
296    // Step 1: Find all ancestors of this child
297    let mut all_ancestors = HashSet::new();
298    calculate_ancestors_recursive(arena, child_id, &mut all_ancestors);
299    all_ancestors.remove(&child_id); // Exclude the child itself
300
301    // Step 2: Find ancestors reachable through surviving parents
302    let mut reachable_through_survivors = HashSet::new();
303    if let Some(child_entry) = arena.get(child_id) {
304        for &parent_id in &child_entry.parents {
305            // If parent is NOT being removed, it's a surviving parent
306            if !to_remove.contains(&parent_id) {
307                // Add all ancestors of this surviving parent
308                calculate_ancestors_recursive(arena, parent_id, &mut reachable_through_survivors);
309            }
310        }
311    }
312
313    // Step 3: Truly removed = (all ancestors ∩ to_remove) - reachable_through_survivors
314    let mut stats = AncestorStats::default();
315
316    for ancestor_id in all_ancestors {
317        // Must be in removal set AND not reachable through survivors
318        if to_remove.contains(&ancestor_id) && !reachable_through_survivors.contains(&ancestor_id) {
319            if let Some(entry) = arena.get(ancestor_id) {
320                stats.count += 1;
321                stats.size += entry.vsize();
322                stats.fees = bitcoin::SignedAmount::from_sat(
323                    stats.fees.to_sat() + entry.modified_fee.to_sat() as i64,
324                );
325                stats.sigops += entry.sigop_cost;
326            }
327        }
328    }
329
330    stats
331}
332
333/// Recursively collect all ancestors (including the starting entry).
334fn calculate_ancestors_recursive(
335    arena: &MemPoolArena,
336    entry_id: EntryId,
337    ancestors: &mut HashSet<EntryId>,
338) {
339    if !ancestors.insert(entry_id) {
340        return; // Already visited
341    }
342
343    if let Some(entry) = arena.get(entry_id) {
344        for &parent_id in &entry.parents {
345            calculate_ancestors_recursive(arena, parent_id, ancestors);
346        }
347    }
348}
349
350/// Update a surviving child after its ancestors have been removed.
351///
352/// Applies negative deltas to subtract the removed ancestors from the child's
353/// cached ancestor state.
354///
355/// **CRITICAL**: Must be called AFTER entries are removed from arena.
356fn update_child_for_removed_ancestors(arena: &mut MemPoolArena, metadata: SurvivingChildMetadata) {
357    let stats = metadata.removed_ancestor_stats;
358
359    // Apply negative deltas to the child
360    arena.update_ancestor_state(
361        metadata.child_id,
362        -stats.size,
363        -stats.fees,
364        -stats.count,
365        -stats.sigops,
366    );
367
368    // Also update all descendants of this child (they lose the same ancestors)
369    let mut descendants = HashSet::new();
370    if let Some(entry) = arena.get(metadata.child_id) {
371        for &desc_id in &entry.children {
372            calculate_descendants_recursive(arena, desc_id, &mut descendants);
373        }
374    }
375
376    for desc_id in descendants {
377        arena.update_ancestor_state(
378            desc_id,
379            -stats.size,
380            -stats.fees,
381            -stats.count,
382            -stats.sigops,
383        );
384    }
385}
386
387/// Recursively collect all descendants (including the starting entry).
388fn calculate_descendants_recursive(
389    arena: &MemPoolArena,
390    entry_id: EntryId,
391    descendants: &mut HashSet<EntryId>,
392) {
393    if !descendants.insert(entry_id) {
394        return; // Already visited
395    }
396
397    if let Some(entry) = arena.get(entry_id) {
398        for &child_id in &entry.children {
399            calculate_descendants_recursive(arena, child_id, descendants);
400        }
401    }
402}
403
404impl Default for MemPoolInner {
405    fn default() -> Self {
406        Self::new()
407    }
408}