Skip to content

Commit 02aa4c9

Browse files
mzientstiepan
authored andcommitted
C API 2.0 Checkpointing + unblock dali.h (#5879)
* This commit adds C API for pipeline checkpointing. * Checkpointing support in dali::Pipeline is also slightly refactored to allow the Pipeline object to serialize checkpoint objects passed as an argument. * The checkpoint in C API can be obtained from a Pipeline handle by calling daliPipelineGetCheckpoint and then serialized separately. This allows for the checkpoint to be passed between pipeline instances without serialization. * External checkpoint data (pipeline_data, iterator_data) is stored with the checkpoint object and does not need manual memory management. * This commit also removes the #error that blocked inclusion of dali.h header, making it more or less official. --------- Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
1 parent 54c6221 commit 02aa4c9

File tree

22 files changed

+435
-48
lines changed

22 files changed

+435
-48
lines changed

dali/benchmark/checkpointing_bench.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class CheckpointingOverhead : public DALIBenchmark {
4949
if (policy == CheckpointingPolicy::SaveEveryIter) {
5050
volatile auto cpt = pipe->GetCheckpoint();
5151
} else if (policy == CheckpointingPolicy::SerializeEveryIter) {
52-
volatile auto cpt = pipe->SerializedCheckpoint({});
52+
volatile auto cpt = pipe->GetSerializedCheckpoint({});
5353
}
5454
}
5555

dali/c_api/c_api.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,7 @@ void daliGetSerializedCheckpoint(
873873
external_context->iterator_data.size
874874
};
875875
}
876-
std::string cpt = pipeline->SerializedCheckpoint(ctx);
876+
std::string cpt = pipeline->GetSerializedCheckpoint(ctx);
877877
*n = cpt.size();
878878
*checkpoint = reinterpret_cast<char *>(daliAlloc(cpt.size()));
879879
DALI_ENFORCE(*checkpoint, "Failed to allocate memory");

dali/c_api_2/c_api_internal_test.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
#include <gtest/gtest.h>
1616
#include <stdexcept>
1717
#include <system_error>
18-
#define DALI_ALLOW_NEW_C_API
1918
#include "dali/dali.h"
2019
#include "dali/c_api_2/error_handling.h"
2120
#include "dali/core/cuda_error.h"

dali/c_api_2/checkpoint.h

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#ifndef DALI_C_API_2_CHECKPOINT_H_
16+
#define DALI_C_API_2_CHECKPOINT_H_
17+
18+
#include <memory>
19+
#include <string>
20+
#include <string_view>
21+
#include <utility>
22+
#include <vector>
23+
#include "dali/dali.h"
24+
#include "dali/pipeline/operator/checkpointing/checkpoint.h"
25+
26+
// A dummy base that the handle points to
27+
struct _DALICheckpoint {
28+
protected:
29+
_DALICheckpoint() = default;
30+
~_DALICheckpoint() = default;
31+
};
32+
33+
34+
namespace dali::c_api {
35+
36+
class PipelineWrapper;
37+
38+
class CheckpointWrapper : public _DALICheckpoint {
39+
public:
40+
explicit CheckpointWrapper(Checkpoint &&cpt)
41+
: cpt_(std::move(cpt)) {}
42+
43+
const std::string &Serialized() const & {
44+
return serialized_.value();
45+
}
46+
47+
void Serialize(const PipelineWrapper &pipeline);
48+
49+
daliCheckpointExternalData_t ExternalData() const {
50+
daliCheckpointExternalData_t ext;
51+
ext.iterator_data.data = cpt_.external_ctx_cpt_.iterator_data.data();
52+
ext.iterator_data.size = cpt_.external_ctx_cpt_.iterator_data.size();
53+
ext.pipeline_data.data = cpt_.external_ctx_cpt_.pipeline_data.data();
54+
ext.pipeline_data.size = cpt_.external_ctx_cpt_.pipeline_data.size();
55+
return ext;
56+
}
57+
58+
Checkpoint *Unwrap() & {
59+
return &cpt_;
60+
}
61+
62+
const Checkpoint *Unwrap() const & {
63+
return &cpt_;
64+
}
65+
66+
private:
67+
Checkpoint cpt_;
68+
std::optional<std::string> serialized_;
69+
};
70+
71+
} // namespace dali::c_api
72+
73+
#endif // DALI_C_API_2_CHECKPOINT_H_

dali/c_api_2/data_objects.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
#include <string>
2222
#include <utility>
2323
#include <vector>
24-
#define DALI_ALLOW_NEW_C_API
2524
#include "dali/dali.h"
2625
#include "dali/pipeline/data/tensor_list.h"
2726
#include "dali/c_api_2/ref_counting.h"
@@ -661,9 +660,9 @@ class TensorListWrapper : public ITensorList {
661660
}
662661

663662
RefCountedPtr<ITensor> ViewAsTensor() const override {
664-
if (!tl_->IsContiguous())
663+
if (!tl_->IsDenseTensor())
665664
throw std::runtime_error(
666-
"The TensorList is not contiguous and cannot be viewed as a Tensor.");
665+
"Only a densely packed list of tensors of uniform shape can be viewed as a Tensor.");
667666

668667
auto t = std::make_shared<Tensor<Backend>>();
669668
auto buf = unsafe_owner(*tl_);

dali/c_api_2/error_handling.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
#include <iostream>
2020
#include <string>
2121
#include <sstream>
22-
#define DALI_ALLOW_NEW_C_API
2322
#include "dali/dali.h"
2423
#include "dali/core/error_handling.h"
2524

dali/c_api_2/init.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
// limitations under the License.
1414

1515
#include <atomic>
16-
#define DALI_ALLOW_NEW_C_API
1716
#include "dali/dali.h"
1817
#include "dali/c_api_2/error_handling.h"
1918
#include "dali/pipeline/init.h"

dali/c_api_2/managed_handle.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#include <cassert>
1919
#include <stdexcept>
2020
#include <utility>
21-
#define DALI_ALLOW_NEW_C_API
2221
#include "dali/dali.h"
2322
#include "dali/core/unique_handle.h"
2423

@@ -122,6 +121,7 @@ class Resource##Handle \
122121

123122
DALI_C_UNIQUE_HANDLE(Pipeline);
124123
DALI_C_UNIQUE_HANDLE(PipelineOutputs);
124+
DALI_C_UNIQUE_HANDLE(Checkpoint);
125125
DALI_C_REF_HANDLE(TensorList);
126126
DALI_C_REF_HANDLE(Tensor);
127127

dali/c_api_2/op_test/complex_pipeline_test.cc

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,17 @@
2424
namespace dali::c_api::test {
2525

2626
std::unique_ptr<Pipeline>
27-
ReaderDecoderPipe(std::string_view decoder_device, StorageDevice output_device) {
27+
ReaderDecoderPipe(
28+
std::string_view decoder_device,
29+
StorageDevice output_device,
30+
PipelineParams params = {}) {
2831
std::string file_root = testing::dali_extra_path() + "/db/single/jpeg/";
2932
std::string file_list = file_root + "image_list.txt";
30-
auto pipe = std::make_unique<Pipeline>(4, 1, 0, 12345, true, 2, true, true);
33+
if (!params.max_batch_size) params.max_batch_size = 4;
34+
if (!params.num_threads) params.num_threads = 1;
35+
if (!params.seed) params.seed = 12345;
36+
if (!params.executor_type) params.executor_type = ExecutorType::Dynamic;
37+
auto pipe = std::make_unique<Pipeline>(params);
3138
pipe->AddOperator(OpSpec("FileReader")
3239
.AddArg("device", "cpu")
3340
.AddArg("file_root", file_root)
@@ -81,4 +88,79 @@ TEST(CAPI2_PipelineTest, ReaderDecoderMixed2CPU) {
8188
TestReaderDecoder("mixed", StorageDevice::CPU);
8289
}
8390

91+
TEST(CAPI2_PipelineTest, Checkpointing) {
92+
// This test creates three pipelines - a C++ pipeline (ref) and two C pipelines (pipe1, pipe2),
93+
// created by deserializing the serialized representation of the C++ pipeline.
94+
//
95+
// (pipe1) advances 3 iterations and then a checkpoint is taken and restored in (ref),
96+
// after which 5 iterations of outputs are compared.
97+
// Then a checkpoint is taken in (ref) and restored in (pipe2), after which another 5 iterations
98+
// are compared.
99+
PipelineParams params{};
100+
params.enable_checkpointing = true;
101+
params.seed = 1234;
102+
auto ref = ReaderDecoderPipe("cpu", StorageDevice::GPU, params);
103+
ref->Build();
104+
auto pipe_str = ref->SerializeToProtobuf(); // serialize the ref...
105+
auto pipe1 = Deserialize(pipe_str, {}); // ...and create pipe1
106+
auto pipe2 = Deserialize(pipe_str, {}); // ...and pipe2 from serialized ref
107+
CHECK_DALI(daliPipelineBuild(pipe1));
108+
CHECK_DALI(daliPipelineBuild(pipe2));
109+
110+
// Advance a few iterations...
111+
CHECK_DALI(daliPipelinePrefetch(pipe1));
112+
daliPipelineOutputs_h out1_h{};
113+
CHECK_DALI(daliPipelinePopOutputs(pipe1, &out1_h));
114+
CHECK_DALI(daliPipelineOutputsDestroy(out1_h));
115+
CHECK_DALI(daliPipelineRun(pipe1));
116+
CHECK_DALI(daliPipelinePopOutputs(pipe1, &out1_h));
117+
CHECK_DALI(daliPipelineOutputsDestroy(out1_h));
118+
CHECK_DALI(daliPipelineRun(pipe1));
119+
CHECK_DALI(daliPipelinePopOutputs(pipe1, &out1_h));
120+
CHECK_DALI(daliPipelineOutputsDestroy(out1_h));
121+
122+
const char pipeline_data[] = "A rose by any other name would smell as sweet";
123+
size_t pipeline_data_size = strlen(pipeline_data);
124+
125+
daliCheckpointExternalData_t ext{};
126+
ext.iterator_data.data = "ITER";
127+
ext.iterator_data.size = 4;
128+
ext.pipeline_data.data = pipeline_data;
129+
ext.pipeline_data.size = strlen(ext.pipeline_data.data);
130+
131+
daliCheckpoint_h checkpoint_h{};
132+
// Take a checkpoint...
133+
CHECK_DALI(daliPipelineGetCheckpoint(pipe1, &checkpoint_h, &ext));
134+
CheckpointHandle checkpoint(checkpoint_h);
135+
136+
const char *data = nullptr;
137+
size_t size = 0;
138+
CHECK_DALI(daliPipelineSerializeCheckpoint(pipe1, checkpoint, &data, &size));
139+
ASSERT_NE(data, nullptr);
140+
141+
// ...restore...
142+
ref->RestoreFromSerializedCheckpoint(std::string(data, size));
143+
// ...run and compare.
144+
ComparePipelineOutputs(*ref, pipe1, 5, false);
145+
146+
// Now take another checkpoint...
147+
auto chk_str = ref->GetSerializedCheckpoint({ ext.pipeline_data.data, ext.iterator_data.data });
148+
149+
// ...deserialize...
150+
CHECK_DALI(daliPipelineDeserializeCheckpoint(
151+
pipe2, &checkpoint_h, chk_str.data(), chk_str.length()));
152+
CheckpointHandle checkpoint2(checkpoint_h);
153+
154+
daliCheckpointExternalData_t ext2{};
155+
CHECK_DALI(daliCheckpointGetExternalData(checkpoint2, &ext2));
156+
EXPECT_EQ(ext2.iterator_data.size, 4);
157+
EXPECT_STREQ(ext2.iterator_data.data, "ITER");
158+
EXPECT_EQ(ext2.pipeline_data.size, pipeline_data_size);
159+
EXPECT_STREQ(ext2.pipeline_data.data, pipeline_data);
160+
// ...restore...
161+
CHECK_DALI(daliPipelineRestoreCheckpoint(pipe2, checkpoint2));
162+
// ...run and compare.
163+
ComparePipelineOutputs(*ref, pipe2, 5, false);
164+
}
165+
84166
} // namespace dali::c_api::test

dali/c_api_2/pipeline.cc

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "dali/c_api_2/pipeline.h"
1616
#include "dali/c_api_2/pipeline_outputs.h"
17+
#include "dali/c_api_2/checkpoint.h"
1718
#include "dali/c_api_2/error_handling.h"
1819
#include "dali/pipeline/pipeline.h"
1920
#include "dali/c_api_2/utils.h"
@@ -27,6 +28,12 @@ PipelineWrapper *ToPointer(daliPipeline_h handle) {
2728
return static_cast<PipelineWrapper *>(handle);
2829
}
2930

31+
CheckpointWrapper *ToPointer(daliCheckpoint_h handle) {
32+
if (!handle)
33+
throw NullHandle("Checkpoint");
34+
return static_cast<CheckpointWrapper *>(handle);
35+
}
36+
3037
PipelineParams ToCppParams(const daliPipelineParams_t &params) {
3138
PipelineParams cpp_params = {};
3239

@@ -237,6 +244,39 @@ void PipelineWrapper::FeedInputImpl(
237244
data_id ? std::optional<std::string>(std::in_place, *data_id) : std::nullopt);
238245
}
239246

247+
std::unique_ptr<CheckpointWrapper> PipelineWrapper::GetCheckpoint(
248+
const daliCheckpointExternalData_t *ext) const {
249+
auto cpt = std::make_unique<CheckpointWrapper>(pipeline_->GetCheckpoint());
250+
if (ext) {
251+
cpt->Unwrap()->external_ctx_cpt_.pipeline_data =
252+
std::string(ext->pipeline_data.data, ext->pipeline_data.size);
253+
cpt->Unwrap()->external_ctx_cpt_.iterator_data =
254+
std::string(ext->iterator_data.data, ext->iterator_data.size);
255+
}
256+
return cpt;
257+
}
258+
259+
std::string_view PipelineWrapper::SerializeCheckpoint(CheckpointWrapper &chk) const {
260+
chk.Serialize(*this);
261+
return chk.Serialized();
262+
}
263+
264+
void CheckpointWrapper::Serialize(const PipelineWrapper &pipeline) {
265+
if (!serialized_)
266+
serialized_ = pipeline.Unwrap()->SerializeCheckpoint(cpt_);
267+
}
268+
269+
std::unique_ptr<CheckpointWrapper>
270+
PipelineWrapper::DeserializeCheckpoint(std::string_view serialized) {
271+
return std::make_unique<CheckpointWrapper>(pipeline_->DeserializeCheckpoint(serialized));
272+
}
273+
274+
void PipelineWrapper::RestoreFromCheckpoint(CheckpointWrapper &chk) {
275+
pipeline_->RestoreFromCheckpoint(*chk.Unwrap());
276+
}
277+
278+
279+
240280
} // namespace dali::c_api
241281

242282
using namespace dali::c_api; // NOLINT
@@ -393,3 +433,83 @@ daliResult_t daliPipelinePopOutputsAsync(
393433
*out = pipe->PopOutputs(stream).release();
394434
DALI_EPILOG();
395435
}
436+
437+
438+
daliResult_t daliPipelineGetCheckpoint(
439+
daliPipeline_h pipeline,
440+
daliCheckpoint_h *out_checkpoint,
441+
const daliCheckpointExternalData_t *checkpoint_ext) {
442+
DALI_PROLOG();
443+
auto pipe = ToPointer(pipeline);
444+
CHECK_OUTPUT(out_checkpoint);
445+
auto chk = pipe->GetCheckpoint(checkpoint_ext);
446+
*out_checkpoint = chk.release(); // No throwing beyond this point!
447+
DALI_EPILOG();
448+
}
449+
450+
daliResult_t daliPipelineRestoreCheckpoint(
451+
daliPipeline_h pipeline,
452+
daliCheckpoint_h checkpoint) {
453+
DALI_PROLOG();
454+
auto pipe = ToPointer(pipeline);
455+
auto chk = ToPointer(checkpoint);
456+
pipe->RestoreFromCheckpoint(*chk);
457+
DALI_EPILOG();
458+
}
459+
460+
daliResult_t daliPipelineDeserializeCheckpoint(
461+
daliPipeline_h pipeline,
462+
daliCheckpoint_h *out_checkpoint,
463+
const char *serialized_checkpoint,
464+
size_t serialized_checkpoint_size) {
465+
DALI_PROLOG();
466+
auto pipe = ToPointer(pipeline);
467+
CHECK_OUTPUT(out_checkpoint);
468+
if (!serialized_checkpoint_size) {
469+
*out_checkpoint = nullptr;
470+
return DALI_NO_DATA;
471+
}
472+
if (!serialized_checkpoint) {
473+
throw std::invalid_argument("The parameter `serialized_checkpoint` must not be NULL if "
474+
"`serialize_checkpoint_size` is nonzero.");
475+
}
476+
477+
auto cpt = pipe->DeserializeCheckpoint(
478+
std::string_view(serialized_checkpoint, serialized_checkpoint_size));
479+
480+
*out_checkpoint = cpt.release(); // No throwing beyond this point!
481+
DALI_EPILOG();
482+
}
483+
484+
daliResult_t daliCheckpointGetExternalData(
485+
daliCheckpoint_h checkpoint,
486+
daliCheckpointExternalData_t *out_ext_data) {
487+
DALI_PROLOG();
488+
auto cpt = ToPointer(checkpoint);
489+
CHECK_OUTPUT(out_ext_data);
490+
*out_ext_data = cpt->ExternalData();
491+
DALI_EPILOG();
492+
}
493+
494+
daliResult_t daliPipelineSerializeCheckpoint(
495+
daliPipeline_h pipeline,
496+
daliCheckpoint_h checkpoint,
497+
const char **out_data,
498+
size_t *out_size) {
499+
DALI_PROLOG();
500+
auto pipe = ToPointer(pipeline);
501+
auto cpt = ToPointer(checkpoint);
502+
CHECK_OUTPUT(out_data);
503+
CHECK_OUTPUT(out_size);
504+
auto serialized = pipe->SerializeCheckpoint(*cpt);
505+
*out_data = serialized.data();
506+
*out_size = serialized.size();
507+
DALI_EPILOG();
508+
}
509+
510+
/** Destroys a checkpoint object */
511+
daliResult_t daliCheckpointDestroy(daliCheckpoint_h checkpoint) {
512+
DALI_PROLOG();
513+
delete ToPointer(checkpoint);
514+
DALI_EPILOG();
515+
}

0 commit comments

Comments
 (0)