Skip to content

Commit

Permalink
[AnomalyDetection] Add base classes and specifiable protocol (#33845)
Browse files Browse the repository at this point in the history
* Add base classes and specifiable protocol for anomaly detection.

* Add subspaces to global specifiable map

* Add __init__.py

* Fix lints

* Fix get_subspace when calling from from_spec

* Refactor code, add tests and add docstrings.

* Minor changes to docstrings and comments

* Remove the fallback subspace '*' from accepted list. Use it in tests only.

* Bring fallback subspace back to accepted list. Clarify the use of spec_type to resolve naming conclict.

* Make _KNOWN_SPECIFIABLE a defaultdict. Remove error_if_exiists.

* Minor adjustment on docstrings.
  • Loading branch information
shunping authored Feb 10, 2025
1 parent 25425db commit 30a8c2d
Show file tree
Hide file tree
Showing 5 changed files with 1,265 additions and 0 deletions.
16 changes: 16 additions & 0 deletions sdks/python/apache_beam/ml/anomaly/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
211 changes: 211 additions & 0 deletions sdks/python/apache_beam/ml/anomaly/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
Base classes for anomaly detection
"""
from __future__ import annotations

import abc
from dataclasses import dataclass
from typing import Iterable
from typing import List
from typing import Optional

import apache_beam as beam

__all__ = [
"AnomalyPrediction",
"AnomalyResult",
"ThresholdFn",
"AggregationFn",
"AnomalyDetector",
"EnsembleAnomalyDetector"
]


@dataclass(frozen=True)
class AnomalyPrediction():
"""A dataclass for anomaly detection predictions."""
#: The ID of detector (model) that generates the prediction.
model_id: Optional[str] = None
#: The outlier score resulting from applying the detector to the input data.
score: Optional[float] = None
#: The outlier label (normal or outlier) derived from the outlier score.
label: Optional[int] = None
#: The threshold used to determine the label.
threshold: Optional[float] = None
#: Additional information about the prediction.
info: str = ""
#: If enabled, a list of `AnomalyPrediction` objects used to derive the
#: aggregated prediction.
agg_history: Optional[Iterable[AnomalyPrediction]] = None


@dataclass(frozen=True)
class AnomalyResult():
"""A dataclass for the anomaly detection results"""
#: The original input data.
example: beam.Row
#: The `AnomalyPrediction` object containing the prediction.
prediction: AnomalyPrediction


class ThresholdFn(abc.ABC):
"""An abstract base class for threshold functions.
Args:
normal_label: The integer label used to identify normal data. Defaults to 0.
outlier_label: The integer label used to identify outlier data. Defaults to
1.
"""
def __init__(self, normal_label: int = 0, outlier_label: int = 1):
self._normal_label = normal_label
self._outlier_label = outlier_label

@property
@abc.abstractmethod
def is_stateful(self) -> bool:
"""Indicates whether the threshold function is stateful or not."""
raise NotImplementedError

@property
@abc.abstractmethod
def threshold(self) -> Optional[float]:
"""Retrieves the current threshold value, or None if not set."""
raise NotImplementedError

@abc.abstractmethod
def apply(self, score: Optional[float]) -> int:
"""Applies the threshold function to a given score to classify it as
normal or outlier.
Args:
score: The outlier score generated from the detector (model).
Returns:
The label assigned to the score, either `self._normal_label`
or `self._outlier_label`
"""
raise NotImplementedError


class AggregationFn(abc.ABC):
"""An abstract base class for aggregation functions."""
@abc.abstractmethod
def apply(
self, predictions: Iterable[AnomalyPrediction]) -> AnomalyPrediction:
"""Applies the aggregation function to an iterable of predictions, either on
their outlier scores or labels.
Args:
predictions: An Iterable of `AnomalyPrediction` objects to aggregate.
Returns:
An `AnomalyPrediction` object containing the aggregated result.
"""
raise NotImplementedError


class AnomalyDetector(abc.ABC):
"""An abstract base class for anomaly detectors.
Args:
model_id: The ID of detector (model). Defaults to the value of the
`spec_type` attribute, or 'unknown' if not set.
features: An Iterable of strings representing the names of the input
features in the `beam.Row`
target: The name of the target field in the `beam.Row`.
threshold_criterion: An optional `ThresholdFn` to apply to the outlier score
and yield a label.
"""
def __init__(
self,
model_id: Optional[str] = None,
features: Optional[Iterable[str]] = None,
target: Optional[str] = None,
threshold_criterion: Optional[ThresholdFn] = None,
**kwargs):
self._model_id = model_id if model_id is not None else getattr(
self, 'spec_type', 'unknown')
self._features = features
self._target = target
self._threshold_criterion = threshold_criterion

@abc.abstractmethod
def learn_one(self, x: beam.Row) -> None:
"""Trains the detector on a single data instance.
Args:
x: A `beam.Row` representing the data instance.
"""
raise NotImplementedError

@abc.abstractmethod
def score_one(self, x: beam.Row) -> float:
"""Scores a single data instance for anomalies.
Args:
x: A `beam.Row` representing the data instance.
Returns:
The outlier score as a float.
"""
raise NotImplementedError


class EnsembleAnomalyDetector(AnomalyDetector):
"""An abstract base class for an ensemble of anomaly (sub-)detectors.
Args:
sub_detectors: A List of `AnomalyDetector` used in this ensemble model.
aggregation_strategy: An optional `AggregationFn` to apply to the
predictions from all sub-detectors and yield an aggregated result.
model_id: Inherited from `AnomalyDetector`.
features: Inherited from `AnomalyDetector`.
target: Inherited from `AnomalyDetector`.
threshold_criterion: Inherited from `AnomalyDetector`.
"""
def __init__(
self,
sub_detectors: Optional[List[AnomalyDetector]] = None,
aggregation_strategy: Optional[AggregationFn] = None,
**kwargs):
if "model_id" not in kwargs or kwargs["model_id"] is None:
kwargs["model_id"] = getattr(self, 'spec_type', 'custom')

super().__init__(**kwargs)

self._aggregation_strategy = aggregation_strategy
self._sub_detectors = sub_detectors

def learn_one(self, x: beam.Row) -> None:
"""Inherited from `AnomalyDetector.learn_one`.
This method is never called during ensemble detector training. The training
process is done on each sub-detector independently and in parallel.
"""
raise NotImplementedError

def score_one(self, x: beam.Row) -> float:
"""Inherited from `AnomalyDetector.score_one`.
This method is never called during ensemble detector scoring. The scoring
process is done on sub-detector independently and in parallel, and then
the results are aggregated in the pipeline.
"""
raise NotImplementedError
Loading

0 comments on commit 30a8c2d

Please sign in to comment.