88
99#include < executorch/extension/runner_util/inputs.h>
1010
11+ #include < cstdlib>
12+ #include < cstring>
13+
1114#include < executorch/extension/data_loader/file_data_loader.h>
1215#include < executorch/runtime/core/exec_aten/exec_aten.h>
1316#include < executorch/runtime/core/span.h>
@@ -40,52 +43,81 @@ class InputsTest : public ::testing::Test {
4043 void SetUp () override {
4144 torch::executor::runtime_init ();
4245
43- // Create a loader for the serialized ModuleAdd program.
44- const char * path = std::getenv (" ET_MODULE_ADD_PATH" );
45- Result<FileDataLoader> loader = FileDataLoader::from (path);
46- ASSERT_EQ (loader.error (), Error::Ok);
47- loader_ = std::make_unique<FileDataLoader>(std::move (loader.get ()));
46+ // Load ModuleAdd
47+ const char * add_path = std::getenv (" ET_MODULE_ADD_PATH" );
48+ ASSERT_NE (add_path, nullptr )
49+ << " ET_MODULE_ADD_PATH environment variable must be set" ;
50+ Result<FileDataLoader> add_loader = FileDataLoader::from (add_path);
51+ ASSERT_EQ (add_loader.error (), Error::Ok);
52+ add_loader_ = std::make_unique<FileDataLoader>(std::move (add_loader.get ()));
53+
54+ Result<Program> add_program = Program::load (
55+ add_loader_.get (), Program::Verification::InternalConsistency);
56+ ASSERT_EQ (add_program.error (), Error::Ok);
57+ add_program_ = std::make_unique<Program>(std::move (add_program.get ()));
58+
59+ add_mmm_ = std::make_unique<ManagedMemoryManager>(
60+ /* planned_memory_bytes=*/ 32 * 1024U ,
61+ /* method_allocator_bytes=*/ 32 * 1024U );
62+
63+ Result<Method> add_method =
64+ add_program_->load_method (" forward" , &add_mmm_->get ());
65+ ASSERT_EQ (add_method.error (), Error::Ok);
66+ add_method_ = std::make_unique<Method>(std::move (add_method.get ()));
67+
68+ // Load ModuleIntBool
69+ const char * intbool_path = std::getenv (" ET_MODULE_INTBOOL_PATH" );
70+ ASSERT_NE (intbool_path, nullptr )
71+ << " ET_MODULE_INTBOOL_PATH environment variable must be set" ;
72+ Result<FileDataLoader> intbool_loader = FileDataLoader::from (intbool_path);
73+ ASSERT_EQ (intbool_loader.error (), Error::Ok);
74+ intbool_loader_ =
75+ std::make_unique<FileDataLoader>(std::move (intbool_loader.get ()));
4876
49- // Use it to load the program.
50- Result<Program> program = Program::load (
51- loader_. get (), Program::Verification::InternalConsistency );
52- ASSERT_EQ (program. error (), Error::Ok);
53- program_ = std::make_unique<Program>(std::move (program .get ()));
77+ Result<Program> intbool_program = Program:: load(
78+ intbool_loader_. get (), Program::Verification::InternalConsistency);
79+ ASSERT_EQ (intbool_program. error (), Error::Ok );
80+ intbool_program_ =
81+ std::make_unique<Program>(std::move (intbool_program .get ()));
5482
55- mmm_ = std::make_unique<ManagedMemoryManager>(
83+ intbool_mmm_ = std::make_unique<ManagedMemoryManager>(
5684 /* planned_memory_bytes=*/ 32 * 1024U ,
5785 /* method_allocator_bytes=*/ 32 * 1024U );
5886
59- // Load the forward method.
60- Result<Method> method = program_ ->load_method (" forward" , &mmm_ ->get ());
61- ASSERT_EQ (method .error (), Error::Ok);
62- method_ = std::make_unique<Method>(std::move (method .get ()));
87+ Result<Method> intbool_method =
88+ intbool_program_ ->load_method (" forward" , &intbool_mmm_ ->get ());
89+ ASSERT_EQ (intbool_method .error (), Error::Ok);
90+ intbool_method_ = std::make_unique<Method>(std::move (intbool_method .get ()));
6391 }
6492
6593 private:
66- // Must outlive method_, but tests shouldn't need to touch them.
67- std::unique_ptr<FileDataLoader> loader_;
68- std::unique_ptr<ManagedMemoryManager> mmm_;
69- std::unique_ptr<Program> program_;
94+ std::unique_ptr<FileDataLoader> add_loader_;
95+ std::unique_ptr<Program> add_program_;
96+ std::unique_ptr<ManagedMemoryManager> add_mmm_;
97+
98+ std::unique_ptr<FileDataLoader> intbool_loader_;
99+ std::unique_ptr<Program> intbool_program_;
100+ std::unique_ptr<ManagedMemoryManager> intbool_mmm_;
70101
71102 protected:
72- std::unique_ptr<Method> method_;
103+ std::unique_ptr<Method> add_method_;
104+ std::unique_ptr<Method> intbool_method_;
73105};
74106
75107TEST_F (InputsTest, Smoke) {
76- Result<BufferCleanup> input_buffers = prepare_input_tensors (*method_ );
108+ Result<BufferCleanup> input_buffers = prepare_input_tensors (*add_method_ );
77109 ASSERT_EQ (input_buffers.error (), Error::Ok);
78- auto input_err = method_ ->set_input (executorch::runtime::EValue (1.0 ), 2 );
110+ auto input_err = add_method_ ->set_input (executorch::runtime::EValue (1.0 ), 2 );
79111 ASSERT_EQ (input_err, Error::Ok);
80112
81113 // We can't look at the input tensors, but we can check that the outputs make
82114 // sense after executing the method.
83- Error status = method_ ->execute ();
115+ Error status = add_method_ ->execute ();
84116 ASSERT_EQ (status, Error::Ok);
85117
86118 // Get the single output, which should be a floating-point Tensor.
87- ASSERT_EQ (method_ ->outputs_size (), 1 );
88- const EValue& output_value = method_ ->get_output (0 );
119+ ASSERT_EQ (add_method_ ->outputs_size (), 1 );
120+ const EValue& output_value = add_method_ ->get_output (0 );
89121 ASSERT_EQ (output_value.tag , Tag::Tensor);
90122 Tensor output = output_value.toTensor ();
91123 ASSERT_EQ (output.scalar_type (), ScalarType::Float);
@@ -107,14 +139,14 @@ TEST_F(InputsTest, ExceedingInputCountLimitFails) {
107139 // The smoke test above demonstrated that we can prepare inputs with the
108140 // default limits. It should fail if we lower the max below the number of
109141 // actual inputs.
110- MethodMeta method_meta = method_ ->method_meta ();
142+ MethodMeta method_meta = add_method_ ->method_meta ();
111143 size_t num_inputs = method_meta.num_inputs ();
112144 ASSERT_GE (num_inputs, 1 );
113145 executorch::extension::PrepareInputTensorsOptions options;
114146 options.max_inputs = num_inputs - 1 ;
115147
116148 Result<BufferCleanup> input_buffers =
117- prepare_input_tensors (*method_ , options);
149+ prepare_input_tensors (*add_method_ , options);
118150 ASSERT_NE (input_buffers.error (), Error::Ok);
119151}
120152
@@ -128,7 +160,7 @@ TEST_F(InputsTest, ExceedingInputAllocationLimitFails) {
128160 options.max_total_allocation_size = 1 ;
129161
130162 Result<BufferCleanup> input_buffers =
131- prepare_input_tensors (*method_ , options);
163+ prepare_input_tensors (*add_method_ , options);
132164 ASSERT_NE (input_buffers.error (), Error::Ok);
133165}
134166
@@ -186,3 +218,107 @@ TEST(BufferCleanupTest, Smoke) {
186218 // complaint.
187219 bc2.reset ();
188220}
221+
222+ TEST_F (InputsTest, DoubleInputWrongSizeFails) {
223+ MethodMeta method_meta = add_method_->method_meta ();
224+
225+ // ModuleAdd has 3 inputs: tensor, tensor, double (alpha)
226+ ASSERT_EQ (method_meta.num_inputs (), 3 );
227+
228+ // Verify input 2 is a Double
229+ auto tag = method_meta.input_tag (2 );
230+ ASSERT_TRUE (tag.ok ());
231+ ASSERT_EQ (tag.get (), Tag::Double);
232+
233+ // Create input_buffers with wrong size for the Double input
234+ std::vector<std::pair<char *, size_t >> input_buffers;
235+
236+ // Allocate correct buffers for tensors (inputs 0 and 1)
237+ auto tensor0_meta = method_meta.input_tensor_meta (0 );
238+ auto tensor1_meta = method_meta.input_tensor_meta (1 );
239+ ASSERT_TRUE (tensor0_meta.ok ());
240+ ASSERT_TRUE (tensor1_meta.ok ());
241+
242+ std::vector<char > buf0 (tensor0_meta->nbytes (), 0 );
243+ std::vector<char > buf1 (tensor1_meta->nbytes (), 0 );
244+
245+ // ModuleAdd expects alpha=1.0. Need to set this correctly, otherwise
246+ // set_input fails validation before the buffer overflow happens.
247+ double alpha = 1.0 ;
248+ // Double is size 8; use a larger buffer to invoke overflow.
249+ char large_buffer[16 ];
250+ memcpy (large_buffer, &alpha, sizeof (double ));
251+
252+ input_buffers.push_back ({buf0.data (), buf0.size ()});
253+ input_buffers.push_back ({buf1.data (), buf1.size ()});
254+ input_buffers.push_back ({large_buffer, sizeof (large_buffer)});
255+
256+ Result<BufferCleanup> result =
257+ prepare_input_tensors (*add_method_, {}, input_buffers);
258+ EXPECT_EQ (result.error (), Error::InvalidArgument);
259+ }
260+
261+ TEST_F (InputsTest, IntBoolInputWrongSizeFails) {
262+ MethodMeta method_meta = intbool_method_->method_meta ();
263+
264+ // ModuleIntBool has 3 inputs: tensor, int, bool
265+ ASSERT_EQ (method_meta.num_inputs (), 3 );
266+
267+ // Verify input types
268+ auto int_tag = method_meta.input_tag (1 );
269+ ASSERT_TRUE (int_tag.ok ());
270+ ASSERT_EQ (int_tag.get (), Tag::Int);
271+
272+ auto bool_tag = method_meta.input_tag (2 );
273+ ASSERT_TRUE (bool_tag.ok ());
274+ ASSERT_EQ (bool_tag.get (), Tag::Bool);
275+
276+ // Allocate correct buffer for tensor (input 0)
277+ auto tensor0_meta = method_meta.input_tensor_meta (0 );
278+ ASSERT_TRUE (tensor0_meta.ok ());
279+ std::vector<char > buf0 (tensor0_meta->nbytes (), 0 );
280+
281+ // Prepare scalar values
282+ int64_t y = 1 ;
283+ bool z = true ;
284+
285+ // Test 1: Int input with wrong size
286+ {
287+ std::vector<std::pair<char *, size_t >> input_buffers;
288+
289+ // Int is size 8; use a larger buffer to invoke overflow.
290+ char large_int_buffer[16 ];
291+ memcpy (large_int_buffer, &y, sizeof (int64_t ));
292+
293+ char bool_buffer[sizeof (bool )];
294+ memcpy (bool_buffer, &z, sizeof (bool ));
295+
296+ input_buffers.push_back ({buf0.data (), buf0.size ()});
297+ input_buffers.push_back ({large_int_buffer, sizeof (large_int_buffer)});
298+ input_buffers.push_back ({bool_buffer, sizeof (bool_buffer)});
299+
300+ Result<BufferCleanup> result =
301+ prepare_input_tensors (*intbool_method_, {}, input_buffers);
302+ EXPECT_EQ (result.error (), Error::InvalidArgument);
303+ }
304+
305+ // Test 2: Bool input with wrong size
306+ {
307+ std::vector<std::pair<char *, size_t >> input_buffers;
308+
309+ char int_buffer[sizeof (int64_t )];
310+ memcpy (int_buffer, &y, sizeof (int64_t ));
311+
312+ // Bool is size 1; use a larger buffer to invoke overflow.
313+ char large_bool_buffer[8 ];
314+ memcpy (large_bool_buffer, &z, sizeof (bool ));
315+
316+ input_buffers.push_back ({buf0.data (), buf0.size ()});
317+ input_buffers.push_back ({int_buffer, sizeof (int_buffer)});
318+ input_buffers.push_back ({large_bool_buffer, sizeof (large_bool_buffer)});
319+
320+ Result<BufferCleanup> result =
321+ prepare_input_tensors (*intbool_method_, {}, input_buffers);
322+ EXPECT_EQ (result.error (), Error::InvalidArgument);
323+ }
324+ }
0 commit comments