Skip to content

Commit 3bc1928

Browse files
authored
Limit network actions to tracked chains (in main) (#2393)
* Add a `tracked_chains` field to `ChainWorkerState` Prepare to only create network actions for the set of tracked chains. * Add a `tracked_chains` field to `WorkerState` Share it with the created chain worker actors. * Add a `tracked_chains` field to `Client` type Keep track of the chains that client is interested in. * Select tracked chains when starting Specify which chains should be tracked by a new `Client`. * Forward tracked chains to `WorkerState` Configure the worker based on the client's selection. * Only create network actions for tracked chains Avoid handling chains that aren't interesting to the client. * Add `retry_pending_cross_chain_requests` helper Allow resending messages intended for chains that weren't tracked when the outgoing message was scheduled, but became tracked later. * Add `Client::track_chain` method Allow adding more chains to the initial set of tracked chains. * Track newly created chains Ensure that chains that the client open are tracked. * Ensure newly assigned chain is tracked So that the worker can properly handle it. * Track chains used in benchmark Ensure that they are properly executed during the benchmark. * Add a TODO to merge tracked chains set Remember to replace the quick-fix with a more comprehensive refactor. * Track chains created during block execution Check all executed blocks for messages that open new chains, and add the new chain IDs to the set of tracked chains.
1 parent a07a3af commit 3bc1928

File tree

10 files changed

+144
-15
lines changed

10 files changed

+144
-15
lines changed

linera-client/src/client_context.rs

+7
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ where
146146
storage,
147147
options.max_pending_messages,
148148
delivery,
149+
wallet.chain_ids(),
149150
);
150151

151152
ClientContext {
@@ -519,9 +520,15 @@ where
519520
.expect("failed to create new chain");
520521
let chain_id = ChainId::child(message_id);
521522
key_pairs.insert(chain_id, key_pair.copy());
523+
self.client.track_chain(chain_id);
522524
self.update_wallet_for_new_chain(chain_id, Some(key_pair.copy()), timestamp);
523525
}
524526
}
527+
let updated_chain_client = self.make_chain_client(default_chain_id);
528+
updated_chain_client
529+
.retry_pending_outgoing_messages()
530+
.await
531+
.context("outgoing messages to create the new chains should be delivered")?;
525532

526533
for chain_id in key_pairs.keys() {
527534
let child_client = self.make_chain_client(*chain_id);

linera-client/src/unit_tests/chain_listener.rs

+1
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ async fn test_chain_listener() -> anyhow::Result<()> {
150150
storage.clone(),
151151
10,
152152
delivery,
153+
[chain_id0],
153154
)),
154155
};
155156
let key_pair = KeyPair::generate_from(&mut rng);

linera-core/src/chain_worker/actor.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
//! An actor that runs a chain worker.
55
66
use std::{
7+
collections::HashSet,
78
fmt::{self, Debug, Formatter},
8-
sync::Arc,
9+
sync::{Arc, RwLock},
910
};
1011

1112
use linera_base::{
@@ -151,6 +152,7 @@ where
151152
storage: StorageClient,
152153
certificate_value_cache: Arc<ValueCache<CryptoHash, HashedCertificateValue>>,
153154
blob_cache: Arc<ValueCache<BlobId, Blob>>,
155+
tracked_chains: Option<Arc<RwLock<HashSet<ChainId>>>>,
154156
chain_id: ChainId,
155157
) -> Result<Self, WorkerError> {
156158
let (service_runtime_thread, execution_state_receiver, runtime_request_sender) =
@@ -161,6 +163,7 @@ where
161163
storage,
162164
certificate_value_cache,
163165
blob_cache,
166+
tracked_chains,
164167
chain_id,
165168
execution_state_receiver,
166169
runtime_request_sender,

linera-core/src/chain_worker/state/attempted_changes.rs

+1
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ where
346346
tip.num_outgoing_messages += executed_block.outcome.messages.len() as u32;
347347
self.state.chain.confirmed_log.push(certificate.hash());
348348
let info = ChainInfoResponse::new(&self.state.chain, self.state.config.key_pair());
349+
self.state.track_newly_created_chains(executed_block);
349350
let mut actions = self.state.create_network_actions().await?;
350351
actions.notifications.push(Notification {
351352
chain_id: block.chain_id,

linera-core/src/chain_worker/state/mod.rs

+39-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ mod temporary_changes;
99
use std::{
1010
borrow::Cow,
1111
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
12-
sync::Arc,
12+
sync::{self, Arc},
1313
};
1414

1515
use linera_base::{
@@ -26,8 +26,8 @@ use linera_chain::{
2626
ChainError, ChainStateView,
2727
};
2828
use linera_execution::{
29-
committee::Epoch, ExecutionRequest, Query, QueryContext, Response, ServiceRuntimeRequest,
30-
UserApplicationDescription, UserApplicationId,
29+
committee::Epoch, ExecutionRequest, Message, Query, QueryContext, Response,
30+
ServiceRuntimeRequest, SystemMessage, UserApplicationDescription, UserApplicationId,
3131
};
3232
use linera_storage::Storage;
3333
use linera_views::views::{ClonableView, ViewError};
@@ -60,6 +60,7 @@ where
6060
runtime_request_sender: std::sync::mpsc::Sender<ServiceRuntimeRequest>,
6161
recent_hashed_certificate_values: Arc<ValueCache<CryptoHash, HashedCertificateValue>>,
6262
recent_blobs: Arc<ValueCache<BlobId, Blob>>,
63+
tracked_chains: Option<Arc<sync::RwLock<HashSet<ChainId>>>>,
6364
knows_chain_is_active: bool,
6465
}
6566

@@ -69,11 +70,13 @@ where
6970
ViewError: From<StorageClient::StoreError>,
7071
{
7172
/// Creates a new [`ChainWorkerState`] using the provided `storage` client.
73+
#[allow(clippy::too_many_arguments)]
7274
pub async fn load(
7375
config: ChainWorkerConfig,
7476
storage: StorageClient,
7577
certificate_value_cache: Arc<ValueCache<CryptoHash, HashedCertificateValue>>,
7678
blob_cache: Arc<ValueCache<BlobId, Blob>>,
79+
tracked_chains: Option<Arc<sync::RwLock<HashSet<ChainId>>>>,
7780
chain_id: ChainId,
7881
execution_state_receiver: futures::channel::mpsc::UnboundedReceiver<ExecutionRequest>,
7982
runtime_request_sender: std::sync::mpsc::Sender<ServiceRuntimeRequest>,
@@ -89,6 +92,7 @@ where
8992
runtime_request_sender,
9093
recent_hashed_certificate_values: certificate_value_cache,
9194
recent_blobs: blob_cache,
95+
tracked_chains,
9296
knows_chain_is_active: false,
9397
})
9498
}
@@ -369,11 +373,41 @@ where
369373
self.recent_blobs.insert(blob).await
370374
}
371375

376+
/// Adds any newly created chains to the set of `tracked_chains`.
377+
fn track_newly_created_chains(&self, block: &ExecutedBlock) {
378+
if let Some(tracked_chains) = self.tracked_chains.as_ref() {
379+
let messages = block.messages().iter().flatten();
380+
let open_chain_message_indices =
381+
messages
382+
.enumerate()
383+
.filter_map(|(index, outgoing_message)| match outgoing_message.message {
384+
Message::System(SystemMessage::OpenChain(_)) => Some(index),
385+
_ => None,
386+
});
387+
let open_chain_message_ids =
388+
open_chain_message_indices.map(|index| block.message_id(index as u32));
389+
let new_chain_ids = open_chain_message_ids.map(ChainId::child);
390+
391+
tracked_chains
392+
.write()
393+
.expect("Panics should not happen while holding a lock to `tracked_chains`")
394+
.extend(new_chain_ids);
395+
}
396+
}
397+
372398
/// Loads pending cross-chain requests.
373399
async fn create_network_actions(&self) -> Result<NetworkActions, WorkerError> {
374400
let mut heights_by_recipient: BTreeMap<_, BTreeMap<_, _>> = Default::default();
375-
let pairs = self.chain.outboxes.try_load_all_entries().await?;
376-
for (target, outbox) in pairs {
401+
let mut targets = self.chain.outboxes.indices().await?;
402+
if let Some(tracked_chains) = self.tracked_chains.as_ref() {
403+
let tracked_chains = tracked_chains
404+
.read()
405+
.expect("Panics should not happen while holding a lock to `tracked_chains`");
406+
targets.retain(|target| tracked_chains.contains(&target.recipient));
407+
}
408+
let outboxes = self.chain.outboxes.try_load_entries(&targets).await?;
409+
for (target, outbox) in targets.into_iter().zip(outboxes) {
410+
let outbox = outbox.expect("Only existing outboxes should be referenced by `indices`");
377411
let heights = outbox.queue.elements().await?;
378412
heights_by_recipient
379413
.entry(target.recipient)

linera-core/src/client.rs

+39-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use std::{
77
convert::Infallible,
88
iter,
99
ops::{Deref, DerefMut},
10-
sync::Arc,
10+
sync::{Arc, RwLock},
1111
};
1212

1313
use dashmap::{
@@ -91,6 +91,9 @@ where
9191
message_policy: MessagePolicy,
9292
/// Whether to block on cross-chain message delivery.
9393
cross_chain_message_delivery: CrossChainMessageDelivery,
94+
/// Chains that should be tracked by the client.
95+
// TODO(#2412): Merge with set of chains the client is receiving notifications from validators
96+
tracked_chains: Arc<RwLock<HashSet<ChainId>>>,
9497
/// References to clients waiting for chain notifications.
9598
notifier: Arc<Notifier<Notification>>,
9699
/// A copy of the storage client so that we don't have to lock the local node client
@@ -111,10 +114,16 @@ where
111114
storage: S,
112115
max_pending_messages: usize,
113116
cross_chain_message_delivery: CrossChainMessageDelivery,
117+
tracked_chains: impl IntoIterator<Item = ChainId>,
114118
) -> Self {
115-
let state = WorkerState::new_for_client("Client node".to_string(), storage.clone())
116-
.with_allow_inactive_chains(true)
117-
.with_allow_messages_from_deprecated_epochs(true);
119+
let tracked_chains = Arc::new(RwLock::new(tracked_chains.into_iter().collect()));
120+
let state = WorkerState::new_for_client(
121+
"Client node".to_string(),
122+
storage.clone(),
123+
tracked_chains.clone(),
124+
)
125+
.with_allow_inactive_chains(true)
126+
.with_allow_messages_from_deprecated_epochs(true);
118127
let local_node = LocalNodeClient::new(state);
119128

120129
Self {
@@ -124,6 +133,7 @@ where
124133
max_pending_messages,
125134
message_policy: MessagePolicy::new(BlanketMessagePolicy::Accept, None),
126135
cross_chain_message_delivery,
136+
tracked_chains,
127137
notifier: Arc::new(Notifier::default()),
128138
storage,
129139
}
@@ -141,6 +151,15 @@ where
141151
&self.local_node
142152
}
143153

154+
#[tracing::instrument(level = "trace", skip(self))]
155+
/// Adds a chain to the set of chains tracked by the local node.
156+
pub fn track_chain(&self, chain_id: ChainId) {
157+
self.tracked_chains
158+
.write()
159+
.expect("Panics should not happen while holding a lock to `tracked_chains`")
160+
.insert(chain_id);
161+
}
162+
144163
#[tracing::instrument(level = "trace", skip_all, fields(chain_id, next_block_height))]
145164
/// Creates a new `ChainClient`.
146165
#[allow(clippy::too_many_arguments)]
@@ -2501,6 +2520,12 @@ where
25012520
executed_block.message_id_for_operation(0, OPEN_CHAIN_MESSAGE_INDEX)
25022521
})
25032522
.ok_or_else(|| ChainClientError::InternalError("Failed to create new chain"))?;
2523+
// Add the new chain to the list of tracked chains
2524+
self.client.track_chain(ChainId::child(message_id));
2525+
self.client
2526+
.local_node
2527+
.retry_pending_cross_chain_requests(self.chain_id)
2528+
.await?;
25042529
return Ok(ClientOutcome::Committed((message_id, certificate)));
25052530
}
25062531
}
@@ -2791,6 +2816,16 @@ where
27912816
.await
27922817
}
27932818

2819+
#[tracing::instrument(level = "trace")]
2820+
/// Handles any cross-chain requests for any pending outgoing messages.
2821+
pub async fn retry_pending_outgoing_messages(&self) -> Result<(), ChainClientError> {
2822+
self.client
2823+
.local_node
2824+
.retry_pending_cross_chain_requests(self.chain_id)
2825+
.await?;
2826+
Ok(())
2827+
}
2828+
27942829
#[tracing::instrument(level = "trace", skip(from, limit))]
27952830
pub async fn read_hashed_certificate_values_downward(
27962831
&self,

linera-core/src/local_node.rs

+20-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
use std::{
66
borrow::Cow,
7-
collections::{HashMap, HashSet},
7+
collections::{HashMap, HashSet, VecDeque},
88
sync::Arc,
99
};
1010

@@ -570,4 +570,23 @@ where
570570
}
571571
}
572572
}
573+
574+
/// Handles any pending local cross-chain requests.
575+
#[tracing::instrument(level = "trace", skip(self))]
576+
pub async fn retry_pending_cross_chain_requests(
577+
&self,
578+
sender_chain: ChainId,
579+
) -> Result<(), LocalNodeError> {
580+
let (_response, actions) = self
581+
.node
582+
.state
583+
.handle_chain_info_query(ChainInfoQuery::new(sender_chain))
584+
.await?;
585+
let mut requests = VecDeque::from_iter(actions.cross_chain_requests);
586+
while let Some(request) = requests.pop_front() {
587+
let new_actions = self.node.state.handle_cross_chain_request(request).await?;
588+
requests.extend(new_actions.cross_chain_requests);
589+
}
590+
Ok(())
591+
}
573592
}

linera-core/src/unit_tests/test_utils.rs

+1
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,7 @@ where
796796
storage,
797797
10,
798798
CrossChainMessageDelivery::NonBlocking,
799+
[chain_id],
799800
));
800801
Ok(builder.create_chain_client(
801802
chain_id,

linera-core/src/worker.rs

+31-4
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
use std::{
66
borrow::Cow,
7-
collections::{hash_map, BTreeMap, HashMap, VecDeque},
7+
collections::{hash_map, BTreeMap, HashMap, HashSet, VecDeque},
88
num::NonZeroUsize,
9-
sync::{Arc, LazyLock, Mutex},
9+
sync::{Arc, LazyLock, Mutex, RwLock},
1010
time::Duration,
1111
};
1212

@@ -235,6 +235,8 @@ where
235235
recent_hashed_certificate_values: Arc<ValueCache<CryptoHash, HashedCertificateValue>>,
236236
/// Cached blobs by `BlobId`.
237237
recent_blobs: Arc<ValueCache<BlobId, Blob>>,
238+
/// Chain IDs that should be tracked by a worker.
239+
tracked_chains: Option<Arc<RwLock<HashSet<ChainId>>>>,
238240
/// One-shot channels to notify callers when messages of a particular chain have been
239241
/// delivered.
240242
delivery_notifiers: Arc<Mutex<DeliveryNotifiers>>,
@@ -264,15 +266,30 @@ where
264266
chain_worker_config: ChainWorkerConfig::default().with_key_pair(key_pair),
265267
recent_hashed_certificate_values: Arc::new(ValueCache::default()),
266268
recent_blobs: Arc::new(ValueCache::default()),
269+
tracked_chains: None,
267270
delivery_notifiers: Arc::default(),
268271
chain_worker_tasks: Arc::default(),
269272
chain_workers: Arc::new(Mutex::new(LruCache::new(*CHAIN_WORKER_LIMIT))),
270273
}
271274
}
272275

273276
#[tracing::instrument(level = "trace", skip(nickname, storage))]
274-
pub fn new_for_client(nickname: String, storage: StorageClient) -> Self {
275-
Self::new(nickname, None, storage)
277+
pub fn new_for_client(
278+
nickname: String,
279+
storage: StorageClient,
280+
tracked_chains: Arc<RwLock<HashSet<ChainId>>>,
281+
) -> Self {
282+
WorkerState {
283+
nickname,
284+
storage,
285+
chain_worker_config: ChainWorkerConfig::default(),
286+
recent_hashed_certificate_values: Arc::new(ValueCache::default()),
287+
recent_blobs: Arc::new(ValueCache::default()),
288+
tracked_chains: Some(tracked_chains),
289+
delivery_notifiers: Arc::default(),
290+
chain_worker_tasks: Arc::default(),
291+
chain_workers: Arc::new(Mutex::new(LruCache::new(*CHAIN_WORKER_LIMIT))),
292+
}
276293
}
277294

278295
#[tracing::instrument(level = "trace", skip(self, value))]
@@ -288,6 +305,15 @@ where
288305
self
289306
}
290307

308+
/// Configures the subset of chains that this worker is tracking.
309+
pub fn with_tracked_chains(
310+
mut self,
311+
tracked_chains: impl IntoIterator<Item = ChainId>,
312+
) -> Self {
313+
self.tracked_chains = Some(Arc::new(RwLock::new(tracked_chains.into_iter().collect())));
314+
self
315+
}
316+
291317
/// Returns an instance with the specified grace period, in microseconds.
292318
///
293319
/// Blocks with a timestamp this far in the future will still be accepted, but the validator
@@ -665,6 +691,7 @@ where
665691
self.storage.clone(),
666692
self.recent_hashed_certificate_values.clone(),
667693
self.recent_blobs.clone(),
694+
self.tracked_chains.clone(),
668695
chain_id,
669696
)
670697
.await?;

linera-service/src/linera/main.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1100,6 +1100,7 @@ impl Job {
11001100
ViewError: From<S::StoreError>,
11011101
{
11021102
let state = WorkerState::new("Local node".to_string(), None, storage)
1103+
.with_tracked_chains([message_id.chain_id, chain_id])
11031104
.with_allow_inactive_chains(true)
11041105
.with_allow_messages_from_deprecated_epochs(true);
11051106
let node_client = LocalNodeClient::new(state);

0 commit comments

Comments
 (0)