Skip to content

Commit e36152e

Browse files
committed
allow EP to be selected
1 parent fe66c12 commit e36152e

File tree

5 files changed

+121
-32
lines changed

5 files changed

+121
-32
lines changed

mobile/examples/model_tester/common/include/model_runner.h

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,42 @@
55
#include <chrono>
66
#include <optional>
77
#include <string>
8+
#include <unordered_map>
89
#include <vector>
910

1011
namespace model_runner {
1112

13+
using Clock = std::chrono::steady_clock;
14+
using Duration = Clock::duration;
15+
1216
struct RunConfig {
17+
// Path to the model to run.
1318
std::string model_path{};
1419

15-
size_t num_warmup_iterations{};
16-
size_t num_iterations{};
20+
// Whether to run a warmup iteration before running the measured (timed) iterations.
21+
bool run_warmup_iteration{true};
22+
23+
// Number of iterations to run.
24+
size_t num_iterations{10};
25+
26+
// Configuration for an Execution Provider (EP).
27+
struct EpConfig {
28+
std::string provider_name{};
29+
std::unordered_map<std::string, std::string> provider_options{};
30+
};
1731

32+
// Specifies the EP to use in the session.
33+
std::optional<EpConfig> ep{};
34+
35+
// Specifies the onnxruntime log level.
1836
std::optional<int> log_level{};
1937
};
2038

21-
using Clock = std::chrono::steady_clock;
22-
using Duration = Clock::duration;
23-
2439
struct RunResult {
40+
// Time taken to load the model.
2541
Duration load_duration;
42+
43+
// Times taken to run the model.
2644
std::vector<Duration> run_durations;
2745
};
2846

mobile/examples/model_tester/common/model_runner.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,10 @@ RunResult Run(const RunConfig& run_config) {
177177

178178
auto session_options = Ort::SessionOptions{};
179179

180+
if (const auto& ep_config = run_config.ep; ep_config.has_value()) {
181+
session_options.AppendExecutionProvider(ep_config->provider_name, ep_config->provider_options);
182+
}
183+
180184
Timer timer{};
181185
auto session = Ort::Session{env, run_config.model_path.c_str(), session_options};
182186
run_result.load_duration = timer.Elapsed();
@@ -191,15 +195,16 @@ RunResult Run(const RunConfig& run_config) {
191195

192196
auto run_options = Ort::RunOptions{};
193197

198+
run_result.run_durations.reserve(run_config.num_iterations);
199+
194200
// warmup
195-
for (size_t i = 0; i < run_config.num_warmup_iterations; ++i) {
201+
if (run_config.run_warmup_iteration) {
196202
auto outputs = session.Run(run_options,
197203
input_name_cstrs.data(), input_values.data(), input_values.size(),
198204
output_name_cstrs.data(), output_name_cstrs.size());
199205
}
200206

201207
// measure runs
202-
run_result.run_durations.reserve(run_config.num_iterations);
203208
for (size_t i = 0; i < run_config.num_iterations; ++i) {
204209
timer.Reset();
205210
auto outputs = session.Run(run_options,

mobile/examples/model_tester/ios/ModelTester/ModelRunner/model_runner_objc_wrapper.h

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,28 @@
11
#import <Foundation/Foundation.h>
2-
#include <stdint.h>
32

43
NS_ASSUME_NONNULL_BEGIN
54

65
/**
7-
* This class is an Objective-C wrapper around the C++ model runner functionality.
6+
* This class is an Objective-C wrapper around the C++ `model_runner::RunConfig` structure.
7+
*/
8+
@interface ModelRunnerRunConfig : NSObject
9+
10+
- (void)setModelPath:(NSString*)modelPath;
11+
12+
- (void)setNumIterations:(NSUInteger)numIterations;
13+
14+
- (void)setExecutionProvider:(NSString*)providerName
15+
options:(nullable NSDictionary<NSString*, NSString*>*)providerOptions;
16+
17+
@end
18+
19+
/**
20+
* This class is an Objective-C wrapper around the C++ model runner functions.
821
*/
922
@interface ModelRunner : NSObject
1023

11-
+ (nullable NSString*)runWithModelPath:(NSString*)modelPath
12-
numIterations:(uint32_t)numIterations
13-
error:(NSError**)error;
24+
+ (nullable NSString*)runWithConfig:(ModelRunnerRunConfig*)config
25+
error:(NSError**)error NS_SWIFT_NAME(run(config:));
1426

1527
@end
1628

mobile/examples/model_tester/ios/ModelTester/ModelRunner/model_runner_objc_wrapper.mm

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,52 @@
22

33
#include "model_runner.h"
44

5+
NS_ASSUME_NONNULL_BEGIN
6+
7+
@interface ModelRunnerRunConfig ()
8+
9+
- (const model_runner::RunConfig&)cppRunConfig;
10+
11+
@end
12+
13+
@implementation ModelRunnerRunConfig {
14+
model_runner::RunConfig _runConfig;
15+
}
16+
17+
- (void)setModelPath:(nonnull NSString*)modelPath {
18+
_runConfig.model_path = modelPath.UTF8String;
19+
}
20+
21+
- (void)setNumIterations:(NSUInteger)numIterations {
22+
_runConfig.num_iterations = static_cast<size_t>(numIterations);
23+
}
24+
25+
- (void)setExecutionProvider:(NSString*)providerName
26+
options:(nullable NSDictionary<NSString*, NSString*>*)providerOptions {
27+
model_runner::RunConfig::EpConfig ep_config{};
28+
ep_config.provider_name = providerName.UTF8String;
29+
if (providerOptions != nil) {
30+
for (NSString* optionName in providerOptions) {
31+
NSString* optionValue = providerOptions[optionName];
32+
ep_config.provider_options.emplace(optionName.UTF8String,
33+
optionValue.UTF8String);
34+
}
35+
}
36+
_runConfig.ep = std::move(ep_config);
37+
}
38+
39+
- (const model_runner::RunConfig&)cppRunConfig {
40+
return _runConfig;
41+
}
42+
43+
@end
44+
545
@implementation ModelRunner
646

7-
+ (nullable NSString*)runWithModelPath:(NSString*)modelPath
8-
numIterations:(uint32_t)numIterations
9-
error:(NSError**)error {
47+
+ (nullable NSString*)runWithConfig:(ModelRunnerRunConfig*)objcConfig
48+
error:(NSError**)error {
1049
try {
11-
model_runner::RunConfig config{};
12-
config.model_path = modelPath.UTF8String;
13-
config.num_iterations = numIterations;
14-
config.num_warmup_iterations = 1;
50+
const auto& config = [objcConfig cppRunConfig];
1551

1652
auto result = model_runner::Run(config);
1753

@@ -32,3 +68,5 @@ + (nullable NSString*)runWithModelPath:(NSString*)modelPath
3268
}
3369

3470
@end
71+
72+
NS_ASSUME_NONNULL_END

mobile/examples/model_tester/ios/ModelTester/ModelTester/ContentView.swift

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,18 @@ enum ModelTesterError: Error {
44
case runtimeError(msg: String)
55
}
66

7+
enum ExecutionProviderType: String, CaseIterable, Identifiable {
8+
case cpu = "CPU"
9+
case coreml = "CoreML"
10+
11+
var id: Self { self }
12+
}
13+
714
struct ContentView: View {
815
@State private var runResultMessage: String = ""
916
@State private var isRunning: Bool = false
10-
@State private var numIterations: UInt32 = 10
17+
@State private var numIterations: UInt = 10
18+
@State private var executionProviderType: ExecutionProviderType = .cpu
1119

1220
private func Run() {
1321
isRunning = true
@@ -19,9 +27,15 @@ struct ContentView: View {
1927
throw ModelTesterError.runtimeError(msg: "Failed to find model file path.")
2028
}
2129

22-
output = try ModelRunner.run(
23-
withModelPath: modelPath,
24-
numIterations: numIterations)
30+
let config = ModelRunnerRunConfig()
31+
config.setModelPath(modelPath)
32+
config.setNumIterations(numIterations)
33+
34+
if executionProviderType != .cpu {
35+
config.setExecutionProvider(executionProviderType.rawValue)
36+
}
37+
38+
output = try ModelRunner.run(config: config)
2539
} catch {
2640
output = "Error: \(error)"
2741
}
@@ -33,14 +47,17 @@ struct ContentView: View {
3347
}
3448

3549
var body: some View {
36-
VStack {
37-
HStack {
38-
Text("Iterations:")
39-
TextField(
40-
"", value: $numIterations,
41-
format: IntegerFormatStyle<UInt32>.number
42-
)
43-
.keyboardType(.numberPad)
50+
Form {
51+
Text("Iterations:")
52+
TextField(
53+
"", value: $numIterations,
54+
format: IntegerFormatStyle<UInt>.number
55+
).keyboardType(.numberPad)
56+
57+
Picker("Execution provider type", selection: $executionProviderType) {
58+
ForEach(ExecutionProviderType.allCases) { epType in
59+
Text(epType.rawValue).tag(epType)
60+
}
4461
}
4562

4663
Button(action: Run) { Text("Run") }
@@ -49,7 +66,6 @@ struct ContentView: View {
4966
Text(runResultMessage)
5067
.font(.body.monospaced())
5168
}
52-
.padding()
5369
}
5470
}
5571

0 commit comments

Comments
 (0)