@@ -927,6 +927,8 @@ struct AMDGPUStreamTy {
927
927
AMDGPUSignalManagerTy *SignalManager;
928
928
};
929
929
930
+ using AMDGPUStreamCallbackTy = Error(void *Data);
931
+
930
932
// / The stream is composed of N stream's slots. The struct below represents
931
933
// / the fields of each slot. Each slot has a signal and an optional action
932
934
// / function. When appending an HSA asynchronous operation to the stream, one
@@ -942,65 +944,82 @@ struct AMDGPUStreamTy {
942
944
// / operation as input signal.
943
945
AMDGPUSignalTy *Signal;
944
946
945
- // / The action that must be performed after the operation's completion. Set
947
+ // / The actions that must be performed after the operation's completion. Set
946
948
// / to nullptr when there is no action to perform.
947
- Error (*ActionFunction)( void *) ;
949
+ llvm::SmallVector<AMDGPUStreamCallbackTy *> Callbacks ;
948
950
949
951
// / Space for the action's arguments. A pointer to these arguments is passed
950
952
// / to the action function. Notice the space of arguments is limited.
951
- union {
953
+ union ActionArgsTy {
952
954
MemcpyArgsTy MemcpyArgs;
953
955
ReleaseBufferArgsTy ReleaseBufferArgs;
954
956
ReleaseSignalArgsTy ReleaseSignalArgs;
955
- } ActionArgs;
957
+ void *CallbackArgs;
958
+ };
959
+
960
+ llvm::SmallVector<ActionArgsTy> ActionArgs;
956
961
957
962
// / Create an empty slot.
958
- StreamSlotTy () : Signal(nullptr ), ActionFunction( nullptr ) {}
963
+ StreamSlotTy () : Signal(nullptr ), Callbacks({}), ActionArgs({} ) {}
959
964
960
965
// / Schedule a host memory copy action on the slot.
961
966
Error schedHostMemoryCopy (void *Dst, const void *Src, size_t Size ) {
962
- ActionFunction = memcpyAction;
963
- ActionArgs.MemcpyArgs = MemcpyArgsTy{Dst, Src, Size };
967
+ Callbacks. emplace_back ( memcpyAction) ;
968
+ ActionArgs.emplace_back (). MemcpyArgs = MemcpyArgsTy{Dst, Src, Size };
964
969
return Plugin::success ();
965
970
}
966
971
967
972
// / Schedule a release buffer action on the slot.
968
973
Error schedReleaseBuffer (void *Buffer, AMDGPUMemoryManagerTy &Manager) {
969
- ActionFunction = releaseBufferAction;
970
- ActionArgs.ReleaseBufferArgs = ReleaseBufferArgsTy{Buffer, &Manager};
974
+ Callbacks.emplace_back (releaseBufferAction);
975
+ ActionArgs.emplace_back ().ReleaseBufferArgs =
976
+ ReleaseBufferArgsTy{Buffer, &Manager};
971
977
return Plugin::success ();
972
978
}
973
979
974
980
// / Schedule a signal release action on the slot.
975
981
Error schedReleaseSignal (AMDGPUSignalTy *SignalToRelease,
976
982
AMDGPUSignalManagerTy *SignalManager) {
977
- ActionFunction = releaseSignalAction;
978
- ActionArgs.ReleaseSignalArgs =
983
+ Callbacks. emplace_back ( releaseSignalAction) ;
984
+ ActionArgs.emplace_back (). ReleaseSignalArgs =
979
985
ReleaseSignalArgsTy{SignalToRelease, SignalManager};
980
986
return Plugin::success ();
981
987
}
982
988
989
+ // / Register a callback to be called on compleition
990
+ Error schedCallback (AMDGPUStreamCallbackTy *Func, void *Data) {
991
+ Callbacks.emplace_back (Func);
992
+ ActionArgs.emplace_back ().CallbackArgs = Data;
993
+
994
+ return Plugin::success ();
995
+ }
996
+
983
997
// Perform the action if needed.
984
998
Error performAction () {
985
- if (!ActionFunction )
999
+ if (Callbacks. empty () )
986
1000
return Plugin::success ();
987
1001
988
- // Perform the action.
989
- if (ActionFunction == memcpyAction) {
990
- if (auto Err = memcpyAction (&ActionArgs))
991
- return Err;
992
- } else if (ActionFunction == releaseBufferAction) {
993
- if (auto Err = releaseBufferAction (&ActionArgs))
994
- return Err;
995
- } else if (ActionFunction == releaseSignalAction) {
996
- if (auto Err = releaseSignalAction (&ActionArgs))
997
- return Err;
998
- } else {
999
- return Plugin::error (" Unknown action function!" );
1002
+ assert (Callbacks.size () == ActionArgs.size () && " Size mismatch" );
1003
+ for (auto [Callback, ActionArg] : llvm::zip (Callbacks, ActionArgs)) {
1004
+ // Perform the action.
1005
+ if (Callback == memcpyAction) {
1006
+ if (auto Err = memcpyAction (&ActionArg))
1007
+ return Err;
1008
+ } else if (Callback == releaseBufferAction) {
1009
+ if (auto Err = releaseBufferAction (&ActionArg))
1010
+ return Err;
1011
+ } else if (Callback == releaseSignalAction) {
1012
+ if (auto Err = releaseSignalAction (&ActionArg))
1013
+ return Err;
1014
+ } else if (Callback) {
1015
+ if (auto Err = Callback (ActionArg.CallbackArgs ))
1016
+ return Err;
1017
+ }
1000
1018
}
1001
1019
1002
1020
// Invalidate the action.
1003
- ActionFunction = nullptr ;
1021
+ Callbacks.clear ();
1022
+ ActionArgs.clear ();
1004
1023
1005
1024
return Plugin::success ();
1006
1025
}
0 commit comments