Skip to content

Commit f1a9434

Browse files
refactor(pre-compute-app): remove unnecessary Option<> fields
1 parent 64fffb8 commit f1a9434

File tree

2 files changed

+72
-97
lines changed

2 files changed

+72
-97
lines changed

src/compute/app_runner.rs

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::api::worker_api::{ExitMessage, WorkerApiClient};
22
use crate::compute::pre_compute_app::{PreComputeApp, PreComputeAppTrait};
3+
use crate::compute::pre_compute_args::PreComputeArgs;
34
use crate::compute::{
45
errors::ReplicateStatusCause,
56
signer::get_challenge,
@@ -27,10 +28,6 @@ pub enum ExitMode {
2728
/// It uses the provided app to execute core operations and handles all the
2829
/// workflow states and transitions.
2930
///
30-
/// # Arguments
31-
///
32-
/// * `pre_compute_app` - An implementation of [`PreComputeAppTrait`] that will be used to execute the pre-compute operations.
33-
///
3431
/// # Example
3532
///
3633
/// ```
@@ -40,20 +37,10 @@ pub enum ExitMode {
4037
/// let pre_compute_app = PreComputeApp::new();
4138
/// let exit_code = start_with_app(pre_compute_app);
4239
/// ```
43-
pub fn start_with_app<A: PreComputeAppTrait>(pre_compute_app: &mut A) -> ExitMode {
44-
info!("TEE pre-compute started");
45-
40+
pub fn start_with_app<A: PreComputeAppTrait>(pre_compute_app: &A, chain_task_id: &str) -> ExitMode {
4641
let exit_cause = ReplicateStatusCause::PreComputeFailedUnknownIssue;
47-
let chain_task_id =
48-
match get_env_var_or_error(IexecTaskId, ReplicateStatusCause::PreComputeTaskIdMissing) {
49-
Ok(id) => id,
50-
Err(e) => {
51-
error!("TEE pre-compute cannot proceed without taskID context: {e:?}");
52-
return ExitMode::InitializationFailure;
53-
}
54-
};
5542

56-
match pre_compute_app.run(&chain_task_id) {
43+
match pre_compute_app.run() {
5744
Ok(_) => {
5845
info!("TEE pre-compute completed");
5946
return ExitMode::Success;
@@ -63,7 +50,7 @@ pub fn start_with_app<A: PreComputeAppTrait>(pre_compute_app: &mut A) -> ExitMod
6350
}
6451
}
6552

66-
let authorization = match get_challenge(&chain_task_id) {
53+
let authorization = match get_challenge(chain_task_id) {
6754
Ok(auth) => auth,
6855
Err(_) => {
6956
error!("Failed to sign exitCause message [{exit_cause:?}]");
@@ -77,7 +64,7 @@ pub fn start_with_app<A: PreComputeAppTrait>(pre_compute_app: &mut A) -> ExitMod
7764

7865
match WorkerApiClient::from_env().send_exit_cause_for_pre_compute_stage(
7966
&authorization,
80-
&chain_task_id,
67+
chain_task_id,
8168
&exit_message,
8269
) {
8370
Ok(_) => ExitMode::ReportedFailure,
@@ -102,8 +89,23 @@ pub fn start_with_app<A: PreComputeAppTrait>(pre_compute_app: &mut A) -> ExitMod
10289
/// std::process::exit(exit_code);
10390
/// ```
10491
pub fn start() -> ExitMode {
105-
let mut pre_compute_app = PreComputeApp::new();
106-
start_with_app(&mut pre_compute_app)
92+
info!("TEE pre-compute started");
93+
94+
let chain_task_id =
95+
match get_env_var_or_error(IexecTaskId, ReplicateStatusCause::PreComputeTaskIdMissing) {
96+
Ok(id) => id,
97+
Err(e) => {
98+
error!("TEE pre-compute cannot proceed without taskID context: {e:?}");
99+
return ExitMode::InitializationFailure;
100+
}
101+
};
102+
let pre_compute_args = match PreComputeArgs::read_args() {
103+
Ok(pre_compute_args) => pre_compute_args,
104+
Err(_) => { return ExitMode::InitializationFailure; }
105+
};
106+
107+
let pre_compute_app = PreComputeApp::new(chain_task_id.clone(),pre_compute_args);
108+
start_with_app(&pre_compute_app,&chain_task_id)
107109
}
108110

109111
#[cfg(test)]
@@ -124,6 +126,17 @@ mod pre_compute_start_with_app_tests {
124126
const ENV_SIGN_TEE_CHALLENGE_PRIVATE_KEY: &str = "SIGN_TEE_CHALLENGE_PRIVATE_KEY";
125127
const ENV_WORKER_HOST: &str = "WORKER_HOST_ENV_VAR";
126128

129+
#[test]
130+
fn start_fails_when_read_args_fails() {
131+
temp_env::with_vars([(ENV_IEXEC_TASK_ID, Some(CHAIN_TASK_ID))], || {
132+
assert_eq!(
133+
start(),
134+
ExitMode::InitializationFailure,
135+
"Should return 3 if IEXEC_TASK_ID is missing"
136+
);
137+
});
138+
}
139+
127140
#[test]
128141
fn start_fails_when_task_id_missing() {
129142
temp_env::with_vars_unset(vec![ENV_IEXEC_TASK_ID], || {
@@ -148,13 +161,12 @@ mod pre_compute_start_with_app_tests {
148161

149162
let mut mock = MockPreComputeAppTrait::new();
150163
mock.expect_run()
151-
.withf(|chain_task_id| chain_task_id == CHAIN_TASK_ID)
152-
.returning(|_| Err(ReplicateStatusCause::PreComputeWorkerAddressMissing));
164+
.returning(|| Err(ReplicateStatusCause::PreComputeWorkerAddressMissing));
153165

154166
temp_env::with_vars(env_vars_to_set, || {
155167
temp_env::with_vars_unset(env_vars_to_unset, || {
156168
assert_eq!(
157-
start_with_app(&mut mock),
169+
start_with_app(&mock,CHAIN_TASK_ID),
158170
ExitMode::UnreportedFailure,
159171
"Should return 2 if get_challenge fails due to missing signer address"
160172
);
@@ -172,13 +184,12 @@ mod pre_compute_start_with_app_tests {
172184

173185
let mut mock = MockPreComputeAppTrait::new();
174186
mock.expect_run()
175-
.withf(|chain_task_id| chain_task_id == CHAIN_TASK_ID)
176-
.returning(|_| Err(ReplicateStatusCause::PreComputeTeeChallengePrivateKeyMissing));
187+
.returning(|| Err(ReplicateStatusCause::PreComputeTeeChallengePrivateKeyMissing));
177188

178189
temp_env::with_vars(env_vars_to_set, || {
179190
temp_env::with_vars_unset(env_vars_to_unset, || {
180191
assert_eq!(
181-
start_with_app(&mut mock),
192+
start_with_app(&mock, CHAIN_TASK_ID),
182193
ExitMode::UnreportedFailure,
183194
"Should return 2 if get_challenge fails due to missing private key"
184195
);
@@ -200,8 +211,7 @@ mod pre_compute_start_with_app_tests {
200211

201212
let mut mock = MockPreComputeAppTrait::new();
202213
mock.expect_run()
203-
.withf(|chain_task_id| chain_task_id == CHAIN_TASK_ID)
204-
.returning(|_| Err(ReplicateStatusCause::PreComputeTeeChallengePrivateKeyMissing));
214+
.returning(|| Err(ReplicateStatusCause::PreComputeTeeChallengePrivateKeyMissing));
205215

206216
let result_code = tokio::task::spawn_blocking(move || {
207217
let env_vars = vec![
@@ -214,7 +224,7 @@ mod pre_compute_start_with_app_tests {
214224
(ENV_WORKER_HOST, Some(mock_server_addr_string.as_str())),
215225
];
216226

217-
temp_env::with_vars(env_vars, || start_with_app(&mut mock))
227+
temp_env::with_vars(env_vars, || start_with_app(&mock, CHAIN_TASK_ID))
218228
})
219229
.await
220230
.expect("Blocking task panicked");
@@ -247,8 +257,7 @@ mod pre_compute_start_with_app_tests {
247257

248258
let mut mock = MockPreComputeAppTrait::new();
249259
mock.expect_run()
250-
.withf(|chain_task_id| chain_task_id == CHAIN_TASK_ID)
251-
.returning(|_| Err(ReplicateStatusCause::PreComputeTeeChallengePrivateKeyMissing));
260+
.returning(|| Err(ReplicateStatusCause::PreComputeTeeChallengePrivateKeyMissing));
252261

253262
// Move the blocking operations into spawn_blocking
254263
let result_code = tokio::task::spawn_blocking(move || {
@@ -262,7 +271,7 @@ mod pre_compute_start_with_app_tests {
262271
(ENV_WORKER_HOST, Some(mock_server_addr_string.as_str())),
263272
];
264273

265-
temp_env::with_vars(env_vars, || start_with_app(&mut mock))
274+
temp_env::with_vars(env_vars, || start_with_app(&mock, CHAIN_TASK_ID))
266275
})
267276
.await
268277
.expect("Blocking task panicked");

src/compute/pre_compute_app.rs

Lines changed: 31 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ const AES_IV_LENGTH: usize = 16;
2626

2727
#[cfg_attr(test, automock)]
2828
pub trait PreComputeAppTrait {
29-
fn run(&mut self, chain_task_id: &str) -> Result<(), ReplicateStatusCause>;
29+
fn run(&self) -> Result<(), ReplicateStatusCause>;
3030
fn check_output_folder(&self) -> Result<(), ReplicateStatusCause>;
3131
fn download_input_files(&self) -> Result<(), ReplicateStatusCause>;
3232
fn download_encrypted_dataset(&self) -> Result<Vec<u8>, ReplicateStatusCause>;
@@ -35,25 +35,23 @@ pub trait PreComputeAppTrait {
3535
}
3636

3737
pub struct PreComputeApp {
38-
chain_task_id: Option<String>,
39-
pre_compute_args: Option<PreComputeArgs>,
38+
chain_task_id: String,
39+
pre_compute_args: PreComputeArgs,
4040
}
4141

4242
impl PreComputeApp {
43-
pub fn new() -> Self {
43+
pub fn new(chain_task_id: String, pre_compute_args: PreComputeArgs) -> Self {
4444
PreComputeApp {
45-
chain_task_id: None,
46-
pre_compute_args: None,
45+
chain_task_id,
46+
pre_compute_args,
4747
}
4848
}
4949
}
5050

5151
impl PreComputeAppTrait for PreComputeApp {
52-
fn run(&mut self, chain_task_id: &str) -> Result<(), ReplicateStatusCause> {
53-
self.chain_task_id = Some(chain_task_id.to_string());
54-
self.pre_compute_args = Some(PreComputeArgs::read_args()?);
52+
fn run(&self) -> Result<(), ReplicateStatusCause> {
5553
self.check_output_folder()?;
56-
if self.pre_compute_args.as_ref().unwrap().is_dataset_required {
54+
if self.pre_compute_args.is_dataset_required {
5755
let encrypted_content = self.download_encrypted_dataset()?;
5856
let plain_content = self.decrypt_dataset(&encrypted_content)?;
5957
self.save_plain_dataset_file(&plain_content)?;
@@ -82,14 +80,8 @@ impl PreComputeAppTrait for PreComputeApp {
8280
/// pre_compute_app.check_output_folder()?;
8381
/// ```
8482
fn check_output_folder(&self) -> Result<(), ReplicateStatusCause> {
85-
let output_dir = self
86-
.pre_compute_args
87-
.as_ref()
88-
.ok_or(ReplicateStatusCause::PreComputeOutputFolderNotFound)?
89-
.output_dir
90-
.clone();
91-
92-
let chain_task_id = self.chain_task_id.as_deref().unwrap_or("unknown");
83+
let output_dir: &str = &self.pre_compute_args.output_dir;
84+
let chain_task_id: &str = &self.chain_task_id;
9385

9486
info!("Checking output folder [chainTaskId:{chain_task_id}, path:{output_dir}]");
9587

@@ -130,8 +122,8 @@ impl PreComputeAppTrait for PreComputeApp {
130122
/// pre_compute_app.download_input_files()?;
131123
/// ```
132124
fn download_input_files(&self) -> Result<(), ReplicateStatusCause> {
133-
let args = self.pre_compute_args.as_ref().unwrap();
134-
let chain_task_id = self.chain_task_id.as_ref().unwrap();
125+
let args = &self.pre_compute_args;
126+
let chain_task_id: &str = &self.chain_task_id;
135127

136128
for url in &args.input_files {
137129
info!("Downloading input file [chainTaskId:{chain_task_id}, url:{url}]");
@@ -162,8 +154,8 @@ impl PreComputeAppTrait for PreComputeApp {
162154
/// app.download_encrypted_dataset()?;
163155
/// ```
164156
fn download_encrypted_dataset(&self) -> Result<Vec<u8>, ReplicateStatusCause> {
165-
let args = self.pre_compute_args.as_ref().unwrap();
166-
let chain_task_id = self.chain_task_id.as_ref().unwrap();
157+
let args = &self.pre_compute_args;
158+
let chain_task_id = &self.chain_task_id;
167159
let encrypted_dataset_url = args.encrypted_dataset_url.as_ref().unwrap();
168160

169161
info!(
@@ -233,8 +225,6 @@ impl PreComputeAppTrait for PreComputeApp {
233225
fn decrypt_dataset(&self, encrypted_content: &[u8]) -> Result<Vec<u8>, ReplicateStatusCause> {
234226
let base64_key = self
235227
.pre_compute_args
236-
.as_ref()
237-
.unwrap()
238228
.encrypted_dataset_base64_key
239229
.as_ref()
240230
.unwrap();
@@ -280,10 +270,10 @@ impl PreComputeAppTrait for PreComputeApp {
280270
/// app.save_plain_dataset_file(&plain_data)?;
281271
/// ```
282272
fn save_plain_dataset_file(&self, plain_dataset: &[u8]) -> Result<(), ReplicateStatusCause> {
283-
let chain_task_id = self.chain_task_id.as_ref().unwrap();
284-
let args = self.pre_compute_args.as_ref().unwrap();
285-
let output_dir = &args.output_dir;
286-
let plain_dataset_filename = args.plain_dataset_filename.as_ref().unwrap();
273+
let chain_task_id: &str = &self.chain_task_id;
274+
let args = &self.pre_compute_args;
275+
let output_dir: &str = &args.output_dir;
276+
let plain_dataset_filename: &str = args.plain_dataset_filename.as_ref().unwrap();
287277

288278
let mut path = PathBuf::from(output_dir);
289279
path.push(plain_dataset_filename);
@@ -330,16 +320,16 @@ mod tests {
330320
output_dir: &str,
331321
) -> PreComputeApp {
332322
PreComputeApp {
333-
chain_task_id: Some(chain_task_id.to_string()),
334-
pre_compute_args: Some(PreComputeArgs {
323+
chain_task_id: chain_task_id.to_string(),
324+
pre_compute_args: PreComputeArgs {
335325
input_files: urls.into_iter().map(String::from).collect(),
336326
output_dir: output_dir.to_string(),
337327
is_dataset_required: true,
338328
encrypted_dataset_url: Some(HTTP_DATASET_URL.to_string()),
339329
encrypted_dataset_base64_key: Some(ENCRYPTED_DATASET_KEY.to_string()),
340330
encrypted_dataset_checksum: Some(DATASET_CHECKSUM.to_string()),
341331
plain_dataset_filename: Some(PLAIN_DATA_FILE.to_string()),
342-
}),
332+
},
343333
}
344334
}
345335

@@ -383,19 +373,6 @@ mod tests {
383373
);
384374
}
385375

386-
#[test]
387-
fn check_output_folder_returns_err_with_invalid_pre_compute_args() {
388-
let app = PreComputeApp {
389-
chain_task_id: Some(CHAIN_TASK_ID.to_string()),
390-
pre_compute_args: None,
391-
};
392-
393-
let result = app.check_output_folder();
394-
assert_eq!(
395-
result,
396-
Err(ReplicateStatusCause::PreComputeOutputFolderNotFound)
397-
);
398-
}
399376
// endregion
400377

401378
// region download_input_files
@@ -500,9 +477,7 @@ mod tests {
500477
#[test]
501478
fn download_encrypted_dataset_failure_with_invalid_dataset_url() {
502479
let mut app = get_pre_compute_app(CHAIN_TASK_ID, vec![], "");
503-
if let Some(args) = &mut app.pre_compute_args {
504-
args.encrypted_dataset_url = Some("http://bad-url".to_string());
505-
}
480+
app.pre_compute_args.encrypted_dataset_url = Some("http://bad-url".to_string());
506481
let actual_content = app.download_encrypted_dataset();
507482
assert_eq!(
508483
actual_content,
@@ -513,12 +488,9 @@ mod tests {
513488
#[test]
514489
fn download_encrypted_dataset_success_with_valid_iexec_gateway() {
515490
let mut app = get_pre_compute_app(CHAIN_TASK_ID, vec![], "");
516-
if let Some(args) = &mut app.pre_compute_args {
517-
args.encrypted_dataset_url = Some(IPFS_DATASET_URL.to_string());
518-
args.encrypted_dataset_checksum = Some(
519-
"0x323b1637c7999942fbebfe5d42fe15dbfe93737577663afa0181938d7ad4a2ac".to_string(),
520-
)
521-
}
491+
app.pre_compute_args.encrypted_dataset_url = Some(IPFS_DATASET_URL.to_string());
492+
app.pre_compute_args.encrypted_dataset_checksum =
493+
Some("0x323b1637c7999942fbebfe5d42fe15dbfe93737577663afa0181938d7ad4a2ac".to_string());
522494
let actual_content = app.download_encrypted_dataset();
523495
let expected_content = Ok("hello world !\n".as_bytes().to_vec());
524496
assert_eq!(actual_content, expected_content);
@@ -527,9 +499,8 @@ mod tests {
527499
#[test]
528500
fn download_encrypted_dataset_failure_with_invalid_gateway() {
529501
let mut app = get_pre_compute_app(CHAIN_TASK_ID, vec![], "");
530-
if let Some(args) = &mut app.pre_compute_args {
531-
args.encrypted_dataset_url = Some("/ipfs/INVALID_IPFS_DATASET_URL".to_string());
532-
}
502+
app.pre_compute_args.encrypted_dataset_url =
503+
Some("/ipfs/INVALID_IPFS_DATASET_URL".to_string());
533504
let actual_content = app.download_encrypted_dataset();
534505
let expected_content = Err(ReplicateStatusCause::PreComputeDatasetDownloadFailed);
535506
assert_eq!(actual_content, expected_content);
@@ -538,9 +509,8 @@ mod tests {
538509
#[test]
539510
fn download_encrypted_dataset_failure_with_invalid_dataset_checksum() {
540511
let mut app = get_pre_compute_app(CHAIN_TASK_ID, vec![], "");
541-
if let Some(args) = &mut app.pre_compute_args {
542-
args.encrypted_dataset_checksum = Some("invalid_dataset_checksum".to_string())
543-
}
512+
app.pre_compute_args.encrypted_dataset_checksum =
513+
Some("invalid_dataset_checksum".to_string());
544514
let actual_content = app.download_encrypted_dataset();
545515
let expected_content = Err(ReplicateStatusCause::PreComputeInvalidDatasetChecksum);
546516
assert_eq!(actual_content, expected_content);
@@ -562,9 +532,7 @@ mod tests {
562532
#[test]
563533
fn decrypt_dataset_failure_with_bad_key() {
564534
let mut app = get_pre_compute_app(CHAIN_TASK_ID, vec![], "");
565-
if let Some(args) = &mut app.pre_compute_args {
566-
args.encrypted_dataset_base64_key = Some("bad_key".to_string());
567-
}
535+
app.pre_compute_args.encrypted_dataset_base64_key = Some("bad_key".to_string());
568536
let encrypted_data = app.download_encrypted_dataset().unwrap();
569537
let actual_plain_data = app.decrypt_dataset(&encrypted_data);
570538

@@ -608,9 +576,7 @@ mod tests {
608576
let output_path = temp_dir.path().to_str().unwrap();
609577

610578
let mut app = get_pre_compute_app(CHAIN_TASK_ID, vec![], output_path);
611-
if let Some(args) = &mut app.pre_compute_args {
612-
args.plain_dataset_filename = Some("/some-folder-123/not-found".to_string());
613-
}
579+
app.pre_compute_args.plain_dataset_filename = Some("/some-folder-123/not-found".to_string());
614580
let plain_dataset = "Some very useful data.".as_bytes().to_vec();
615581
let saved_dataset = app.save_plain_dataset_file(&plain_dataset);
616582

0 commit comments

Comments
 (0)