diff --git a/src/algorithms/onnx/CalorimeterParticleIDPostML.cc b/src/algorithms/onnx/CalorimeterParticleIDPostML.cc index e9e0bd20c6..6bb0ab4599 100644 --- a/src/algorithms/onnx/CalorimeterParticleIDPostML.cc +++ b/src/algorithms/onnx/CalorimeterParticleIDPostML.cc @@ -20,8 +20,8 @@ void CalorimeterParticleIDPostML::init() { void CalorimeterParticleIDPostML::process(const CalorimeterParticleIDPostML::Input& input, const CalorimeterParticleIDPostML::Output& output) const { - const auto [in_clusters, in_assocs, prediction_tensors] = input; - auto [out_clusters, out_assocs, out_particle_ids] = output; + const auto [in_clusters, in_track_matches, in_assocs, prediction_tensors] = input; + auto [out_clusters, out_track_matches, out_assocs, out_particle_ids] = output; if (prediction_tensors->size() != 1) { error("Expected to find a single tensor, found {}", prediction_tensors->size()); @@ -79,6 +79,15 @@ void CalorimeterParticleIDPostML::process(const CalorimeterParticleIDPostML::Inp prob_electron // float likelihood )); + // propagate track matches + for (auto in_track_match : *in_track_matches) { + if (in_track_match.getCluster() == in_cluster) { + auto out_track_match = in_track_match.clone(); + out_track_match.setCluster(out_cluster); + out_track_matches->push_back(out_track_match); + } + } + // propagate associations for (auto in_assoc : *in_assocs) { if (in_assoc.getRec() == in_cluster) { diff --git a/src/algorithms/onnx/CalorimeterParticleIDPostML.h b/src/algorithms/onnx/CalorimeterParticleIDPostML.h index 3dd3c68c3c..2ddf73c111 100644 --- a/src/algorithms/onnx/CalorimeterParticleIDPostML.h +++ b/src/algorithms/onnx/CalorimeterParticleIDPostML.h @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -17,10 +18,10 @@ namespace eicrecon { using CalorimeterParticleIDPostMLAlgorithm = algorithms::Algorithm< - algorithms::Input, edm4eic::TensorCollection>, - algorithms::Output, edm4hep::ParticleIDCollection>>; diff --git a/src/algorithms/onnx/CalorimeterParticleIDPreML.cc b/src/algorithms/onnx/CalorimeterParticleIDPreML.cc index 745ea2f085..bdf239a62b 100644 --- a/src/algorithms/onnx/CalorimeterParticleIDPreML.cc +++ b/src/algorithms/onnx/CalorimeterParticleIDPreML.cc @@ -1,19 +1,19 @@ // SPDX-License-Identifier: LGPL-3.0-or-later -// Copyright (C) 2024 Dmitry Kalinkin +// Copyright (C) 2024 - 2025 Dmitry Kalinkin #include #if EDM4EIC_VERSION_MAJOR >= 8 -#include -#include #include +#include #include #include #include #include -#include - +#include +#include #include +#include #include "CalorimeterParticleIDPreML.h" @@ -26,8 +26,8 @@ void CalorimeterParticleIDPreML::init() { void CalorimeterParticleIDPreML::process(const CalorimeterParticleIDPreML::Input& input, const CalorimeterParticleIDPreML::Output& output) const { - const auto [clusters, cluster_assocs] = input; - auto [feature_tensors, target_tensors] = output; + const auto [clusters, track_matches, cluster_assocs] = input; + auto [feature_tensors, target_tensors] = output; edm4eic::MutableTensor feature_tensor = feature_tensors->create(); feature_tensor.addToShape(clusters->size()); @@ -45,19 +45,18 @@ void CalorimeterParticleIDPreML::process(const CalorimeterParticleIDPreML::Input for (edm4eic::Cluster cluster : *clusters) { double momentum = NAN; { - // FIXME: use track momentum once matching to tracks becomes available - edm4eic::MCRecoClusterParticleAssociation best_assoc; - for (auto assoc : *cluster_assocs) { - if (assoc.getRec() == cluster) { - if ((not best_assoc.isAvailable()) || (assoc.getWeight() > best_assoc.getWeight())) { - best_assoc = assoc; + auto best_match = edm4eic::TrackClusterMatch::makeEmpty(); + for (auto match : *track_matches) { + if (match.getCluster() == cluster) { + if ((not best_match.isAvailable()) || (match.getWeight() > best_match.getWeight())) { + best_match = match; } } } - if (best_assoc.isAvailable()) { - momentum = edm4hep::utils::magnitude(best_assoc.getSim().getMomentum()); + if (best_match.isAvailable()) { + momentum = edm4hep::utils::magnitude(best_match.getTrack().getMomentum()); } else { - warning("Can't find association for cluster. Skipping..."); + trace("Can't find a match for the cluster. Skipping..."); continue; } } diff --git a/src/algorithms/onnx/CalorimeterParticleIDPreML.h b/src/algorithms/onnx/CalorimeterParticleIDPreML.h index 9ae4858d66..bbdb340e10 100644 --- a/src/algorithms/onnx/CalorimeterParticleIDPreML.h +++ b/src/algorithms/onnx/CalorimeterParticleIDPreML.h @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -16,7 +17,7 @@ namespace eicrecon { using CalorimeterParticleIDPreMLAlgorithm = algorithms::Algorithm< - algorithms::Input>, algorithms::Output>>; @@ -25,8 +26,10 @@ class CalorimeterParticleIDPreML : public CalorimeterParticleIDPreMLAlgorithm, public: CalorimeterParticleIDPreML(std::string_view name) - : CalorimeterParticleIDPreMLAlgorithm{ - name, {"inputClusters"}, {"outputFeatureTensor", "outputTargetTensor"}, ""} {} + : CalorimeterParticleIDPreMLAlgorithm{name, + {"inputClusters", "inputTrackClusterMatches"}, + {"outputFeatureTensor", "outputTargetTensor"}, + ""} {} void init() final; void process(const Input&, const Output&) const final; diff --git a/src/detectors/EEMC/EEMC.cc b/src/detectors/EEMC/EEMC.cc index f03b906771..24cd49bdbb 100644 --- a/src/detectors/EEMC/EEMC.cc +++ b/src/detectors/EEMC/EEMC.cc @@ -150,7 +150,7 @@ void InitPlugin(JApplication* app) { #endif #if EDM4EIC_VERSION_MAJOR >= 8 {"EcalEndcapNClustersWithoutPIDAndShapes", // edm4eic::Cluster - "EcalEndcapNClusterAssociationsWithoutPIDAndShapes"}, // edm4eic::MCRecoClusterParticleAssociation + "EcalEndcapNClusterWithoutPIDAssociationsAndShapes"}, // edm4eic::MCRecoClusterParticleAssociation #else {"EcalEndcapNClustersWithoutShapes", // edm4eic::Cluster "EcalEndcapNClusterAssociationsWithoutShapes"}, // edm4eic::MCRecoClusterParticleAssociation @@ -168,8 +168,8 @@ void InitPlugin(JApplication* app) { #if EDM4EIC_VERSION_MAJOR >= 8 "EcalEndcapNClustersWithoutPID", {"EcalEndcapNClustersWithoutPIDAndShapes", - "EcalEndcapNClusterAssociationsWithoutPIDAndShapes"}, - {"EcalEndcapNClustersWithoutPID", "EcalEndcapNClusterAssociationsWithoutPID"}, + "EcalEndcapNClusterWithoutPIDAssociationsAndShapes"}, + {"EcalEndcapNClustersWithoutPID", "EcalEndcapNClusterWithoutPIDAssociations"}, #else "EcalEndcapNClusters", {"EcalEndcapNClustersWithoutShapes", "EcalEndcapNClusterAssociationsWithoutShapes"}, @@ -196,7 +196,8 @@ void InitPlugin(JApplication* app) { "EcalEndcapNParticleIDPreML", { "EcalEndcapNClustersWithoutPID", - "EcalEndcapNClusterAssociationsWithoutPID", + "EcalEndcapNTrackClusterWithoutPIDMatches", + "EcalEndcapNClusterWithoutPIDAssociations", }, { "EcalEndcapNParticleIDInput_features", @@ -220,11 +221,13 @@ void InitPlugin(JApplication* app) { "EcalEndcapNParticleIDPostML", { "EcalEndcapNClustersWithoutPID", - "EcalEndcapNClusterAssociationsWithoutPID", + "EcalEndcapNTrackClusterWithoutPIDMatches", + "EcalEndcapNClusterWithoutPIDAssociations", "EcalEndcapNParticleIDOutput_probability_tensor", }, { "EcalEndcapNClusters", + "EcalEndcapNTrackClusterMatches", "EcalEndcapNClusterAssociations", "EcalEndcapNClusterParticleIDs", }, diff --git a/src/factories/calorimetry/CalorimeterParticleIDPostML_factory.h b/src/factories/calorimetry/CalorimeterParticleIDPostML_factory.h index 399a3af72f..90926a6047 100644 --- a/src/factories/calorimetry/CalorimeterParticleIDPostML_factory.h +++ b/src/factories/calorimetry/CalorimeterParticleIDPostML_factory.h @@ -19,10 +19,12 @@ class CalorimeterParticleIDPostML_factory std::unique_ptr m_algo; PodioInput m_cluster_input{this}; + PodioInput m_track_cluster_matches_input{this}; PodioInput m_cluster_assoc_input{this}; PodioInput m_prediction_tensor_input{this}; PodioOutput m_cluster_output{this}; + PodioOutput m_track_cluster_matches_output{this}; PodioOutput m_cluster_assoc_output{this}; PodioOutput m_particle_id_output{this}; @@ -37,9 +39,10 @@ class CalorimeterParticleIDPostML_factory void ChangeRun(int32_t /* run_number */) {} void Process(int32_t /* run_number */, uint64_t /* event_number */) { - m_algo->process( - {m_cluster_input(), m_cluster_assoc_input(), m_prediction_tensor_input()}, - {m_cluster_output().get(), m_cluster_assoc_output().get(), m_particle_id_output().get()}); + m_algo->process({m_cluster_input(), m_track_cluster_matches_input(), m_cluster_assoc_input(), + m_prediction_tensor_input()}, + {m_cluster_output().get(), m_track_cluster_matches_output().get(), + m_cluster_assoc_output().get(), m_particle_id_output().get()}); } }; diff --git a/src/factories/calorimetry/CalorimeterParticleIDPreML_factory.h b/src/factories/calorimetry/CalorimeterParticleIDPreML_factory.h index fb5fe9e161..99f4da0c3d 100644 --- a/src/factories/calorimetry/CalorimeterParticleIDPreML_factory.h +++ b/src/factories/calorimetry/CalorimeterParticleIDPreML_factory.h @@ -19,6 +19,7 @@ class CalorimeterParticleIDPreML_factory std::unique_ptr m_algo; PodioInput m_cluster_input{this}; + PodioInput m_track_cluster_matches_input{this}; PodioInput m_cluster_assoc_input{this}; PodioOutput m_feature_tensor_output{this}; @@ -35,7 +36,7 @@ class CalorimeterParticleIDPreML_factory void ChangeRun(int32_t /* run_number */) {} void Process(int32_t /* run_number */, uint64_t /* event_number */) { - m_algo->process({m_cluster_input(), m_cluster_assoc_input()}, + m_algo->process({m_cluster_input(), m_track_cluster_matches_input(), m_cluster_assoc_input()}, {m_feature_tensor_output().get(), m_target_tensor_output().get()}); } }; diff --git a/src/global/reco/reco.cc b/src/global/reco/reco.cc index 4e75c56c95..3b6c28d956 100644 --- a/src/global/reco/reco.cc +++ b/src/global/reco/reco.cc @@ -209,14 +209,13 @@ void InitPlugin(JApplication* app) { // Backward app->Add(new JOmniFactoryGeneratorT( - "EcalEndcapNBarrelTrackClusterMatches", - {"CalorimeterTrackProjections", "EcalEndcapNClusters"}, {"EcalEndcapNTrackClusterMatches"}, - {.calo_id = "EcalEndcapN_ID"}, app)); + "EcalEndcapNTrackClusterMatches", + {"CalorimeterTrackProjections", "EcalEndcapNClustersWithoutPID"}, + {"EcalEndcapNTrackClusterWithoutPIDMatches"}, {.calo_id = "EcalEndcapN_ID"}, app)); app->Add(new JOmniFactoryGeneratorT( - "HcalEndcapNBarrelTrackClusterMatches", - {"CalorimeterTrackProjections", "HcalEndcapNClusters"}, {"HcalEndcapNTrackClusterMatches"}, - {.calo_id = "HcalEndcapN_ID"}, app)); + "HcalEndcapNTrackClusterMatches", {"CalorimeterTrackProjections", "HcalEndcapNClusters"}, + {"HcalEndcapNTrackClusterMatches"}, {.calo_id = "HcalEndcapN_ID"}, app)); #endif // EDM4EIC_VERSION_MAJOR >= 8