Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/algorithms/onnx/CalorimeterParticleIDPostML.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ void CalorimeterParticleIDPostML::init() {
void CalorimeterParticleIDPostML::process(const CalorimeterParticleIDPostML::Input& input,
const CalorimeterParticleIDPostML::Output& output) const {

const auto [in_clusters, in_assocs, prediction_tensors] = input;
auto [out_clusters, out_assocs, out_particle_ids] = output;
const auto [in_clusters, in_track_matches, in_assocs, prediction_tensors] = input;
auto [out_clusters, out_track_matches, out_assocs, out_particle_ids] = output;

if (prediction_tensors->size() != 1) {
error("Expected to find a single tensor, found {}", prediction_tensors->size());
Expand Down Expand Up @@ -79,6 +79,15 @@ void CalorimeterParticleIDPostML::process(const CalorimeterParticleIDPostML::Inp
prob_electron // float likelihood
));

// propagate track matches
for (auto in_track_match : *in_track_matches) {
if (in_track_match.getCluster() == in_cluster) {
auto out_track_match = in_track_match.clone();
out_track_match.setCluster(out_cluster);
out_track_matches->push_back(out_track_match);
}
}

// propagate associations
for (auto in_assoc : *in_assocs) {
if (in_assoc.getRec() == in_cluster) {
Expand Down
5 changes: 3 additions & 2 deletions src/algorithms/onnx/CalorimeterParticleIDPostML.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <algorithms/algorithm.h>
#include <edm4eic/ClusterCollection.h>
#include <edm4eic/TrackClusterMatchCollection.h>
#include <edm4eic/MCRecoClusterParticleAssociationCollection.h>
#include <edm4eic/TensorCollection.h>
#include <edm4hep/ParticleIDCollection.h>
Expand All @@ -17,10 +18,10 @@
namespace eicrecon {

using CalorimeterParticleIDPostMLAlgorithm = algorithms::Algorithm<
algorithms::Input<edm4eic::ClusterCollection,
algorithms::Input<edm4eic::ClusterCollection, edm4eic::TrackClusterMatchCollection,
std::optional<edm4eic::MCRecoClusterParticleAssociationCollection>,
edm4eic::TensorCollection>,
algorithms::Output<edm4eic::ClusterCollection,
algorithms::Output<edm4eic::ClusterCollection, edm4eic::TrackClusterMatchCollection,
std::optional<edm4eic::MCRecoClusterParticleAssociationCollection>,
edm4hep::ParticleIDCollection>>;

Expand Down
31 changes: 15 additions & 16 deletions src/algorithms/onnx/CalorimeterParticleIDPreML.cc
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
// SPDX-License-Identifier: LGPL-3.0-or-later
// Copyright (C) 2024 Dmitry Kalinkin
// Copyright (C) 2024 - 2025 Dmitry Kalinkin

#include <edm4eic/EDM4eicVersion.h>

#if EDM4EIC_VERSION_MAJOR >= 8
#include <cstddef>
#include <cstdint>
#include <edm4hep/MCParticle.h>
#include <edm4eic/Track.h>
#include <edm4hep/Vector3f.h>
#include <edm4hep/utils/vector_utils.h>
#include <fmt/core.h>
#include <cmath>
#include <stdexcept>

#include <cstddef>
#include <cstdint>
#include <gsl/pointers>
#include <stdexcept>

#include "CalorimeterParticleIDPreML.h"

Expand All @@ -26,8 +26,8 @@ void CalorimeterParticleIDPreML::init() {
void CalorimeterParticleIDPreML::process(const CalorimeterParticleIDPreML::Input& input,
const CalorimeterParticleIDPreML::Output& output) const {

const auto [clusters, cluster_assocs] = input;
auto [feature_tensors, target_tensors] = output;
const auto [clusters, track_matches, cluster_assocs] = input;
auto [feature_tensors, target_tensors] = output;

edm4eic::MutableTensor feature_tensor = feature_tensors->create();
feature_tensor.addToShape(clusters->size());
Expand All @@ -45,19 +45,18 @@ void CalorimeterParticleIDPreML::process(const CalorimeterParticleIDPreML::Input
for (edm4eic::Cluster cluster : *clusters) {
double momentum = NAN;
{
// FIXME: use track momentum once matching to tracks becomes available
edm4eic::MCRecoClusterParticleAssociation best_assoc;
for (auto assoc : *cluster_assocs) {
if (assoc.getRec() == cluster) {
if ((not best_assoc.isAvailable()) || (assoc.getWeight() > best_assoc.getWeight())) {
best_assoc = assoc;
auto best_match = edm4eic::TrackClusterMatch::makeEmpty();
for (auto match : *track_matches) {
if (match.getCluster() == cluster) {
if ((not best_match.isAvailable()) || (match.getWeight() > best_match.getWeight())) {
best_match = match;
}
}
}
if (best_assoc.isAvailable()) {
momentum = edm4hep::utils::magnitude(best_assoc.getSim().getMomentum());
if (best_match.isAvailable()) {
momentum = edm4hep::utils::magnitude(best_match.getTrack().getMomentum());
} else {
warning("Can't find association for cluster. Skipping...");
trace("Can't find a match for the cluster. Skipping...");
continue;
}
}
Expand Down
9 changes: 6 additions & 3 deletions src/algorithms/onnx/CalorimeterParticleIDPreML.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <algorithms/algorithm.h>
#include <edm4eic/ClusterCollection.h>
#include <edm4eic/TrackClusterMatchCollection.h>
#include <edm4eic/MCRecoClusterParticleAssociationCollection.h>
#include <edm4eic/TensorCollection.h>
#include <optional>
Expand All @@ -16,7 +17,7 @@
namespace eicrecon {

using CalorimeterParticleIDPreMLAlgorithm = algorithms::Algorithm<
algorithms::Input<edm4eic::ClusterCollection,
algorithms::Input<edm4eic::ClusterCollection, edm4eic::TrackClusterMatchCollection,
std::optional<edm4eic::MCRecoClusterParticleAssociationCollection>>,
algorithms::Output<edm4eic::TensorCollection, std::optional<edm4eic::TensorCollection>>>;

Expand All @@ -25,8 +26,10 @@ class CalorimeterParticleIDPreML : public CalorimeterParticleIDPreMLAlgorithm,

public:
CalorimeterParticleIDPreML(std::string_view name)
: CalorimeterParticleIDPreMLAlgorithm{
name, {"inputClusters"}, {"outputFeatureTensor", "outputTargetTensor"}, ""} {}
: CalorimeterParticleIDPreMLAlgorithm{name,
{"inputClusters", "inputTrackClusterMatches"},
{"outputFeatureTensor", "outputTargetTensor"},
""} {}

void init() final;
void process(const Input&, const Output&) const final;
Expand Down
13 changes: 8 additions & 5 deletions src/detectors/EEMC/EEMC.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ void InitPlugin(JApplication* app) {
#endif
#if EDM4EIC_VERSION_MAJOR >= 8
{"EcalEndcapNClustersWithoutPIDAndShapes", // edm4eic::Cluster
"EcalEndcapNClusterAssociationsWithoutPIDAndShapes"}, // edm4eic::MCRecoClusterParticleAssociation
"EcalEndcapNClusterWithoutPIDAssociationsAndShapes"}, // edm4eic::MCRecoClusterParticleAssociation
#else
{"EcalEndcapNClustersWithoutShapes", // edm4eic::Cluster
"EcalEndcapNClusterAssociationsWithoutShapes"}, // edm4eic::MCRecoClusterParticleAssociation
Expand All @@ -168,8 +168,8 @@ void InitPlugin(JApplication* app) {
#if EDM4EIC_VERSION_MAJOR >= 8
"EcalEndcapNClustersWithoutPID",
{"EcalEndcapNClustersWithoutPIDAndShapes",
"EcalEndcapNClusterAssociationsWithoutPIDAndShapes"},
{"EcalEndcapNClustersWithoutPID", "EcalEndcapNClusterAssociationsWithoutPID"},
"EcalEndcapNClusterWithoutPIDAssociationsAndShapes"},
{"EcalEndcapNClustersWithoutPID", "EcalEndcapNClusterWithoutPIDAssociations"},
#else
"EcalEndcapNClusters",
{"EcalEndcapNClustersWithoutShapes", "EcalEndcapNClusterAssociationsWithoutShapes"},
Expand All @@ -196,7 +196,8 @@ void InitPlugin(JApplication* app) {
"EcalEndcapNParticleIDPreML",
{
"EcalEndcapNClustersWithoutPID",
"EcalEndcapNClusterAssociationsWithoutPID",
"EcalEndcapNTrackClusterWithoutPIDMatches",
"EcalEndcapNClusterWithoutPIDAssociations",
},
{
"EcalEndcapNParticleIDInput_features",
Expand All @@ -220,11 +221,13 @@ void InitPlugin(JApplication* app) {
"EcalEndcapNParticleIDPostML",
{
"EcalEndcapNClustersWithoutPID",
"EcalEndcapNClusterAssociationsWithoutPID",
"EcalEndcapNTrackClusterWithoutPIDMatches",
"EcalEndcapNClusterWithoutPIDAssociations",
"EcalEndcapNParticleIDOutput_probability_tensor",
},
{
"EcalEndcapNClusters",
"EcalEndcapNTrackClusterMatches",
"EcalEndcapNClusterAssociations",
"EcalEndcapNClusterParticleIDs",
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ class CalorimeterParticleIDPostML_factory
std::unique_ptr<AlgoT> m_algo;

PodioInput<edm4eic::Cluster> m_cluster_input{this};
PodioInput<edm4eic::TrackClusterMatch> m_track_cluster_matches_input{this};
PodioInput<edm4eic::MCRecoClusterParticleAssociation> m_cluster_assoc_input{this};
PodioInput<edm4eic::Tensor> m_prediction_tensor_input{this};

PodioOutput<edm4eic::Cluster> m_cluster_output{this};
PodioOutput<edm4eic::TrackClusterMatch> m_track_cluster_matches_output{this};
PodioOutput<edm4eic::MCRecoClusterParticleAssociation> m_cluster_assoc_output{this};
PodioOutput<edm4hep::ParticleID> m_particle_id_output{this};

Expand All @@ -37,9 +39,10 @@ class CalorimeterParticleIDPostML_factory
void ChangeRun(int32_t /* run_number */) {}

void Process(int32_t /* run_number */, uint64_t /* event_number */) {
m_algo->process(
{m_cluster_input(), m_cluster_assoc_input(), m_prediction_tensor_input()},
{m_cluster_output().get(), m_cluster_assoc_output().get(), m_particle_id_output().get()});
m_algo->process({m_cluster_input(), m_track_cluster_matches_input(), m_cluster_assoc_input(),
m_prediction_tensor_input()},
{m_cluster_output().get(), m_track_cluster_matches_output().get(),
m_cluster_assoc_output().get(), m_particle_id_output().get()});
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class CalorimeterParticleIDPreML_factory
std::unique_ptr<AlgoT> m_algo;

PodioInput<edm4eic::Cluster> m_cluster_input{this};
PodioInput<edm4eic::TrackClusterMatch> m_track_cluster_matches_input{this};
PodioInput<edm4eic::MCRecoClusterParticleAssociation> m_cluster_assoc_input{this};

PodioOutput<edm4eic::Tensor> m_feature_tensor_output{this};
Expand All @@ -35,7 +36,7 @@ class CalorimeterParticleIDPreML_factory
void ChangeRun(int32_t /* run_number */) {}

void Process(int32_t /* run_number */, uint64_t /* event_number */) {
m_algo->process({m_cluster_input(), m_cluster_assoc_input()},
m_algo->process({m_cluster_input(), m_track_cluster_matches_input(), m_cluster_assoc_input()},
{m_feature_tensor_output().get(), m_target_tensor_output().get()});
}
};
Expand Down
11 changes: 5 additions & 6 deletions src/global/reco/reco.cc
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,13 @@ void InitPlugin(JApplication* app) {

// Backward
app->Add(new JOmniFactoryGeneratorT<TrackClusterMatch_factory>(
"EcalEndcapNBarrelTrackClusterMatches",
{"CalorimeterTrackProjections", "EcalEndcapNClusters"}, {"EcalEndcapNTrackClusterMatches"},
{.calo_id = "EcalEndcapN_ID"}, app));
"EcalEndcapNTrackClusterMatches",
{"CalorimeterTrackProjections", "EcalEndcapNClustersWithoutPID"},
{"EcalEndcapNTrackClusterWithoutPIDMatches"}, {.calo_id = "EcalEndcapN_ID"}, app));

app->Add(new JOmniFactoryGeneratorT<TrackClusterMatch_factory>(
"HcalEndcapNBarrelTrackClusterMatches",
{"CalorimeterTrackProjections", "HcalEndcapNClusters"}, {"HcalEndcapNTrackClusterMatches"},
{.calo_id = "HcalEndcapN_ID"}, app));
"HcalEndcapNTrackClusterMatches", {"CalorimeterTrackProjections", "HcalEndcapNClusters"},
{"HcalEndcapNTrackClusterMatches"}, {.calo_id = "HcalEndcapN_ID"}, app));

#endif // EDM4EIC_VERSION_MAJOR >= 8

Expand Down
Loading