diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 137ce180071..63528105a57 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,14 +1,16 @@ # detection-rules code owners # POC: Elastic Security Intelligence and Analytics Team -tests/**/*.py @mikaayenson @eric-forte-elastic @terrancedejesus -detection_rules/ @mikaayenson @eric-forte-elastic @terrancedejesus -tests/ @mikaayenson @eric-forte-elastic @terrancedejesus -lib/ @mikaayenson @eric-forte-elastic @terrancedejesus -hunting/ @mikaayenson @eric-forte-elastic @terrancedejesus +tests/**/*.py @mikaayenson @eric-forte-elastic @traut +detection_rules/ @mikaayenson @eric-forte-elastic @traut +tests/ @mikaayenson @eric-forte-elastic @traut +lib/ @mikaayenson @eric-forte-elastic @traut +hunting/**/*.py @mikaayenson @eric-forte-elastic @traut # skip rta-mapping to avoid the spam -detection_rules/etc/packages.yaml @mikaayenson @eric-forte-elastic @terrancedejesus -detection_rules/etc/*.json @mikaayenson @eric-forte-elastic @terrancedejesus -detection_rules/etc/*.json @mikaayenson @eric-forte-elastic @terrancedejesus -detection_rules/etc/*/* @mikaayenson @eric-forte-elastic @terrancedejesus +detection_rules/etc/packages.yaml @mikaayenson @eric-forte-elastic @traut +detection_rules/etc/*.json @mikaayenson @eric-forte-elastic @traut +detection_rules/etc/*/* @mikaayenson @eric-forte-elastic @traut + +# exclude files from code owners +detection_rules/etc/non-ecs-schema.json diff --git a/.github/PULL_REQUEST_GUIDELINES/bug_guidelines.md b/.github/PULL_REQUEST_GUIDELINES/bug_guidelines.md index bdb5359cf95..9b25a8cbae8 100644 --- a/.github/PULL_REQUEST_GUIDELINES/bug_guidelines.md +++ b/.github/PULL_REQUEST_GUIDELINES/bug_guidelines.md @@ -11,11 +11,7 @@ These guidelines serve as a reminder set of considerations when addressing a bug ### Code Standards and Practices - [ ] Code follows established design patterns within the repo and avoids duplication. -- [ ] Code changes do not introduce new warnings or errors. -- [ ] Variables and functions are well-named and descriptive. -- [ ] Any unnecessary / commented-out code is removed. - [ ] Ensure that the code is modular and reusable where applicable. -- [ ] Check for proper exception handling and messaging. ### Testing @@ -25,11 +21,9 @@ These guidelines serve as a reminder set of considerations when addressing a bug - [ ] Validate that any rules affected by the bug are correctly updated. - [ ] Ensure that performance is not negatively impacted by the changes. - [ ] Verify that any release artifacts are properly generated and tested. +- [ ] Conducted system testing, including fleet, import, and create APIs (e.g., run `make test-cli`, `make test-remote-cli`, `make test-hunting-cli`) ### Additional Checks -- [ ] Ensure that the bug fix does not break existing functionality. -- [ ] Review the bug fix with a peer or team member for additional insights. - [ ] Verify that the bug fix works across all relevant environments (e.g., different OS versions). -- [ ] Confirm that all dependencies are up-to-date and compatible with the changes. - [ ] Confirm that the proper version label is applied to the PR `patch`, `minor`, `major`. diff --git a/.github/PULL_REQUEST_GUIDELINES/enhancement_guidelines.md b/.github/PULL_REQUEST_GUIDELINES/enhancement_guidelines.md index c02664f7e33..30c210d9498 100644 --- a/.github/PULL_REQUEST_GUIDELINES/enhancement_guidelines.md +++ b/.github/PULL_REQUEST_GUIDELINES/enhancement_guidelines.md @@ -11,11 +11,7 @@ These guidelines serve as a reminder set of considerations when addressing addin ### Code Standards and Practices - [ ] Code follows established design patterns within the repo and avoids duplication. -- [ ] Code changes do not introduce new warnings or errors. -- [ ] Variables and functions are well-named and descriptive. -- [ ] Any unnecessary / commented-out code is removed. - [ ] Ensure that the code is modular and reusable where applicable. -- [ ] Check for proper exception handling and messaging. ### Testing @@ -25,11 +21,9 @@ These guidelines serve as a reminder set of considerations when addressing addin - [ ] Validate that any rules affected by the enhancement are correctly updated. - [ ] Ensure that performance is not negatively impacted by the changes. - [ ] Verify that any release artifacts are properly generated and tested. +- [ ] Conducted system testing, including fleet, import, and create APIs (e.g., run `make test-cli`, `make test-remote-cli`, `make test-hunting-cli`) ### Additional Checks -- [ ] Ensure that the enhancement does not break existing functionality. -- [ ] Review the enhancement with a peer or team member for additional insights. - [ ] Verify that the enhancement works across all relevant environments (e.g., different OS versions). -- [ ] Confirm that all dependencies are up-to-date and compatible with the changes. - [ ] Confirm that the proper version label is applied to the PR `patch`, `minor`, `major`. diff --git a/.github/PULL_REQUEST_GUIDELINES/schema_enhancement_guidelines.md b/.github/PULL_REQUEST_GUIDELINES/schema_enhancement_guidelines.md index 3d715339d32..ef5e52308c7 100644 --- a/.github/PULL_REQUEST_GUIDELINES/schema_enhancement_guidelines.md +++ b/.github/PULL_REQUEST_GUIDELINES/schema_enhancement_guidelines.md @@ -11,11 +11,7 @@ These guidelines serve as a reminder set of considerations when addressing addin ### Code Standards and Practices - [ ] Code follows established design patterns within the repo and avoids duplication. -- [ ] Code changes do not introduce new warnings or errors. -- [ ] Variables and functions are well-named and descriptive. -- [ ] Any unnecessary / commented-out code is removed. - [ ] Ensure that the code is modular and reusable where applicable. -- [ ] Check for proper exception handling and messaging. ### Testing @@ -25,23 +21,21 @@ These guidelines serve as a reminder set of considerations when addressing addin - [ ] Validate that any rules affected by the enhancement are correctly updated. - [ ] Ensure that performance is not negatively impacted by the changes. - [ ] Verify that any release artifacts are properly generated and tested. +- [ ] Conducted system testing, including fleet, import, and create APIs (e.g., run `make test-cli`, `make test-remote-cli`, `make test-hunting-cli`) ### Additional Schema Related Checks -- [ ] Ensure that the enhancement does not break existing functionality. (e.g., run `make test-cli`) -- [ ] Review the enhancement with a peer or team member for additional insights. - [ ] Verify that the enhancement works across all relevant environments (e.g., different OS versions). -- [ ] Confirm that all dependencies are up-to-date and compatible with the changes. - [ ] Link to the relevant Kibana PR or issue provided -- [ ] Exported detection rule(s) from Kibana to showcase the feature(s) -- [ ] Converted the exported ndjson file(s) to toml in the detection-rules repo -- [ ] Re-exported the toml rule(s) to ndjson and re-imported into Kibana +- [ ] Test export/import flow: + - [ ] Exported detection rule(s) from Kibana to showcase the feature(s) + - [ ] Converted the exported ndjson file(s) to toml in the detection-rules repo + - [ ] Re-exported the toml rule(s) to ndjson and re-imported into Kibana - [ ] Updated necessary unit tests to accommodate the feature +- [ ] Incorporated a comprehensive test rule in unit tests for full schema coverage - [ ] Applied min_compat restrictions to limit the feature to a specified minimum stack version - [ ] Executed all unit tests locally with a test toml rule to confirm passing - [ ] Included Kibana PR implementer as an optional reviewer for insights on the feature - [ ] Implemented requisite downgrade functionality - [ ] Cross-referenced the feature with product documentation for consistency -- [ ] Incorporated a comprehensive test rule in unit tests for full schema coverage -- [ ] Conducted system testing, including fleet, import, and create APIs (e.g., run `make test-remote-cli`) - [ ] Confirm that the proper version label is applied to the PR `patch`, `minor`, `major`. diff --git a/.github/workflows/code-checks.yaml b/.github/workflows/code-checks.yaml new file mode 100644 index 00000000000..7e5698b206b --- /dev/null +++ b/.github/workflows/code-checks.yaml @@ -0,0 +1,47 @@ +name: Code checks + +on: + push: + branches: [ "main", "7.*", "8.*", "9.*" ] + pull_request: + branches: [ "*" ] + paths: + - 'detection_rules/**/*.py' + - 'hunting/**/*.py' + +jobs: + code-checks: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 1 + + - name: Set up Python 3.13 + uses: actions/setup-python@v5 + with: + python-version: '3.13' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip cache purge + pip install .[dev] + + - name: Linting check + run: | + ruff check --exit-non-zero-on-fix + + - name: Formatting check + run: | + ruff format --check + + - name: Pyright check + run: | + pyright + + - name: Python License Check + run: | + python -m detection_rules dev license-check diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index 97ca3e0f025..5b454cd7187 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -20,10 +20,10 @@ jobs: run: | git fetch origin main:refs/remotes/origin/main - - name: Set up Python 3.12 + - name: Set up Python 3.13 uses: actions/setup-python@v5 with: - python-version: '3.12' + python-version: '3.13' - name: Install dependencies run: | @@ -31,14 +31,6 @@ jobs: pip cache purge pip install .[dev] - - name: Python Lint - run: | - python -m flake8 tests detection_rules --ignore D203,N815 --max-line-length 120 - - - name: Python License Check - run: | - python -m detection_rules dev license-check - - name: Unit tests env: # only run the test test_rule_change_has_updated_date on pull request events to main diff --git a/Makefile b/Makefile index 88f7e370b3a..c25a785719d 100644 --- a/Makefile +++ b/Makefile @@ -43,7 +43,9 @@ license-check: $(VENV) deps .PHONY: lint lint: $(VENV) deps @echo "LINTING" - $(PYTHON) -m flake8 tests detection_rules --ignore D203,N815 --max-line-length 120 + $(PYTHON) -m ruff check --exit-non-zero-on-fix + $(PYTHON) -m ruff format --check + $(PYTHON) -m pyright .PHONY: test test: $(VENV) lint pytest diff --git a/detection_rules/__init__.py b/detection_rules/__init__.py index ebf6fdb0b37..a82df289240 100644 --- a/detection_rules/__init__.py +++ b/detection_rules/__init__.py @@ -5,43 +5,38 @@ """Detection rules.""" -import sys - - -assert (3, 12) <= sys.version_info < (4, 0), "Only Python 3.12+ supported" - -from . import ( # noqa: E402 - custom_schemas, +from . import ( custom_rules, + custom_schemas, devtools, docs, eswrap, ghwrap, kbwrap, main, - ml, misc, + ml, navigator, rule_formatter, rule_loader, schemas, - utils + utils, ) __all__ = ( - 'custom_rules', - 'custom_schemas', - 'devtools', - 'docs', - 'eswrap', - 'ghwrap', - 'kbwrap', + "custom_rules", + "custom_schemas", + "devtools", + "docs", + "eswrap", + "ghwrap", + "kbwrap", "main", - 'misc', - 'ml', - 'navigator', - 'rule_formatter', - 'rule_loader', - 'schemas', - 'utils' + "misc", + "ml", + "navigator", + "rule_formatter", + "rule_loader", + "schemas", + "utils", ) diff --git a/detection_rules/__main__.py b/detection_rules/__main__.py index 8d576e3e014..c768c465c3b 100644 --- a/detection_rules/__main__.py +++ b/detection_rules/__main__.py @@ -3,17 +3,13 @@ # 2.0; you may not use this file except in compliance with the Elastic License # 2.0. -# coding=utf-8 """Shell for detection-rules.""" -import sys + from pathlib import Path import click -assert (3, 12) <= sys.version_info < (4, 0), "Only Python 3.12+ supported" - - -from .main import root # noqa: E402 +from .main import root CURR_DIR = Path(__file__).resolve().parent CLI_DIR = CURR_DIR.parent @@ -26,7 +22,7 @@ """ -def main(): +def main() -> None: """CLI entry point.""" click.echo(BANNER) root(prog_name="detection_rules") diff --git a/detection_rules/action.py b/detection_rules/action.py index 95ee9b997d2..678973be360 100644 --- a/detection_rules/action.py +++ b/detection_rules/action.py @@ -4,9 +4,10 @@ # 2.0. """Dataclasses for Action.""" + from dataclasses import dataclass from pathlib import Path -from typing import List, Optional +from typing import Any from .mixins import MarshmallowDataclassMixin from .schemas import definitions @@ -15,50 +16,56 @@ @dataclass(frozen=True) class ActionMeta(MarshmallowDataclassMixin): """Data stored in an exception's [metadata] section of TOML.""" + creation_date: definitions.Date - rule_id: List[definitions.UUIDString] + rule_id: list[definitions.UUIDString] rule_name: str updated_date: definitions.Date # Optional fields - deprecation_date: Optional[definitions.Date] - comments: Optional[str] - maturity: Optional[definitions.Maturity] + deprecation_date: definitions.Date | None = None + comments: str | None = None + maturity: definitions.Maturity | None = None -@dataclass +@dataclass(frozen=True) class Action(MarshmallowDataclassMixin): """Data object for rule Action.""" + @dataclass class ActionParams: """Data object for rule Action params.""" + body: str action_type_id: definitions.ActionTypeId group: str params: ActionParams - id: Optional[str] - frequency: Optional[dict] - alerts_filter: Optional[dict] + + id: str | None = None + frequency: dict[str, Any] | None = None + alerts_filter: dict[str, Any] | None = None @dataclass(frozen=True) class TOMLActionContents(MarshmallowDataclassMixin): """Object for action from TOML file.""" + metadata: ActionMeta - actions: List[Action] + actions: list[Action] @dataclass(frozen=True) class TOMLAction: """Object for action from TOML file.""" + contents: TOMLActionContents path: Path @property - def name(self): + def name(self) -> str: return self.contents.metadata.rule_name @property - def id(self): + def id(self) -> list[definitions.UUIDString]: return self.contents.metadata.rule_id diff --git a/detection_rules/action_connector.py b/detection_rules/action_connector.py index 8a31c2a8f0e..404b92be4db 100644 --- a/detection_rules/action_connector.py +++ b/detection_rules/action_connector.py @@ -4,17 +4,18 @@ # 2.0. """Dataclasses for Action.""" + from dataclasses import dataclass from datetime import datetime from pathlib import Path -from typing import List, Optional, Tuple +from typing import Any -import pytoml +import pytoml # type: ignore[reportMissingTypeStubs] from marshmallow import EXCLUDE +from .config import parse_rules_config from .mixins import MarshmallowDataclassMixin from .schemas import definitions -from .config import parse_rules_config RULES_CONFIG = parse_rules_config() @@ -25,26 +26,26 @@ class ActionConnectorMeta(MarshmallowDataclassMixin): creation_date: definitions.Date action_connector_name: str - rule_ids: List[definitions.UUIDString] - rule_names: List[str] + rule_ids: list[definitions.UUIDString] + rule_names: list[str] updated_date: definitions.Date # Optional fields - deprecation_date: Optional[definitions.Date] - comments: Optional[str] - maturity: Optional[definitions.Maturity] + deprecation_date: definitions.Date | None = None + comments: str | None = None + maturity: definitions.Maturity | None = None -@dataclass +@dataclass(frozen=True) class ActionConnector(MarshmallowDataclassMixin): """Data object for rule Action Connector.""" id: str - attributes: dict - frequency: Optional[dict] - managed: Optional[bool] - type: Optional[str] - references: Optional[List] + attributes: dict[str, Any] + frequency: dict[str, Any] | None = None + managed: bool | None = None + type: str | None = None + references: list[Any] | None = None @dataclass(frozen=True) @@ -52,21 +53,23 @@ class TOMLActionConnectorContents(MarshmallowDataclassMixin): """Object for action connector from TOML file.""" metadata: ActionConnectorMeta - action_connectors: List[ActionConnector] + action_connectors: list[ActionConnector] @classmethod - def from_action_connector_dict(cls, actions_dict: dict, rule_list: dict) -> "TOMLActionConnectorContents": + def from_action_connector_dict( + cls, actions_dict: dict[str, Any], rule_list: list[dict[str, Any]] + ) -> "TOMLActionConnectorContents": """Create a TOMLActionContents from a kibana rule resource.""" - rule_ids = [] - rule_names = [] + rule_ids: list[str] = [] + rule_names: list[str] = [] for rule in rule_list: rule_ids.append(rule["id"]) rule_names.append(rule["name"]) # Format date to match schema - creation_date = datetime.strptime(actions_dict["created_at"], "%Y-%m-%dT%H:%M:%S.%fZ").strftime("%Y/%m/%d") - updated_date = datetime.strptime(actions_dict["updated_at"], "%Y-%m-%dT%H:%M:%S.%fZ").strftime("%Y/%m/%d") + creation_date = datetime.strptime(actions_dict["created_at"], "%Y-%m-%dT%H:%M:%S.%fZ").strftime("%Y/%m/%d") # noqa: DTZ007 + updated_date = datetime.strptime(actions_dict["updated_at"], "%Y-%m-%dT%H:%M:%S.%fZ").strftime("%Y/%m/%d") # noqa: DTZ007 metadata = { "creation_date": creation_date, "rule_ids": rule_ids, @@ -77,13 +80,9 @@ def from_action_connector_dict(cls, actions_dict: dict, rule_list: dict) -> "TOM return cls.from_dict({"metadata": metadata, "action_connectors": [actions_dict]}, unknown=EXCLUDE) - def to_api_format(self) -> List[dict]: + def to_api_format(self) -> list[dict[str, Any]]: """Convert the TOML Action Connector to the API format.""" - converted = [] - - for action in self.action_connectors: - converted.append(action.to_dict()) - return converted + return [action.to_dict() for action in self.action_connectors] @dataclass(frozen=True) @@ -94,12 +93,13 @@ class TOMLActionConnector: path: Path @property - def name(self): + def name(self) -> str: return self.contents.metadata.action_connector_name - def save_toml(self): + def save_toml(self) -> None: """Save the action to a TOML file.""" - assert self.path is not None, f"Can't save action for {self.name} without a path" + if not self.path: + raise ValueError(f"Can't save action for {self.name} without a path") # Check if self.path has a .toml extension path = self.path if path.suffix != ".toml": @@ -109,13 +109,15 @@ def save_toml(self): contents_dict = self.contents.to_dict() # Sort the dictionary so that 'metadata' is at the top sorted_dict = dict(sorted(contents_dict.items(), key=lambda item: item[0] != "metadata")) - pytoml.dump(sorted_dict, f) + pytoml.dump(sorted_dict, f) # type: ignore[reportUnknownMemberType] -def parse_action_connector_results_from_api(results: List[dict]) -> tuple[List[dict], List[dict]]: +def parse_action_connector_results_from_api( + results: list[dict[str, Any]], +) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: """Filter Kibana export rule results for action connector dictionaries.""" - action_results = [] - non_action_results = [] + action_results: list[dict[str, Any]] = [] + non_action_results: list[dict[str, Any]] = [] for result in results: if result.get("type") != "action": non_action_results.append(result) @@ -125,43 +127,47 @@ def parse_action_connector_results_from_api(results: List[dict]) -> tuple[List[d return action_results, non_action_results -def build_action_connector_objects(action_connectors: List[dict], action_connector_rule_table: dict, - action_connectors_directory: Path, save_toml: bool = False, - skip_errors: bool = False, verbose=False, - ) -> Tuple[List[TOMLActionConnector], List[str], List[str]]: +def build_action_connector_objects( # noqa: PLR0913 + action_connectors: list[dict[str, Any]], + action_connector_rule_table: dict[str, Any], + action_connectors_directory: Path | None, + save_toml: bool = False, + skip_errors: bool = False, + verbose: bool = False, +) -> tuple[list[TOMLActionConnector], list[str], list[str]]: """Build TOMLActionConnector objects from a list of action connector dictionaries.""" - output = [] - errors = [] - toml_action_connectors = [] + output: list[str] = [] + errors: list[str] = [] + toml_action_connectors: list[TOMLActionConnector] = [] for action_connector_dict in action_connectors: try: - connector_id = action_connector_dict.get("id") + connector_id = action_connector_dict["id"] rule_list = action_connector_rule_table.get(connector_id) if not rule_list: output.append(f"Warning action connector {connector_id} has no associated rules. Loading skipped.") continue - else: - contents = TOMLActionConnectorContents.from_action_connector_dict(action_connector_dict, rule_list) - filename = f"{connector_id}_actions.toml" - if RULES_CONFIG.action_connector_dir is None and not action_connectors_directory: - raise FileNotFoundError( - "No Action Connector directory is specified. Please specify either in the config or CLI." - ) - actions_path = ( - Path(action_connectors_directory) / filename - if action_connectors_directory - else RULES_CONFIG.action_connector_dir / filename - ) - if verbose: - output.append(f"[+] Building action connector(s) for {actions_path}") - - ac_object = TOMLActionConnector( - contents=contents, - path=actions_path, + contents = TOMLActionConnectorContents.from_action_connector_dict(action_connector_dict, rule_list) + filename = f"{connector_id}_actions.toml" + if RULES_CONFIG.action_connector_dir is None and not action_connectors_directory: + raise FileNotFoundError( # noqa: TRY301 + "No Action Connector directory is specified. Please specify either in the config or CLI." ) - if save_toml: - ac_object.save_toml() - toml_action_connectors.append(ac_object) + actions_path = ( + Path(action_connectors_directory) / filename + if action_connectors_directory + else RULES_CONFIG.action_connector_dir / filename + ) + if verbose: + output.append(f"[+] Building action connector(s) for {actions_path}") + + ac_object = TOMLActionConnector( + contents=contents, + path=actions_path, + ) + if save_toml: + ac_object.save_toml() + + toml_action_connectors.append(ac_object) except Exception as e: if skip_errors: diff --git a/detection_rules/attack.py b/detection_rules/attack.py index 2178d2c8cd5..71725ffe503 100644 --- a/detection_rules/attack.py +++ b/detection_rules/attack.py @@ -4,240 +4,247 @@ # 2.0. """Mitre attack info.""" + +import json import re import time +from collections import OrderedDict from pathlib import Path -from typing import Optional +from typing import Any -import json import requests -from collections import OrderedDict - from semver import Version -from .utils import cached, clear_caches, get_etc_path, get_etc_glob_path, read_gzip, gzip_compress -PLATFORMS = ['Windows', 'macOS', 'Linux'] -CROSSWALK_FILE = get_etc_path('attack-crosswalk.json') -TECHNIQUES_REDIRECT_FILE = get_etc_path('attack-technique-redirects.json') +from .utils import cached, clear_caches, get_etc_glob_path, get_etc_path, gzip_compress, read_gzip -tactics_map = {} +PLATFORMS = ["Windows", "macOS", "Linux"] +CROSSWALK_FILE = get_etc_path(["attack-crosswalk.json"]) +TECHNIQUES_REDIRECT_FILE = get_etc_path(["attack-technique-redirects.json"]) + +tactics_map: dict[str, Any] = {} @cached -def load_techniques_redirect() -> dict: - return json.loads(TECHNIQUES_REDIRECT_FILE.read_text())['mapping'] +def load_techniques_redirect() -> dict[str, Any]: + return json.loads(TECHNIQUES_REDIRECT_FILE.read_text())["mapping"] def get_attack_file_path() -> Path: - pattern = 'attack-v*.json.gz' - attack_file = get_etc_glob_path(pattern) + pattern = "attack-v*.json.gz" + attack_file = get_etc_glob_path([pattern]) if len(attack_file) < 1: - raise FileNotFoundError(f'Missing required {pattern} file') - elif len(attack_file) != 1: - raise FileExistsError(f'Multiple files found with {pattern} pattern. Only one is allowed') + raise FileNotFoundError(f"Missing required {pattern} file") + if len(attack_file) != 1: + raise FileExistsError(f"Multiple files found with {pattern} pattern. Only one is allowed") return Path(attack_file[0]) -_, _attack_path_base = str(get_attack_file_path()).split('-v') -_ext_length = len('.json.gz') +_, _attack_path_base = str(get_attack_file_path()).split("-v") +_ext_length = len(".json.gz") CURRENT_ATTACK_VERSION = _attack_path_base[:-_ext_length] -def load_attack_gz() -> dict: - +def load_attack_gz() -> dict[str, Any]: return json.loads(read_gzip(get_attack_file_path())) attack = load_attack_gz() -technique_lookup = {} -revoked = {} -deprecated = {} +technique_lookup: dict[str, Any] = {} +revoked: dict[str, Any] = {} +deprecated: dict[str, Any] = {} for item in attack["objects"]: if item["type"] == "x-mitre-tactic": - tactics_map[item['name']] = item['external_references'][0]['external_id'] + tactics_map[item["name"]] = item["external_references"][0]["external_id"] - if item["type"] == "attack-pattern" and item["external_references"][0]['source_name'] == 'mitre-attack': - technique_id = item['external_references'][0]['external_id'] + if item["type"] == "attack-pattern" and item["external_references"][0]["source_name"] == "mitre-attack": + technique_id = item["external_references"][0]["external_id"] technique_lookup[technique_id] = item - if item.get('revoked'): + if item.get("revoked"): revoked[technique_id] = item - if item.get('x_mitre_deprecated'): + if item.get("x_mitre_deprecated"): deprecated[technique_id] = item revoked = dict(sorted(revoked.items())) deprecated = dict(sorted(deprecated.items())) tactics = list(tactics_map) -matrix = {tactic: [] for tactic in tactics} -no_tactic = [] -attack_tm = 'ATT&CK\u2122' +matrix: dict[str, list[str]] = {tactic: [] for tactic in tactics} +no_tactic: list[str] = [] +attack_tm = "ATT&CK\u2122" # Enumerate over the techniques and build the matrix back up -for technique_id, technique in sorted(technique_lookup.items(), key=lambda kv: kv[1]['name'].lower()): - kill_chain = technique.get('kill_chain_phases') +for technique_id, technique in sorted(technique_lookup.items(), key=lambda kv: kv[1]["name"].lower()): + kill_chain = technique.get("kill_chain_phases") if kill_chain: for tactic in kill_chain: - tactic_name = next(t for t in tactics if tactic['kill_chain_name'] == 'mitre-attack' and t.lower() == tactic['phase_name'].replace("-", " ")) # noqa: E501 + tactic_name = next( + t + for t in tactics + if tactic["kill_chain_name"] == "mitre-attack" and t.lower() == tactic["phase_name"].replace("-", " ") + ) matrix[tactic_name].append(technique_id) - else: - no_tactic.append(technique_id) + no_tactic.append(technique_id) -for tactic in matrix: - matrix[tactic].sort(key=lambda tid: technique_lookup[tid]['name'].lower()) +for val in matrix.values(): + val.sort(key=lambda tid: technique_lookup[tid]["name"].lower()) technique_lookup = OrderedDict(sorted(technique_lookup.items())) -techniques = sorted({v['name'] for k, v in technique_lookup.items()}) -technique_id_list = [t for t in technique_lookup if '.' not in t] -sub_technique_id_list = [t for t in technique_lookup if '.' in t] +techniques = sorted({v["name"] for _, v in technique_lookup.items()}) +technique_id_list = [t for t in technique_lookup if "." not in t] +sub_technique_id_list = [t for t in technique_lookup if "." in t] -def refresh_attack_data(save=True) -> (Optional[dict], Optional[bytes]): +def refresh_attack_data(save: bool = True) -> tuple[dict[str, Any] | None, bytes | None]: """Refresh ATT&CK data from Mitre.""" attack_path = get_attack_file_path() - filename, _, _ = attack_path.name.rsplit('.', 2) + filename, _, _ = attack_path.name.rsplit(".", 2) - def get_version_from_tag(name, pattern='att&ck-v'): + def get_version_from_tag(name: str, pattern: str = "att&ck-v") -> str: _, version = name.lower().split(pattern, 1) return version - current_version = Version.parse(get_version_from_tag(filename, 'attack-v'), optional_minor_and_patch=True) + current_version = Version.parse(get_version_from_tag(filename, "attack-v"), optional_minor_and_patch=True) - r = requests.get('https://api.github.com/repos/mitre/cti/tags') + r = requests.get("https://api.github.com/repos/mitre/cti/tags", timeout=30) r.raise_for_status() - releases = [t for t in r.json() if t['name'].startswith('ATT&CK-v')] - latest_release = max(releases, key=lambda release: Version.parse(get_version_from_tag(release['name']), - optional_minor_and_patch=True)) - release_name = latest_release['name'] + releases = [t for t in r.json() if t["name"].startswith("ATT&CK-v")] + latest_release = max( + releases, + key=lambda release: Version.parse(get_version_from_tag(release["name"]), optional_minor_and_patch=True), + ) + release_name = latest_release["name"] latest_version = Version.parse(get_version_from_tag(release_name), optional_minor_and_patch=True) if current_version >= latest_version: - print(f'No versions newer than the current detected: {current_version}') + print(f"No versions newer than the current detected: {current_version}") return None, None - download = f'https://raw.githubusercontent.com/mitre/cti/{release_name}/enterprise-attack/enterprise-attack.json' - r = requests.get(download) + download = f"https://raw.githubusercontent.com/mitre/cti/{release_name}/enterprise-attack/enterprise-attack.json" + r = requests.get(download, timeout=30) r.raise_for_status() attack_data = r.json() compressed = gzip_compress(json.dumps(attack_data, sort_keys=True)) if save: - new_path = get_etc_path(f'attack-v{latest_version}.json.gz') - new_path.write_bytes(compressed) + new_path = get_etc_path([f"attack-v{latest_version}.json.gz"]) + _ = new_path.write_bytes(compressed) attack_path.unlink() - print(f'Replaced file: {attack_path} with {new_path}') + print(f"Replaced file: {attack_path} with {new_path}") return attack_data, compressed -def build_threat_map_entry(tactic: str, *technique_ids: str) -> dict: +def build_threat_map_entry(tactic: str, *technique_ids: str) -> dict[str, Any]: """Build rule threat map from technique IDs.""" techniques_redirect_map = load_techniques_redirect() - url_base = 'https://attack.mitre.org/{type}/{id}/' + url_base = "https://attack.mitre.org/{type}/{id}/" tactic_id = tactics_map[tactic] - tech_entries = {} + tech_entries: dict[str, Any] = {} - def make_entry(_id): - e = { - 'id': _id, - 'name': technique_lookup[_id]['name'], - 'reference': url_base.format(type='techniques', id=_id.replace('.', '/')) + def make_entry(_id: str) -> dict[str, Any]: + return { + "id": _id, + "name": technique_lookup[_id]["name"], + "reference": url_base.format(type="techniques", id=_id.replace(".", "/")), } - return e for tid in technique_ids: # fail if deprecated or else convert if it has been replaced if tid in deprecated: - raise ValueError(f'Technique ID: {tid} has been deprecated and should not be used') - elif tid in techniques_redirect_map: - tid = techniques_redirect_map[tid] + raise ValueError(f"Technique ID: {tid} has been deprecated and should not be used") + if tid in techniques_redirect_map: + tid = techniques_redirect_map[tid] # noqa: PLW2901 if tid not in matrix[tactic]: - raise ValueError(f'Technique ID: {tid} does not fall under tactic: {tactic}') + raise ValueError(f"Technique ID: {tid} does not fall under tactic: {tactic}") # sub-techniques - if '.' in tid: - parent_technique, _ = tid.split('.', 1) + if "." in tid: + parent_technique, _ = tid.split(".", 1) tech_entries.setdefault(parent_technique, make_entry(parent_technique)) - tech_entries[parent_technique].setdefault('subtechnique', []).append(make_entry(tid)) + tech_entries[parent_technique].setdefault("subtechnique", []).append(make_entry(tid)) else: tech_entries.setdefault(tid, make_entry(tid)) - entry = { - 'framework': 'MITRE ATT&CK', - 'tactic': { - 'id': tactic_id, - 'name': tactic, - 'reference': url_base.format(type='tactics', id=tactic_id) - } + entry: dict[str, Any] = { + "framework": "MITRE ATT&CK", + "tactic": {"id": tactic_id, "name": tactic, "reference": url_base.format(type="tactics", id=tactic_id)}, } if tech_entries: - entry['technique'] = sorted(tech_entries.values(), key=lambda x: x['id']) + entry["technique"] = sorted(tech_entries.values(), key=lambda x: x["id"]) return entry -def update_threat_map(rule_threat_map): +def update_threat_map(rule_threat_map: list[dict[str, Any]]) -> None: """Update rule map techniques to reflect changes from ATT&CK.""" for entry in rule_threat_map: - for tech in entry['technique']: - tech['name'] = technique_lookup[tech['id']]['name'] + for tech in entry["technique"]: + tech["name"] = technique_lookup[tech["id"]]["name"] -def retrieve_redirected_id(asset_id: str): +def retrieve_redirected_id(asset_id: str) -> str | Any: """Get the ID for a redirected ATT&CK asset.""" if asset_id in (tactics_map.values()): - attack_type = 'tactics' + attack_type = "tactics" elif asset_id in list(technique_lookup): - attack_type = 'techniques' + attack_type = "techniques" else: - raise ValueError(f'Unknown asset_id: {asset_id}') + raise ValueError(f"Unknown asset_id: {asset_id}") - response = requests.get(f'https://attack.mitre.org/{attack_type}/{asset_id.replace(".", "/")}') + response = requests.get( + f"https://attack.mitre.org/{attack_type}/{asset_id.replace('.', '/')}", + timeout=30, + ) text = response.text.strip().strip("'").lower() if text.startswith(' dict[str, Any]: """Build a mapping of revoked technique IDs to new technique IDs.""" from multiprocessing.pool import ThreadPool - technique_map = {} + technique_map: dict[str, Any] = {} - def download_worker(tech_id): + def download_worker(tech_id: str) -> None: new = retrieve_redirected_id(tech_id) if new: technique_map[tech_id] = new pool = ThreadPool(processes=threads) - pool.map(download_worker, list(technique_lookup)) + _ = pool.map(download_worker, list(technique_lookup)) pool.close() pool.join() return technique_map -def refresh_redirected_techniques_map(threads: int = 50): +def refresh_redirected_techniques_map(threads: int = 50) -> None: """Refresh the locally saved copy of the mapping.""" replacement_map = build_redirected_techniques_map(threads) - mapping = {'saved_date': time.asctime(), 'mapping': replacement_map} + mapping = {"saved_date": time.asctime(), "mapping": replacement_map} - TECHNIQUES_REDIRECT_FILE.write_text(json.dumps(mapping, sort_keys=True, indent=2)) + _ = TECHNIQUES_REDIRECT_FILE.write_text(json.dumps(mapping, sort_keys=True, indent=2)) # reset the cached redirect contents clear_caches() - print(f'refreshed mapping file: {TECHNIQUES_REDIRECT_FILE}') + print(f"refreshed mapping file: {TECHNIQUES_REDIRECT_FILE}") @cached -def load_crosswalk_map() -> dict: +def load_crosswalk_map() -> dict[str, Any]: """Retrieve the replacement mapping.""" - return json.loads(CROSSWALK_FILE.read_text())['mapping'] + return json.loads(CROSSWALK_FILE.read_text())["mapping"] diff --git a/detection_rules/beats.py b/detection_rules/beats.py index ed2a1d9b9e6..d5afaa923cd 100644 --- a/detection_rules/beats.py +++ b/detection_rules/beats.py @@ -4,51 +4,51 @@ # 2.0. """ECS Schemas management.""" + import json import os import re -from typing import List, Optional, Union +from pathlib import Path +from typing import Any -import eql +import eql # type: ignore[reportMissingTypeStubs] +import kql # type: ignore[reportMissingTypeStubs] import requests -from semver import Version import yaml +from semver import Version -import kql - -from .utils import (DateTimeEncoder, cached, get_etc_path, gzip_compress, - read_gzip, unzip) +from .utils import DateTimeEncoder, cached, get_etc_path, gzip_compress, read_gzip, unzip -def _decompress_and_save_schema(url, release_name): +def _decompress_and_save_schema(url: str, release_name: str) -> None: print(f"Downloading beats {release_name}") - response = requests.get(url) + response = requests.get(url, timeout=30) print(f"Downloaded {len(response.content) / 1024.0 / 1024.0:.2f} MB release.") - fs = {} - parsed = {} + fs: dict[str, Any] = {} with unzip(response.content) as archive: base_directory = archive.namelist()[0] for name in archive.namelist(): - if os.path.basename(name) in ("fields.yml", "fields.common.yml", "config.yml"): + path = Path(name) + if path.name in ("fields.yml", "fields.common.yml", "config.yml"): contents = archive.read(name) # chop off the base directory name - key = name[len(base_directory):] + key = name[len(base_directory) :] if key.startswith("x-pack"): - key = key[len("x-pack") + 1:] + key = key[len("x-pack") + 1 :] try: decoded = yaml.safe_load(contents) - except yaml.YAMLError: + except yaml.YAMLError as e: print(f"Error loading {name}") + raise ValueError(f"Error loading {name}") from e # create a hierarchical structure - parsed[key] = decoded branch = fs directory, base_name = os.path.split(key) for limb in directory.split(os.path.sep): @@ -61,53 +61,53 @@ def _decompress_and_save_schema(url, release_name): print(f"Saving detection_rules/etc/beats_schema/{release_name}.json") compressed = gzip_compress(json.dumps(fs, sort_keys=True, cls=DateTimeEncoder)) - path = get_etc_path("beats_schemas", release_name + ".json.gz") - with open(path, 'wb') as f: - f.write(compressed) + path = get_etc_path(["beats_schemas", release_name + ".json.gz"]) + with path.open("wb") as f: + _ = f.write(compressed) -def download_beats_schema(version: str): +def download_beats_schema(version: str) -> None: """Download a beats schema by version.""" - url = 'https://api.github.com/repos/elastic/beats/releases' - releases = requests.get(url) + url = "https://api.github.com/repos/elastic/beats/releases" + releases = requests.get(url, timeout=30) - version = f'v{version.lstrip("v")}' + version = f"v{version.lstrip('v')}" beats_release = None for release in releases.json(): - if release['tag_name'] == version: + if release["tag_name"] == version: beats_release = release break if not beats_release: - print(f'beats release {version} not found!') + print(f"beats release {version} not found!") return - beats_url = beats_release['zipball_url'] - name = beats_release['tag_name'] + beats_url = beats_release["zipball_url"] + name = beats_release["tag_name"] _decompress_and_save_schema(beats_url, name) -def download_latest_beats_schema(): +def download_latest_beats_schema() -> None: """Download additional schemas from beats releases.""" - url = 'https://api.github.com/repos/elastic/beats/releases' - releases = requests.get(url) + url = "https://api.github.com/repos/elastic/beats/releases" + releases = requests.get(url, timeout=30) latest_release = max(releases.json(), key=lambda release: Version.parse(release["tag_name"].lstrip("v"))) download_beats_schema(latest_release["tag_name"]) -def refresh_main_schema(): +def refresh_main_schema() -> None: """Download and refresh beats schema from main.""" - _decompress_and_save_schema('https://github.com/elastic/beats/archive/main.zip', 'main') + _decompress_and_save_schema("https://github.com/elastic/beats/archive/main.zip", "main") -def _flatten_schema(schema: list, prefix="") -> list: +def _flatten_schema(schema: list[dict[str, Any]] | None, prefix: str = "") -> list[dict[str, Any]]: if schema is None: # sometimes we see `fields: null` in the yaml return [] - flattened = [] + flattened: list[dict[str, Any]] = [] for s in schema: if s.get("type") == "group": nested_prefix = prefix + s["name"] + "." @@ -123,7 +123,7 @@ def _flatten_schema(schema: list, prefix="") -> list: # integrations sometimes have a group with a single field flattened.extend(_flatten_schema(s["field"], prefix=nested_prefix)) continue - elif "fields" not in s: + if "fields" not in s: # integrations sometimes have a group with no fields continue @@ -131,25 +131,29 @@ def _flatten_schema(schema: list, prefix="") -> list: elif "fields" in s: flattened.extend(_flatten_schema(s["fields"], prefix=prefix)) elif "name" in s: - s = s.copy() + _s = s.copy() # type is implicitly keyword if not defined # example: https://github.com/elastic/beats/blob/main/packetbeat/_meta/fields.common.yml#L7-L12 - s.setdefault("type", "keyword") - s["name"] = prefix + s["name"] - flattened.append(s) + _s.setdefault("type", "keyword") + _s["name"] = prefix + s["name"] + flattened.append(_s) return flattened -def flatten_ecs_schema(schema: dict) -> dict: +def flatten_ecs_schema(schema: list[dict[str, Any]]) -> list[dict[str, Any]]: return _flatten_schema(schema) -def get_field_schema(base_directory, prefix="", include_common=False): +def get_field_schema( + base_directory: dict[str, Any], + prefix: str = "", + include_common: bool = False, +) -> list[dict[str, Any]]: base_directory = base_directory.get("folders", {}).get("_meta", {}).get("files", {}) - flattened = [] + flattened: list[dict[str, Any]] = [] - file_names = ("fields.yml", "fields.common.yml") if include_common else ("fields.yml", ) + file_names = ("fields.yml", "fields.common.yml") if include_common else ("fields.yml",) for name in file_names: if name in base_directory: @@ -158,7 +162,7 @@ def get_field_schema(base_directory, prefix="", include_common=False): return flattened -def get_beat_root_schema(schema: dict, beat: str): +def get_beat_root_schema(schema: dict[str, Any], beat: str) -> dict[str, Any]: if beat not in schema: raise KeyError(f"Unknown beats module {beat}") @@ -168,22 +172,20 @@ def get_beat_root_schema(schema: dict, beat: str): return {field["name"]: field for field in sorted(flattened, key=lambda f: f["name"])} -def get_beats_sub_schema(schema: dict, beat: str, module: str, *datasets: str): +def get_beats_sub_schema(schema: dict[str, Any], beat: str, module: str, *datasets: str) -> dict[str, Any]: if beat not in schema: raise KeyError(f"Unknown beats module {beat}") - flattened = [] + flattened: list[dict[str, Any]] = [] beat_dir = schema[beat] module_dir = beat_dir.get("folders", {}).get("module", {}).get("folders", {}).get(module, {}) # if we only have a module then we'll work with what we got - if not datasets: - datasets = [d for d in module_dir.get("folders", {}) if not d.startswith("_")] + all_datasets = datasets if datasets else [d for d in module_dir.get("folders", {}) if not d.startswith("_")] - for dataset in datasets: + for _dataset in all_datasets: # replace aws.s3 -> s3 - if dataset.startswith(module + "."): - dataset = dataset[len(module) + 1:] + dataset = _dataset[len(module) + 1 :] if _dataset.startswith(module + ".") else _dataset dataset_dir = module_dir.get("folders", {}).get(dataset, {}) flattened.extend(get_field_schema(dataset_dir, prefix=module + ".", include_common=True)) @@ -195,10 +197,10 @@ def get_beats_sub_schema(schema: dict, beat: str, module: str, *datasets: str): @cached -def get_versions() -> List[Version]: - versions = [] - for filename in os.listdir(get_etc_path("beats_schemas")): - version_match = re.match(r'v(.+)\.json\.gz', filename) +def get_versions() -> list[Version]: + versions: list[Version] = [] + for filename in os.listdir(get_etc_path(["beats_schemas"])): # noqa: PTH208 + version_match = re.match(r"v(.+)\.json\.gz", filename) if version_match: versions.append(Version.parse(version_match.groups()[0])) @@ -211,23 +213,29 @@ def get_max_version() -> str: @cached -def read_beats_schema(version: str = None): - if version and version.lower() == 'main': - return json.loads(read_gzip(get_etc_path('beats_schemas', 'main.json.gz'))) +def read_beats_schema(version: str | None = None) -> dict[str, Any]: + if version and version.lower() == "main": + path = get_etc_path(["beats_schemas", "main.json.gz"]) + return json.loads(read_gzip(path)) - version = Version.parse(version) if version else None + ver = Version.parse(version) if version else None beats_schemas = get_versions() - if version and version not in beats_schemas: - raise ValueError(f'Unknown beats schema: {version}') + if ver and ver not in beats_schemas: + raise ValueError(f"Unknown beats schema: {ver}") version = version or get_max_version() - return json.loads(read_gzip(get_etc_path('beats_schemas', f'v{version}.json.gz'))) + return json.loads(read_gzip(get_etc_path(["beats_schemas", f"v{version}.json.gz"]))) -def get_schema_from_datasets(beats, modules, datasets, version=None): - filtered = {} +def get_schema_from_datasets( + beats: list[str], + modules: set[str], + datasets: set[str], + version: str | None = None, +) -> dict[str, Any]: + filtered: dict[str, Any] = {} beats_schema = read_beats_schema(version=version) # infer the module if only a dataset are defined @@ -236,8 +244,6 @@ def get_schema_from_datasets(beats, modules, datasets, version=None): for beat in beats: # if no modules are specified then grab them all - # all_modules = list(beats_schema.get(beat, {}).get("folders", {}).get("module", {}).get("folders", {})) - # beat_modules = modules or all_modules filtered.update(get_beat_root_schema(beats_schema, beat)) for module in modules: @@ -246,53 +252,50 @@ def get_schema_from_datasets(beats, modules, datasets, version=None): return filtered -def get_datasets_and_modules(tree: Union[eql.ast.BaseNode, kql.ast.BaseNode]) -> tuple: +def get_datasets_and_modules(tree: eql.ast.BaseNode | kql.ast.BaseNode) -> tuple[set[str], set[str]]: """Get datasets and modules from an EQL or KQL AST.""" - modules = set() - datasets = set() + modules: set[str] = set() + datasets: set[str] = set() # extract out event.module and event.dataset from the query's AST - for node in tree: - if isinstance(node, eql.ast.Comparison) and node.comparator == node.EQ and \ - isinstance(node.right, eql.ast.String): + for node in tree: # type: ignore[reportUnknownVariableType] + if ( + isinstance(node, eql.ast.Comparison) + and node.comparator == node.EQ + and isinstance(node.right, eql.ast.String) + ): if node.left == eql.ast.Field("event", ["module"]): - modules.add(node.right.render()) + modules.add(node.right.render()) # type: ignore[reportUnknownMemberType] elif node.left == eql.ast.Field("event", ["dataset"]): - datasets.add(node.right.render()) + datasets.add(node.right.render()) # type: ignore[reportUnknownMemberType] elif isinstance(node, eql.ast.InSet): if node.expression == eql.ast.Field("event", ["module"]): - modules.add(node.get_literals()) + modules.update(node.get_literals()) # type: ignore[reportUnknownMemberType] elif node.expression == eql.ast.Field("event", ["dataset"]): - datasets.add(node.get_literals()) - elif isinstance(node, kql.ast.FieldComparison) and node.field == kql.ast.Field("event.module"): - modules.update(child.value for child in node.value if isinstance(child, kql.ast.String)) - elif isinstance(node, kql.ast.FieldComparison) and node.field == kql.ast.Field("event.dataset"): - datasets.update(child.value for child in node.value if isinstance(child, kql.ast.String)) + datasets.update(node.get_literals()) # type: ignore[reportUnknownMemberType] + elif isinstance(node, kql.ast.FieldComparison) and node.field == kql.ast.Field("event.module"): # type: ignore[reportUnknownMemberType] + modules.update(child.value for child in node.value if isinstance(child, kql.ast.String)) # type: ignore[reportUnknownMemberType, reportUnknownVariableType] + elif isinstance(node, kql.ast.FieldComparison) and node.field == kql.ast.Field("event.dataset"): # type: ignore[reportUnknownMemberType] + datasets.update(child.value for child in node.value if isinstance(child, kql.ast.String)) # type: ignore[reportUnknownMemberType, reportUnknownVariableType] return datasets, modules -def get_schema_from_eql(tree: eql.ast.BaseNode, beats: list, version: str = None) -> dict: - """Get a schema based on datasets and modules in an EQL AST.""" - datasets, modules = get_datasets_and_modules(tree) - return get_schema_from_datasets(beats, modules, datasets, version=version) - - -def get_schema_from_kql(tree: kql.ast.BaseNode, beats: list, version: str = None) -> dict: +def get_schema_from_kql(tree: kql.ast.BaseNode, beats: list[str], version: str | None = None) -> dict[str, Any]: """Get a schema based on datasets and modules in an KQL AST.""" datasets, modules = get_datasets_and_modules(tree) return get_schema_from_datasets(beats, modules, datasets, version=version) -def parse_beats_from_index(index: Optional[list]) -> List[str]: +def parse_beats_from_index(indexes: list[str] | None) -> list[str]: """Parse beats schema types from index.""" - indexes = index or [] - beat_types = [] + indexes = indexes or [] + beat_types: list[str] = [] # Need to split on : or :: to support cross-cluster search # e.g. mycluster:logs-* -> logs-* for index in indexes: if "beat-*" in index: - index_parts = index.replace('::', ':').split(':', 1) + index_parts = index.replace("::", ":").split(":", 1) last_part = index_parts[-1] beat_type = last_part.split("-")[0] beat_types.append(beat_type) diff --git a/detection_rules/cli_utils.py b/detection_rules/cli_utils.py index 0ab967bdb69..ef2b022a1a0 100644 --- a/detection_rules/cli_utils.py +++ b/detection_rules/cli_utils.py @@ -8,52 +8,51 @@ import functools import os import typing +from collections.abc import Callable from pathlib import Path -from typing import List, Optional +from typing import Any import click - -import kql +import kql # type: ignore[reportMissingTypeStubs] from . import ecs from .attack import build_threat_map_entry, matrix, tactics +from .config import parse_rules_config from .rule import BYPASS_VERSION_LOCK, TOMLRule, TOMLRuleContents -from .rule_loader import (DEFAULT_PREBUILT_BBR_DIRS, - DEFAULT_PREBUILT_RULES_DIRS, RuleCollection, - dict_filter) +from .rule_loader import DEFAULT_PREBUILT_BBR_DIRS, DEFAULT_PREBUILT_RULES_DIRS, RuleCollection, dict_filter from .schemas import definitions from .utils import clear_caches, ensure_list_of_strings, rulename_to_filename -from .config import parse_rules_config RULES_CONFIG = parse_rules_config() -def single_collection(f): +def single_collection(f: Callable[..., Any]) -> Callable[..., Any]: """Add arguments to get a RuleCollection by file, directory or a list of IDs""" - from .misc import client_error + from .misc import raise_client_error - @click.option('--rule-file', '-f', multiple=False, required=False, type=click.Path(dir_okay=False)) - @click.option('--rule-id', '-id', multiple=False, required=False) + @click.option("--rule-file", "-f", multiple=False, required=False, type=click.Path(dir_okay=False)) + @click.option("--rule-id", "-id", multiple=False, required=False) @functools.wraps(f) - def get_collection(*args, **kwargs): - rule_name: List[str] = kwargs.pop("rule_name", []) - rule_id: List[str] = kwargs.pop("rule_id", []) - rule_files: List[str] = kwargs.pop("rule_file") - directories: List[str] = kwargs.pop("directory") + def get_collection(*args: Any, **kwargs: Any) -> Any: + rule_name: list[str] = kwargs.pop("rule_name", []) + rule_id: list[str] = kwargs.pop("rule_id", []) + rule_files: list[str] = kwargs.pop("rule_file") + directories: list[str] = kwargs.pop("directory") rules = RuleCollection() if bool(rule_name) + bool(rule_id) + bool(rule_files) != 1: - client_error('Required: exactly one of --rule-id, --rule-file, or --directory') + raise_client_error("Required: exactly one of --rule-id, --rule-file, or --directory") rules.load_files(Path(p) for p in rule_files) rules.load_directories(Path(d) for d in directories) if rule_id: - rules.load_directories(DEFAULT_PREBUILT_RULES_DIRS + DEFAULT_PREBUILT_BBR_DIRS, - obj_filter=dict_filter(rule__rule_id=rule_id)) + rules.load_directories( + DEFAULT_PREBUILT_RULES_DIRS + DEFAULT_PREBUILT_BBR_DIRS, obj_filter=dict_filter(rule__rule_id=rule_id) + ) if len(rules) != 1: - client_error(f"Could not find rule with ID {rule_id}") + raise_client_error(f"Could not find rule with ID {rule_id}") kwargs["rules"] = rules return f(*args, **kwargs) @@ -61,28 +60,38 @@ def get_collection(*args, **kwargs): return get_collection -def multi_collection(f): +def multi_collection(f: Callable[..., Any]) -> Callable[..., Any]: """Add arguments to get a RuleCollection by file, directory or a list of IDs""" - from .misc import client_error + from .misc import raise_client_error @click.option("--rule-file", "-f", multiple=True, type=click.Path(dir_okay=False), required=False) - @click.option("--directory", "-d", multiple=True, type=click.Path(file_okay=False), required=False, - help="Recursively load rules from a directory") + @click.option( + "--directory", + "-d", + multiple=True, + type=click.Path(file_okay=False), + required=False, + help="Recursively load rules from a directory", + ) @click.option("--rule-id", "-id", multiple=True, required=False) - @click.option("--no-tactic-filename", "-nt", is_flag=True, required=False, - help="Allow rule filenames without tactic prefix. " - "Use this if rules have been exported with this flag.") + @click.option( + "--no-tactic-filename", + "-nt", + is_flag=True, + required=False, + help="Allow rule filenames without tactic prefix. Use this if rules have been exported with this flag.", + ) @functools.wraps(f) - def get_collection(*args, **kwargs): - rule_id: List[str] = kwargs.pop("rule_id", []) - rule_files: List[str] = kwargs.pop("rule_file") - directories: List[str] = kwargs.pop("directory") + def get_collection(*args: Any, **kwargs: Any) -> Any: + rule_id: list[str] = kwargs.pop("rule_id", []) + rule_files: list[str] = kwargs.pop("rule_file") + directories: list[str] = kwargs.pop("directory") no_tactic_filename: bool = kwargs.pop("no_tactic_filename", False) rules = RuleCollection() if not (directories or rule_id or rule_files or (DEFAULT_PREBUILT_RULES_DIRS + DEFAULT_PREBUILT_BBR_DIRS)): - client_error("Required: at least one of --rule-id, --rule-file, or --directory") + raise_client_error("Required: at least one of --rule-id, --rule-file, or --directory") rules.load_files(Path(p) for p in rule_files) rules.load_directories(Path(d) for d in directories) @@ -95,12 +104,12 @@ def get_collection(*args, **kwargs): missing = set(rule_id).difference(found_ids) if missing: - client_error(f'Could not find rules with IDs: {", ".join(missing)}') + raise_client_error(f"Could not find rules with IDs: {', '.join(missing)}") elif not rule_files and not directories: rules.load_directories(Path(d) for d in (DEFAULT_PREBUILT_RULES_DIRS + DEFAULT_PREBUILT_BBR_DIRS)) if len(rules) == 0: - client_error("No rules found") + raise_client_error("No rules found") # Warn that if the path does not match the expected path, it will be saved to the expected path for rule in rules: @@ -110,7 +119,9 @@ def get_collection(*args, **kwargs): no_tactic_filename = no_tactic_filename or RULES_CONFIG.no_tactic_filename tactic_name = None if no_tactic_filename else first_tactic rule_name = rulename_to_filename(rule.contents.data.name, tactic_name=tactic_name) - if rule.path.name != rule_name: + if not rule.path: + click.secho(f"WARNING: Rule path for rule not found: {rule_name}", fg="yellow") + elif rule.path.name != rule_name: click.secho( f"WARNING: Rule path does not match required path: {rule.path.name} != {rule_name}", fg="yellow" ) @@ -121,67 +132,84 @@ def get_collection(*args, **kwargs): return get_collection -def rule_prompt(path=None, rule_type=None, required_only=True, save=True, verbose=False, - additional_required: Optional[list] = None, skip_errors: bool = False, strip_none_values=True, **kwargs, - ) -> TOMLRule: +def rule_prompt( # noqa: PLR0912, PLR0913, PLR0915 + path: Path | None = None, + rule_type: str | None = None, + required_only: bool = True, + save: bool = True, + verbose: bool = False, + additional_required: list[str] | None = None, + skip_errors: bool = False, + strip_none_values: bool = True, + **kwargs: Any, +) -> TOMLRule | str: """Prompt loop to build a rule.""" from .misc import schema_prompt additional_required = additional_required or [] - creation_date = datetime.date.today().strftime("%Y/%m/%d") + creation_date = datetime.date.today().strftime("%Y/%m/%d") # noqa: DTZ011 if verbose and path: - click.echo(f'[+] Building rule for {path}') + click.echo(f"[+] Building rule for {path}") kwargs = copy.deepcopy(kwargs) - rule_name = kwargs.get('name') + rule_name = kwargs.get("name") - if 'rule' in kwargs and 'metadata' in kwargs: - kwargs.update(kwargs.pop('metadata')) - kwargs.update(kwargs.pop('rule')) + if "rule" in kwargs and "metadata" in kwargs: + kwargs.update(kwargs.pop("metadata")) + kwargs.update(kwargs.pop("rule")) - rule_type = rule_type or kwargs.get('type') or \ - click.prompt('Rule type', type=click.Choice(typing.get_args(definitions.RuleType))) + rule_type_val = ( + rule_type + or kwargs.get("type") + or click.prompt("Rule type", type=click.Choice(typing.get_args(definitions.RuleType))) + ) - target_data_subclass = TOMLRuleContents.get_data_subclass(rule_type) + target_data_subclass = TOMLRuleContents.get_data_subclass(rule_type_val) schema = target_data_subclass.jsonschema() - props = schema['properties'] - required_fields = schema.get('required', []) + additional_required - contents = {} - skipped = [] + props = schema["properties"] + required_fields = schema.get("required", []) + additional_required + contents: dict[str, Any] = {} + skipped: list[str] = [] for name, options in props.items(): - - if name == 'index' and kwargs.get("type") == "esql": + if name == "index" and kwargs.get("type") == "esql": continue - if name == 'type': - contents[name] = rule_type + if name == "type": + contents[name] = rule_type_val continue # these are set at package release time depending on the version strategy - if (name == 'version' or name == 'revision') and not BYPASS_VERSION_LOCK: + if name in ("version", "revision") and not BYPASS_VERSION_LOCK: continue if required_only and name not in required_fields: continue # build this from technique ID - if name == 'threat': - threat_map = [] + if name == "threat": + threat_map: list[dict[str, Any]] = [] if not skip_errors: - while click.confirm('add mitre tactic?'): - tactic = schema_prompt('mitre tactic name', type='string', enum=tactics, is_required=True) - technique_ids = schema_prompt(f'technique or sub-technique IDs for {tactic}', type='array', - is_required=False, enum=list(matrix[tactic])) or [] + while click.confirm("add mitre tactic?"): + tactic = schema_prompt("mitre tactic name", type="string", enum=tactics, is_required=True) + technique_ids = ( # type: ignore[reportUnknownVariableType] + schema_prompt( + f"technique or sub-technique IDs for {tactic}", + type="array", + is_required=False, + enum=list(matrix[tactic]), + ) + or [] + ) try: - threat_map.append(build_threat_map_entry(tactic, *technique_ids)) + threat_map.append(build_threat_map_entry(tactic, *technique_ids)) # type: ignore[reportUnknownArgumentType] except KeyError as e: - click.secho(f'Unknown ID: {e.args[0]} - entry not saved for: {tactic}', fg='red', err=True) + click.secho(f"Unknown ID: {e.args[0]} - entry not saved for: {tactic}", fg="red", err=True) continue except ValueError as e: - click.secho(f'{e} - entry not saved for: {tactic}', fg='red', err=True) + click.secho(f"{e} - entry not saved for: {tactic}", fg="red", err=True) continue if len(threat_map) > 0: @@ -194,7 +222,7 @@ def rule_prompt(path=None, rule_type=None, required_only=True, save=True, verbos if name == "new_terms": # patch to allow new_term imports - result = {"field": "new_terms_fields"} + result: dict[str, Any] = {"field": "new_terms_fields"} new_terms_fields_value = schema_prompt("new_terms_fields", value=kwargs.pop("new_terms_fields", None)) result["value"] = ensure_list_of_strings(new_terms_fields_value) history_window_start_value = kwargs.pop("history_window_start", None) @@ -205,52 +233,55 @@ def rule_prompt(path=None, rule_type=None, required_only=True, save=True, verbos } ] + elif skip_errors: + # return missing information + return f"Rule: {kwargs['id']}, Rule Name: {rule_name} is missing {name} information" else: - if skip_errors: - # return missing information - return f"Rule: {kwargs["id"]}, Rule Name: {rule_name} is missing {name} information" - else: - result = schema_prompt(name, is_required=name in required_fields, **options.copy()) + result = schema_prompt(name, is_required=name in required_fields, **options.copy()) if result: - if name not in required_fields and result == options.get('default', ''): + if name not in required_fields and result == options.get("default", ""): skipped.append(name) continue contents[name] = result # DEFAULT_PREBUILT_RULES_DIRS[0] is a required directory just as a suggestion - suggested_path = Path(DEFAULT_PREBUILT_RULES_DIRS[0]) / contents['name'] - path = Path(path or input(f'File path for rule [{suggested_path}]: ') or suggested_path).resolve() + suggested_path: Path = Path(DEFAULT_PREBUILT_RULES_DIRS[0]) / contents["name"] + path = Path(path or input(f"File path for rule [{suggested_path}]: ") or suggested_path).resolve() # Inherit maturity and optionally local dates from the rule if it already exists meta = { "creation_date": kwargs.get("creation_date") or creation_date, "updated_date": kwargs.get("updated_date") or creation_date, - "maturity": "development" or kwargs.get("maturity"), + "maturity": "development", } try: - rule = TOMLRule(path=Path(path), contents=TOMLRuleContents.from_dict({'rule': contents, 'metadata': meta})) + rule = TOMLRule(path=Path(path), contents=TOMLRuleContents.from_dict({"rule": contents, "metadata": meta})) except kql.KqlParseError as e: if skip_errors: return f"Rule: {kwargs['id']}, Rule Name: {rule_name} query failed to parse: {e.error_msg}" - if e.error_msg == 'Unknown field': - warning = ('If using a non-ECS field, you must update "ecs{}.non-ecs-schema.json" under `beats` or ' - '`legacy-endgame` (Non-ECS fields should be used minimally).'.format(os.path.sep)) - click.secho(e.args[0], fg='red', err=True) - click.secho(warning, fg='yellow', err=True) + if e.error_msg == "Unknown field": + warning = ( + f'If using a non-ECS field, you must update "ecs{os.path.sep}.non-ecs-schema.json" under `beats` or ' + "`legacy-endgame` (Non-ECS fields should be used minimally)." + ) + click.secho(e.args[0], fg="red", err=True) + click.secho(warning, fg="yellow", err=True) click.pause() # if failing due to a query, loop until resolved or terminated while True: try: - contents['query'] = click.edit(contents['query'], extension='.eql') - rule = TOMLRule(path=Path(path), - contents=TOMLRuleContents.from_dict({'rule': contents, 'metadata': meta})) + contents["query"] = click.edit(contents["query"], extension=".eql") + rule = TOMLRule( + path=Path(path), + contents=TOMLRuleContents.from_dict({"rule": contents, "metadata": meta}), + ) except kql.KqlParseError as e: - click.secho(e.args[0], fg='red', err=True) + click.secho(e.args[0], fg="red", err=True) click.pause() - if e.error_msg.startswith("Unknown field"): + if e.error_msg.startswith("Unknown field"): # type: ignore[reportUnknownMemberType] # get the latest schema for schema errors clear_caches() ecs.get_kql_schema(indexes=contents.get("index", [])) @@ -260,13 +291,13 @@ def rule_prompt(path=None, rule_type=None, required_only=True, save=True, verbos except Exception as e: if skip_errors: return f"Rule: {kwargs['id']}, Rule Name: {rule_name} failed: {e}" - raise e + raise if save: rule.save_toml(strip_none_values=strip_none_values) if skipped: - print('Did not set the following values because they are un-required when set to the default value') - print(' - {}'.format('\n - '.join(skipped))) + print("Did not set the following values because they are un-required when set to the default value") + print(" - {}".format("\n - ".join(skipped))) return rule diff --git a/detection_rules/config.py b/detection_rules/config.py index cd2804c35f3..132ba0cc293 100644 --- a/detection_rules/config.py +++ b/detection_rules/config.py @@ -4,42 +4,46 @@ # 2.0. """Configuration support for custom components.""" + import fnmatch import os from dataclasses import dataclass, field -from pathlib import Path from functools import cached_property -from typing import Dict, List, Optional +from pathlib import Path +from typing import Any import yaml -from eql.utils import load_dump +from eql.utils import load_dump # type: ignore[reportMissingTypeStubs] from .misc import discover_tests -from .utils import cached, load_etc_dump, get_etc_path, set_all_validation_bypass +from .utils import cached, get_etc_path, load_etc_dump, set_all_validation_bypass ROOT_DIR = Path(__file__).parent.parent -CUSTOM_RULES_DIR = os.getenv('CUSTOM_RULES_DIR', None) +CUSTOM_RULES_DIR = os.getenv("CUSTOM_RULES_DIR", None) @dataclass class UnitTest: """Base object for unit tests configuration.""" - bypass: Optional[List[str]] = None - test_only: Optional[List[str]] = None - def __post_init__(self): - assert (self.bypass is None or self.test_only is None), \ - 'Cannot set both `test_only` and `bypass` in test_config!' + bypass: list[str] | None = None + test_only: list[str] | None = None + + def __post_init__(self) -> None: + if self.bypass and self.test_only: + raise ValueError("Cannot set both `test_only` and `bypass` in test_config!") @dataclass class RuleValidation: """Base object for rule validation configuration.""" - bypass: Optional[List[str]] = None - test_only: Optional[List[str]] = None - def __post_init__(self): - assert not (self.bypass and self.test_only), 'Cannot use both test_only and bypass' + bypass: list[str] | None = None + test_only: list[str] | None = None + + def __post_init__(self) -> None: + if self.bypass and self.test_only: + raise ValueError("Cannot use both test_only and bypass") @dataclass @@ -50,32 +54,30 @@ class ConfigFile: class FilePaths: packages_file: str stack_schema_map_file: str - deprecated_rules_file: Optional[str] = None - version_lock_file: Optional[str] = None + deprecated_rules_file: str | None = None + version_lock_file: str | None = None @dataclass class TestConfigPath: config: str files: FilePaths - rule_dir: List[str] - testing: Optional[TestConfigPath] = None + rule_dir: list[str] + testing: TestConfigPath | None = None @classmethod - def from_dict(cls, obj: dict) -> 'ConfigFile': - files_data = obj.get('files', {}) + def from_dict(cls, obj: dict[str, Any]) -> "ConfigFile": + files_data = obj.get("files", {}) files = cls.FilePaths( - deprecated_rules_file=files_data.get('deprecated_rules'), - packages_file=files_data['packages'], - stack_schema_map_file=files_data['stack_schema_map'], - version_lock_file=files_data.get('version_lock') + deprecated_rules_file=files_data.get("deprecated_rules"), + packages_file=files_data["packages"], + stack_schema_map_file=files_data["stack_schema_map"], + version_lock_file=files_data.get("version_lock"), ) - rule_dir = obj['rule_dirs'] + rule_dir = obj["rule_dirs"] - testing_data = obj.get('testing') - testing = cls.TestConfigPath( - config=testing_data['config'] - ) if testing_data else None + testing_data = obj.get("testing") + testing = cls.TestConfigPath(config=testing_data["config"]) if testing_data else None return cls(files=files, rule_dir=rule_dir, testing=testing) @@ -83,59 +85,70 @@ def from_dict(cls, obj: dict) -> 'ConfigFile': @dataclass class TestConfig: """Detection rules test config file""" - test_file: Optional[Path] = None - unit_tests: Optional[UnitTest] = None - rule_validation: Optional[RuleValidation] = None + + test_file: Path | None = None + unit_tests: UnitTest | None = None + rule_validation: RuleValidation | None = None @classmethod - def from_dict(cls, test_file: Optional[Path] = None, unit_tests: Optional[dict] = None, - rule_validation: Optional[dict] = None) -> 'TestConfig': - return cls(test_file=test_file or None, unit_tests=UnitTest(**unit_tests or {}), - rule_validation=RuleValidation(**rule_validation or {})) + def from_dict( + cls, + test_file: Path | None = None, + unit_tests: dict[str, Any] | None = None, + rule_validation: dict[str, Any] | None = None, + ) -> "TestConfig": + return cls( + test_file=test_file or None, + unit_tests=UnitTest(**unit_tests or {}), + rule_validation=RuleValidation(**rule_validation or {}), + ) @cached_property - def all_tests(self): + def all_tests(self) -> list[str]: """Get the list of all test names.""" return discover_tests() - def tests_by_patterns(self, *patterns: str) -> List[str]: + def tests_by_patterns(self, *patterns: str) -> list[str]: """Get the list of test names by patterns.""" - tests = set() + tests: set[str] = set() for pattern in patterns: tests.update(list(fnmatch.filter(self.all_tests, pattern))) return sorted(tests) @staticmethod - def parse_out_patterns(names: List[str]) -> (List[str], List[str]): + def parse_out_patterns(names: list[str]) -> tuple[list[str], list[str]]: """Parse out test patterns from a list of test names.""" - patterns = [] - tests = [] + patterns: list[str] = [] + tests: list[str] = [] for name in names: - if name.startswith('pattern:') and '*' in name: - patterns.append(name[len('pattern:'):]) + if name.startswith("pattern:") and "*" in name: + patterns.append(name[len("pattern:") :]) else: tests.append(name) return patterns, tests @staticmethod - def format_tests(tests: List[str]) -> List[str]: + def format_tests(tests: list[str]) -> list[str]: """Format unit test names into expected format for direct calling.""" - raw = [t.rsplit('.', maxsplit=2) for t in tests] - formatted = [] + raw = [t.rsplit(".", maxsplit=2) for t in tests] + formatted: list[str] = [] for test in raw: path, clazz, method = test - path = f'{path.replace(".", os.path.sep)}.py' - formatted.append('::'.join([path, clazz, method])) + path = f"{path.replace('.', os.path.sep)}.py" + formatted.append(f"{path}::{clazz}::{method}") return formatted - def get_test_names(self, formatted: bool = False) -> (List[str], List[str]): + def get_test_names(self, formatted: bool = False) -> tuple[list[str], list[str]]: """Get the list of test names to run.""" + if not self.unit_tests: + raise ValueError("No unit tests defined") patterns_t, tests_t = self.parse_out_patterns(self.unit_tests.test_only or []) patterns_b, tests_b = self.parse_out_patterns(self.unit_tests.bypass or []) defined_tests = tests_t + tests_b patterns = patterns_t + patterns_b unknowns = sorted(set(defined_tests) - set(self.all_tests)) - assert not unknowns, f'Unrecognized test names in config ({self.test_file}): {unknowns}' + if unknowns: + raise ValueError(f"Unrecognized test names in config ({self.test_file}): {unknowns}") combined_tests = sorted(set(defined_tests + self.tests_by_patterns(*patterns))) @@ -143,8 +156,8 @@ def get_test_names(self, formatted: bool = False) -> (List[str], List[str]): tests = combined_tests skipped = [t for t in self.all_tests if t not in tests] elif self.unit_tests.bypass: - tests = [] - skipped = [] + tests: list[str] = [] + skipped: list[str] = [] for test in self.all_tests: if test not in combined_tests: tests.append(test) @@ -156,11 +169,12 @@ def get_test_names(self, formatted: bool = False) -> (List[str], List[str]): if formatted: return self.format_tests(tests), self.format_tests(skipped) - else: - return tests, skipped + return tests, skipped def check_skip_by_rule_id(self, rule_id: str) -> bool: """Check if a rule_id should be skipped.""" + if not self.rule_validation: + raise ValueError("No rule validation specified") bypass = self.rule_validation.bypass test_only = self.rule_validation.test_only @@ -168,50 +182,52 @@ def check_skip_by_rule_id(self, rule_id: str) -> bool: if not (bypass or test_only): return False # if defined in bypass or not defined in test_only, then skip - return (bypass and rule_id in bypass) or (test_only and rule_id not in test_only) + return bool((bypass and rule_id in bypass) or (test_only and rule_id not in test_only)) @dataclass class RulesConfig: """Detection rules config file.""" + deprecated_rules_file: Path - deprecated_rules: Dict[str, dict] + deprecated_rules: dict[str, dict[str, Any]] packages_file: Path - packages: Dict[str, dict] - rule_dirs: List[Path] + packages: dict[str, dict[str, Any]] + rule_dirs: list[Path] stack_schema_map_file: Path - stack_schema_map: Dict[str, dict] + stack_schema_map: dict[str, dict[str, Any]] test_config: TestConfig version_lock_file: Path - version_lock: Dict[str, dict] + version_lock: dict[str, dict[str, Any]] - action_dir: Optional[Path] = None - action_connector_dir: Optional[Path] = None - auto_gen_schema_file: Optional[Path] = None - bbr_rules_dirs: Optional[List[Path]] = field(default_factory=list) + action_dir: Path | None = None + action_connector_dir: Path | None = None + auto_gen_schema_file: Path | None = None + bbr_rules_dirs: list[Path] = field(default_factory=list) # type: ignore[reportUnknownVariableType] bypass_version_lock: bool = False - exception_dir: Optional[Path] = None + exception_dir: Path | None = None normalize_kql_keywords: bool = True bypass_optional_elastic_validation: bool = False no_tactic_filename: bool = False - def __post_init__(self): + def __post_init__(self) -> None: """Perform post validation on packages.yaml file.""" - if 'package' not in self.packages: - raise ValueError('Missing the `package` field defined in packages.yaml.') + if "package" not in self.packages: + raise ValueError("Missing the `package` field defined in packages.yaml.") - if 'name' not in self.packages['package']: - raise ValueError('Missing the `name` field defined in packages.yaml.') + if "name" not in self.packages["package"]: + raise ValueError("Missing the `name` field defined in packages.yaml.") @cached -def parse_rules_config(path: Optional[Path] = None) -> RulesConfig: +def parse_rules_config(path: Path | None = None) -> RulesConfig: # noqa: PLR0912, PLR0915 """Parse the _config.yaml file for default or custom rules.""" if path: - assert path.exists(), f'rules config file does not exist: {path}' + if not path.exists(): + raise ValueError(f"rules config file does not exist: {path}") loaded = yaml.safe_load(path.read_text()) elif CUSTOM_RULES_DIR: - path = Path(CUSTOM_RULES_DIR) / '_config.yaml' + path = Path(CUSTOM_RULES_DIR) / "_config.yaml" if not path.exists(): raise FileNotFoundError( """ @@ -222,104 +238,101 @@ def parse_rules_config(path: Optional[Path] = None) -> RulesConfig: ) loaded = yaml.safe_load(path.read_text()) else: - path = Path(get_etc_path('_config.yaml')) - loaded = load_etc_dump('_config.yaml') + path = Path(get_etc_path(["_config.yaml"])) + loaded = load_etc_dump(["_config.yaml"]) try: - ConfigFile.from_dict(loaded) + _ = ConfigFile.from_dict(loaded) except KeyError as e: - raise SystemExit(f'Missing key `{str(e)}` in _config.yaml file.') - except (AttributeError, TypeError): - raise SystemExit(f'No data properly loaded from {path}') + raise SystemExit(f"Missing key `{e!s}` in _config.yaml file.") from e + except (AttributeError, TypeError) as e: + raise SystemExit(f"No data properly loaded from {path}") from e except ValueError as e: - raise SystemExit(e) + raise SystemExit(e) from e base_dir = path.resolve().parent # testing # precedence to the environment variable # environment variable is absolute path and config file is relative to the _config.yaml file - test_config_ev = os.getenv('DETECTION_RULES_TEST_CONFIG', None) + test_config_ev = os.getenv("DETECTION_RULES_TEST_CONFIG", None) if test_config_ev: test_config_path = Path(test_config_ev) else: - test_config_file = loaded.get('testing', {}).get('config') - if test_config_file: - test_config_path = base_dir.joinpath(test_config_file) - else: - test_config_path = None + test_config_file = loaded.get("testing", {}).get("config") + test_config_path = base_dir.joinpath(test_config_file) if test_config_file else None if test_config_path: test_config_data = yaml.safe_load(test_config_path.read_text()) # overwrite None with empty list to allow implicit exemption of all tests with `test_only` defined to None in # test config - if 'unit_tests' in test_config_data and test_config_data['unit_tests'] is not None: - test_config_data['unit_tests'] = {k: v or [] for k, v in test_config_data['unit_tests'].items()} + if "unit_tests" in test_config_data and test_config_data["unit_tests"] is not None: + test_config_data["unit_tests"] = {k: v or [] for k, v in test_config_data["unit_tests"].items()} test_config = TestConfig.from_dict(test_file=test_config_path, **test_config_data) else: test_config = TestConfig.from_dict() # files # paths are relative - files = {f'{k}_file': base_dir.joinpath(v) for k, v in loaded['files'].items()} - contents = {k: load_dump(str(base_dir.joinpath(v).resolve())) for k, v in loaded['files'].items()} + files = {f"{k}_file": base_dir.joinpath(v) for k, v in loaded["files"].items()} + contents = {k: load_dump(str(base_dir.joinpath(v).resolve())) for k, v in loaded["files"].items()} contents.update(**files) # directories # paths are relative - if loaded.get('directories'): - contents.update({k: base_dir.joinpath(v).resolve() for k, v in loaded['directories'].items()}) + if loaded.get("directories"): + contents.update({k: base_dir.joinpath(v).resolve() for k, v in loaded["directories"].items()}) # rule_dirs # paths are relative - contents['rule_dirs'] = [base_dir.joinpath(d).resolve() for d in loaded.get('rule_dirs')] + contents["rule_dirs"] = [base_dir.joinpath(d).resolve() for d in loaded.get("rule_dirs")] # directories # paths are relative - if loaded.get('directories'): - directories = loaded.get('directories') - if directories.get('exception_dir'): - contents['exception_dir'] = base_dir.joinpath(directories.get('exception_dir')).resolve() - if directories.get('action_dir'): - contents['action_dir'] = base_dir.joinpath(directories.get('action_dir')).resolve() - if directories.get('action_connector_dir'): - contents['action_connector_dir'] = base_dir.joinpath(directories.get('action_connector_dir')).resolve() + if loaded.get("directories"): + directories = loaded.get("directories") + if directories.get("exception_dir"): + contents["exception_dir"] = base_dir.joinpath(directories.get("exception_dir")).resolve() + if directories.get("action_dir"): + contents["action_dir"] = base_dir.joinpath(directories.get("action_dir")).resolve() + if directories.get("action_connector_dir"): + contents["action_connector_dir"] = base_dir.joinpath(directories.get("action_connector_dir")).resolve() # version strategy - contents['bypass_version_lock'] = loaded.get('bypass_version_lock', False) + contents["bypass_version_lock"] = loaded.get("bypass_version_lock", False) # bbr_rules_dirs # paths are relative - if loaded.get('bbr_rules_dirs'): - contents['bbr_rules_dirs'] = [base_dir.joinpath(d).resolve() for d in loaded.get('bbr_rules_dirs', [])] + if loaded.get("bbr_rules_dirs"): + contents["bbr_rules_dirs"] = [base_dir.joinpath(d).resolve() for d in loaded.get("bbr_rules_dirs", [])] # kql keyword normalization - contents['normalize_kql_keywords'] = loaded.get('normalize_kql_keywords', True) + contents["normalize_kql_keywords"] = loaded.get("normalize_kql_keywords", True) - if loaded.get('auto_gen_schema_file'): - contents['auto_gen_schema_file'] = base_dir.joinpath(loaded['auto_gen_schema_file']) + if loaded.get("auto_gen_schema_file"): + contents["auto_gen_schema_file"] = base_dir.joinpath(loaded["auto_gen_schema_file"]) # Check if the file exists - if not contents['auto_gen_schema_file'].exists(): + if not contents["auto_gen_schema_file"].exists(): # If the file doesn't exist, create the necessary directories and file - contents['auto_gen_schema_file'].parent.mkdir(parents=True, exist_ok=True) - contents['auto_gen_schema_file'].write_text('{}') + contents["auto_gen_schema_file"].parent.mkdir(parents=True, exist_ok=True) + _ = contents["auto_gen_schema_file"].write_text("{}") # bypass_optional_elastic_validation - contents['bypass_optional_elastic_validation'] = loaded.get('bypass_optional_elastic_validation', False) - if contents['bypass_optional_elastic_validation']: - set_all_validation_bypass(contents['bypass_optional_elastic_validation']) + contents["bypass_optional_elastic_validation"] = loaded.get("bypass_optional_elastic_validation", False) + if contents["bypass_optional_elastic_validation"]: + set_all_validation_bypass(contents["bypass_optional_elastic_validation"]) # no_tactic_filename - contents['no_tactic_filename'] = loaded.get('no_tactic_filename', False) + contents["no_tactic_filename"] = loaded.get("no_tactic_filename", False) # return the config try: - rules_config = RulesConfig(test_config=test_config, **contents) + rules_config = RulesConfig(test_config=test_config, **contents) # type: ignore[reportArgumentType] except (ValueError, TypeError) as e: - raise SystemExit(f'Error parsing packages.yaml: {str(e)}') + raise SystemExit(f"Error parsing packages.yaml: {e!s}") from e return rules_config @@ -327,4 +340,4 @@ def parse_rules_config(path: Optional[Path] = None) -> RulesConfig: @cached def load_current_package_version() -> str: """Load the current package version from config file.""" - return parse_rules_config().packages['package']['name'] + return parse_rules_config().packages["package"]["name"] diff --git a/detection_rules/custom_rules.py b/detection_rules/custom_rules.py index dd99006750e..62352e50d71 100644 --- a/detection_rules/custom_rules.py +++ b/detection_rules/custom_rules.py @@ -4,6 +4,7 @@ # 2.0. """Commands for supporting custom rules.""" + from pathlib import Path import click @@ -14,12 +15,12 @@ from .main import root from .utils import ROOT_DIR, get_etc_path, load_etc_dump -DEFAULT_CONFIG_PATH = Path(get_etc_path('_config.yaml')) -CUSTOM_RULES_DOC_PATH = Path(ROOT_DIR).joinpath(REPO_DOCS_DIR, 'custom-rules-management.md') +DEFAULT_CONFIG_PATH = Path(get_etc_path(["_config.yaml"])) +CUSTOM_RULES_DOC_PATH = ROOT_DIR / REPO_DOCS_DIR / "custom-rules-management.md" -@root.group('custom-rules') -def custom_rules(): +@root.group("custom-rules") +def custom_rules() -> None: """Commands for supporting custom rules.""" @@ -27,22 +28,20 @@ def create_config_content() -> str: """Create the initial content for the _config.yaml file.""" # Base structure of the configuration config_content = { - 'rule_dirs': ['rules'], - 'bbr_rules_dirs': ['rules_building_block'], - 'directories': { - 'action_dir': 'actions', - 'action_connector_dir': 'action_connectors', - 'exception_dir': 'exceptions', + "rule_dirs": ["rules"], + "bbr_rules_dirs": ["rules_building_block"], + "directories": { + "action_dir": "actions", + "action_connector_dir": "action_connectors", + "exception_dir": "exceptions", }, - 'files': { - 'deprecated_rules': 'etc/deprecated_rules.json', - 'packages': 'etc/packages.yaml', - 'stack_schema_map': 'etc/stack-schema-map.yaml', - 'version_lock': 'etc/version.lock.json', + "files": { + "deprecated_rules": "etc/deprecated_rules.json", + "packages": "etc/packages.yaml", + "stack_schema_map": "etc/stack-schema-map.yaml", + "version_lock": "etc/version.lock.json", }, - 'testing': { - 'config': 'etc/test_config.yaml' - } + "testing": {"config": "etc/test_config.yaml"}, } return yaml.safe_dump(config_content, default_flow_style=False) @@ -77,24 +76,24 @@ def format_test_string(test_string: str, comment_char: str) -> str: return "\n".join(lines) -@custom_rules.command('setup-config') -@click.argument('directory', type=Path) -@click.argument('kibana-version', type=str, default=load_etc_dump('packages.yaml')['package']['name']) -@click.option('--overwrite', is_flag=True, help="Overwrite the existing _config.yaml file.") +@custom_rules.command("setup-config") +@click.argument("directory", type=Path) +@click.argument("kibana-version", type=str, default=load_etc_dump(["packages.yaml"])["package"]["name"]) +@click.option("--overwrite", is_flag=True, help="Overwrite the existing _config.yaml file.") @click.option( "--enable-prebuilt-tests", "-e", is_flag=True, help="Enable all prebuilt tests instead of default subset." ) -def setup_config(directory: Path, kibana_version: str, overwrite: bool, enable_prebuilt_tests: bool): +def setup_config(directory: Path, kibana_version: str, overwrite: bool, enable_prebuilt_tests: bool) -> None: """Setup the custom rules configuration directory and files with defaults.""" - config = directory / '_config.yaml' + config = directory / "_config.yaml" if not overwrite and config.exists(): - raise FileExistsError(f'{config} already exists. Use --overwrite to update') + raise FileExistsError(f"{config} already exists. Use --overwrite to update") - etc_dir = directory / 'etc' - test_config = etc_dir / 'test_config.yaml' - package_config = etc_dir / 'packages.yaml' - stack_schema_map_config = etc_dir / 'stack-schema-map.yaml' + etc_dir = directory / "etc" + test_config = etc_dir / "test_config.yaml" + package_config = etc_dir / "packages.yaml" + stack_schema_map_config = etc_dir / "stack-schema-map.yaml" config_files = [ package_config, stack_schema_map_config, @@ -102,49 +101,49 @@ def setup_config(directory: Path, kibana_version: str, overwrite: bool, enable_p config, ] directories = [ - directory / 'actions', - directory / 'action_connectors', - directory / 'exceptions', - directory / 'rules', - directory / 'rules_building_block', + directory / "actions", + directory / "action_connectors", + directory / "exceptions", + directory / "rules", + directory / "rules_building_block", etc_dir, ] version_files = [ - etc_dir / 'deprecated_rules.json', - etc_dir / 'version.lock.json', + etc_dir / "deprecated_rules.json", + etc_dir / "version.lock.json", ] # Create directories for dir_ in directories: dir_.mkdir(parents=True, exist_ok=True) - click.echo(f'Created directory: {dir_}') + click.echo(f"Created directory: {dir_}") # Create version_files and populate with default content if applicable for file_ in version_files: - file_.write_text('{}') - click.echo( - f'Created file with default content: {file_}' - ) + _ = file_.write_text("{}") + click.echo(f"Created file with default content: {file_}") # Create the stack-schema-map.yaml file - stack_schema_map_content = load_etc_dump('stack-schema-map.yaml') + stack_schema_map_content = load_etc_dump(["stack-schema-map.yaml"]) latest_version = max(stack_schema_map_content.keys(), key=lambda v: Version.parse(v)) latest_entry = {latest_version: stack_schema_map_content[latest_version]} - stack_schema_map_config.write_text(yaml.safe_dump(latest_entry, default_flow_style=False)) + _ = stack_schema_map_config.write_text(yaml.safe_dump(latest_entry, default_flow_style=False)) # Create default packages.yaml - package_content = {'package': {'name': kibana_version}} - package_config.write_text(yaml.safe_dump(package_content, default_flow_style=False)) + package_content = {"package": {"name": kibana_version}} + _ = package_config.write_text(yaml.safe_dump(package_content, default_flow_style=False)) # Create and configure test_config.yaml - test_config.write_text(create_test_config_content(enable_prebuilt_tests)) + _ = test_config.write_text(create_test_config_content(enable_prebuilt_tests)) # Create and configure _config.yaml - config.write_text(create_config_content()) + _ = config.write_text(create_config_content()) for file_ in config_files: - click.echo(f'Created file with default content: {file_}') + click.echo(f"Created file with default content: {file_}") - click.echo(f'\n# For details on how to configure the _config.yaml file,\n' - f'# consult: {DEFAULT_CONFIG_PATH.resolve()}\n' - f'# or the docs: {CUSTOM_RULES_DOC_PATH.resolve()}') + click.echo( + f"\n# For details on how to configure the _config.yaml file,\n" + f"# consult: {DEFAULT_CONFIG_PATH.resolve()}\n" + f"# or the docs: {CUSTOM_RULES_DOC_PATH.resolve()}" + ) diff --git a/detection_rules/custom_schemas.py b/detection_rules/custom_schemas.py index 84252178b7d..0b1390996c5 100644 --- a/detection_rules/custom_schemas.py +++ b/detection_rules/custom_schemas.py @@ -4,12 +4,13 @@ # 2.0. """Custom Schemas management.""" + import uuid from pathlib import Path +from typing import Any -import eql -import eql.types -from eql import load_dump, save_dump +import eql # type: ignore[reportMissingTypeStubs] +from eql import load_dump, save_dump # type: ignore[reportMissingTypeStubs] from .config import parse_rules_config from .utils import cached, clear_caches @@ -19,9 +20,9 @@ @cached -def get_custom_schemas(stack_version: str = None) -> dict: +def get_custom_schemas(stack_version: str | None = None) -> dict[str, Any]: """Load custom schemas if present.""" - custom_schema_dump = {} + custom_schema_dump: dict[str, Any] = {} stack_versions = [stack_version] if stack_version else RULES_CONFIG.stack_schema_map.keys() @@ -34,7 +35,7 @@ def get_custom_schemas(stack_version: str = None) -> dict: if not schema_path.is_absolute(): schema_path = RULES_CONFIG.stack_schema_map_file.parent / value if schema_path.is_file(): - custom_schema_dump.update(eql.utils.load_dump(str(schema_path))) + custom_schema_dump.update(eql.utils.load_dump(str(schema_path))) # type: ignore[reportUnknownMemberType] else: raise ValueError(f"Custom schema must be a file: {schema_path}") @@ -47,19 +48,22 @@ def resolve_schema_path(path: str) -> Path: return path_obj if path_obj.is_absolute() else RULES_CONFIG.stack_schema_map_file.parent.joinpath(path) -def update_data(index: str, field: str, data: dict, field_type: str = None) -> dict: +def update_data(index: str, field: str, data: dict[str, Any], field_type: str | None = None) -> dict[str, Any]: """Update the schema entry with the appropriate index and field.""" data.setdefault(index, {})[field] = field_type if field_type else "keyword" return data -def update_stack_schema_map(stack_schema_map: dict, auto_gen_schema_file: str) -> dict: +def update_stack_schema_map( + stack_schema_map: dict[str, Any], + auto_gen_schema_file: str, +) -> tuple[dict[str, Any], str | None, str]: """Update the stack-schema-map.yaml file with the appropriate auto_gen_schema_file location.""" random_uuid = str(uuid.uuid4()) auto_generated_id = None - for version in stack_schema_map: + for val in stack_schema_map.values(): key_found = False - for key, value in stack_schema_map[version].items(): + for key, value in val.items(): value_path = resolve_schema_path(value) if value_path == Path(auto_gen_schema_file).resolve() and key not in RESERVED_SCHEMA_NAMES: auto_generated_id = key @@ -68,19 +72,21 @@ def update_stack_schema_map(stack_schema_map: dict, auto_gen_schema_file: str) - if key_found is False: if auto_generated_id is None: auto_generated_id = random_uuid - stack_schema_map[version][auto_generated_id] = str(auto_gen_schema_file) + val[auto_generated_id] = str(auto_gen_schema_file) return stack_schema_map, auto_generated_id, random_uuid -def clean_stack_schema_map(stack_schema_map: dict, auto_generated_id: str, random_uuid: str) -> dict: +def clean_stack_schema_map( + stack_schema_map: dict[str, Any], auto_generated_id: str, random_uuid: str +) -> dict[str, Any]: """Clean up the stack-schema-map.yaml file replacing the random UUID with a known key if possible.""" - for version in stack_schema_map: - if random_uuid in stack_schema_map[version]: - stack_schema_map[version][auto_generated_id] = stack_schema_map[version].pop(random_uuid) + for val in stack_schema_map.values(): + if random_uuid in val: + val[auto_generated_id] = val.pop(random_uuid) return stack_schema_map -def update_auto_generated_schema(index: str, field: str, field_type: str = None): +def update_auto_generated_schema(index: str, field: str, field_type: str | None = None) -> None: """Load custom schemas if present.""" auto_gen_schema_file = str(RULES_CONFIG.auto_gen_schema_file) stack_schema_map_file = str(RULES_CONFIG.stack_schema_map_file) @@ -93,6 +99,10 @@ def update_auto_generated_schema(index: str, field: str, field_type: str = None) # Update the stack-schema-map.yaml file with the appropriate auto_gen_schema_file location stack_schema_map = load_dump(stack_schema_map_file) stack_schema_map, auto_generated_id, random_uuid = update_stack_schema_map(stack_schema_map, auto_gen_schema_file) + + if not auto_generated_id: + raise ValueError("Autogenerated ID not found") + save_dump(stack_schema_map, stack_schema_map_file) # Clean up the stack-schema-map.yaml file replacing the random UUID with the auto_generated_id diff --git a/detection_rules/devtools.py b/detection_rules/devtools.py index 11073780b30..a9d8f2b4c13 100644 --- a/detection_rules/devtools.py +++ b/detection_rules/devtools.py @@ -4,9 +4,9 @@ # 2.0. """CLI commands for internal detection_rules dev team.""" + import csv import dataclasses -import io import json import os import re @@ -18,51 +18,60 @@ import urllib.parse from collections import defaultdict from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Any, Literal +from uuid import uuid4 import click -import pytoml +import pytoml # type: ignore[reportMissingTypeStubs] import requests.exceptions import yaml from elasticsearch import Elasticsearch -from eql.table import Table +from eql.table import Table # type: ignore[reportMissingTypeStubs] +from eql.utils import load_dump # type: ignore[reportMissingTypeStubs, reportUnknownVariableType] +from kibana.connector import Kibana # type: ignore[reportMissingTypeStubs] +from kibana.resources import Signal # type: ignore[reportMissingTypeStubs] from semver import Version -from kibana.connector import Kibana - from . import attack, rule_loader, utils -from .beats import (download_beats_schema, download_latest_beats_schema, - refresh_main_schema) +from .beats import download_beats_schema, download_latest_beats_schema, refresh_main_schema from .cli_utils import single_collection from .config import parse_rules_config -from .docs import IntegrationSecurityDocs, IntegrationSecurityDocsMDX, REPO_DOCS_DIR +from .docs import REPO_DOCS_DIR, IntegrationSecurityDocs, IntegrationSecurityDocsMDX from .ecs import download_endpoint_schemas, download_schemas from .endgame import EndgameSchemaManager from .eswrap import CollectEvents, add_range_to_dsl from .ghwrap import GithubClient, update_gist -from .integrations import (SecurityDetectionEngine, - build_integrations_manifest, - build_integrations_schemas, - find_latest_compatible_version, - find_latest_integration_version, - load_integrations_manifests) +from .integrations import ( + SecurityDetectionEngine, + build_integrations_manifest, + build_integrations_schemas, + find_latest_compatible_version, + find_latest_integration_version, + load_integrations_manifests, +) from .main import root -from .misc import PYTHON_LICENSE, add_client, client_error -from .packaging import (CURRENT_RELEASE_PATH, PACKAGE_FILE, RELEASE_DIR, - Package) -from .rule import (AnyRuleData, BaseRuleData, DeprecatedRule, QueryRuleData, - RuleTransform, ThreatMapping, TOMLRule, TOMLRuleContents) +from .misc import PYTHON_LICENSE, add_client, raise_client_error +from .packaging import CURRENT_RELEASE_PATH, PACKAGE_FILE, RELEASE_DIR, Package +from .rule import ( + AnyRuleData, + BaseRuleData, + DeprecatedRule, + QueryRuleData, + RuleTransform, + ThreatMapping, + TOMLRule, + TOMLRuleContents, +) from .rule_loader import RuleCollection, production_filter from .schemas import definitions, get_stack_versions -from .utils import (dict_hash, get_etc_path, get_path, check_version_lock_double_bumps, - load_dump) +from .utils import check_version_lock_double_bumps, dict_hash, get_etc_path, get_path from .version_lock import VersionLockFile, loaded_version_lock GH_CONFIG = Path.home() / ".config" / "gh" / "hosts.yml" -NAVIGATOR_GIST_ID = '0443cfb5016bed103f1940b2f336e45a' -NAVIGATOR_URL = 'https://ela.st/detection-rules-navigator-trade' +NAVIGATOR_GIST_ID = "0443cfb5016bed103f1940b2f336e45a" +NAVIGATOR_URL = "https://ela.st/detection-rules-navigator-trade" NAVIGATOR_BADGE = ( - f'[![ATT&CK navigator coverage](https://img.shields.io/badge/ATT&CK-Navigator-red.svg)]({NAVIGATOR_URL})' + f"[![ATT&CK navigator coverage](https://img.shields.io/badge/ATT&CK-Navigator-red.svg)]({NAVIGATOR_URL})" ) RULES_CONFIG = parse_rules_config() @@ -74,7 +83,7 @@ MAX_HISTORICAL_VERSIONS_PRE_DIFF = 1 -def get_github_token() -> Optional[str]: +def get_github_token() -> str | None: """Get the current user's GitHub token.""" token = os.getenv("GITHUB_TOKEN") @@ -84,60 +93,81 @@ def get_github_token() -> Optional[str]: return token -@root.group('dev') -def dev_group(): +@root.group("dev") +def dev_group() -> None: """Commands related to the Elastic Stack rules release lifecycle.""" -@dev_group.command('build-release') -@click.argument('config-file', type=click.Path(exists=True, dir_okay=False), required=False, default=PACKAGE_FILE) -@click.option('--update-version-lock', '-u', is_flag=True, - help='Save version.lock.json file with updated rule versions in the package') -@click.option('--generate-navigator', is_flag=True, help='Generate ATT&CK navigator files') -@click.option('--generate-docs', is_flag=True, default=False, help='Generate markdown documentation') -@click.option('--update-message', type=str, help='Update message for new package') +@dev_group.command("build-release") +@click.argument( + "config-file", type=click.Path(exists=True, dir_okay=False, path_type=Path), required=False, default=PACKAGE_FILE +) +@click.option( + "--update-version-lock", + "-u", + is_flag=True, + help="Save version.lock.json file with updated rule versions in the package", +) +@click.option("--generate-navigator", is_flag=True, help="Generate ATT&CK navigator files") +@click.option("--generate-docs", is_flag=True, default=False, help="Generate markdown documentation") +@click.option("--update-message", type=str, help="Update message for new package") @click.pass_context -def build_release(ctx: click.Context, config_file, update_version_lock: bool, generate_navigator: bool, - generate_docs: str, update_message: str, release=None, verbose=True): +def build_release( # noqa: PLR0913 + ctx: click.Context, + config_file: Path, + update_version_lock: bool, + generate_navigator: bool, + generate_docs: str, + update_message: str, + release: str | None = None, + verbose: bool = True, +) -> Package: """Assemble all the rules into Kibana-ready release files.""" if RULES_CONFIG.bypass_version_lock: - click.echo('WARNING: You cannot run this command when the versioning strategy is configured to bypass the ' - 'version lock. Set `bypass_version_lock` to `False` in the rules config to use the version lock.') + click.echo( + "WARNING: You cannot run this command when the versioning strategy is configured to bypass the " + "version lock. Set `bypass_version_lock` to `False` in the rules config to use the version lock." + ) ctx.exit() - config = load_dump(config_file)['package'] + config = load_dump(str(config_file))["package"] - err_msg = f'No `registry_data` in package config. Please see the {get_etc_path("package.yaml")} file for an' \ - f' example on how to supply this field in {PACKAGE_FILE}.' - assert 'registry_data' in config, err_msg + package_path = get_etc_path(["package.yaml"]) + if "registry_data" not in config: + raise ValueError( + f"No `registry_data` in package config. Please see the {package_path} file for an" + f" example on how to supply this field in {PACKAGE_FILE}." + ) - registry_data = config['registry_data'] + registry_data = config["registry_data"] if generate_navigator: - config['generate_navigator'] = True + config["generate_navigator"] = True if release is not None: - config['release'] = release + config["release"] = release if verbose: - click.echo(f'[+] Building package {config.get("name")}') + click.echo(f"[+] Building package {config.get('name')}") package = Package.from_config(config=config, verbose=verbose) if update_version_lock: - loaded_version_lock.manage_versions(package.rules, save_changes=True, verbose=verbose) + _ = loaded_version_lock.manage_versions(package.rules, save_changes=True, verbose=verbose) package.save(verbose=verbose) - previous_pkg_version = find_latest_integration_version("security_detection_engine", "ga", - registry_data['conditions']['kibana.version'].strip("^")) + previous_pkg_version = find_latest_integration_version( + "security_detection_engine", "ga", registry_data["conditions"]["kibana.version"].strip("^") + ) sde = SecurityDetectionEngine() historical_rules = sde.load_integration_assets(previous_pkg_version) - current_pkg_version = Version.parse(registry_data['version']) + current_pkg_version = Version.parse(registry_data["version"]) # pre-release versions are not included in the version comparison # Version 8.17.0-beta.1 is considered lower than 8.17.0 - current_pkg_version_no_prerelease = Version(major=current_pkg_version.major, - minor=current_pkg_version.minor, patch=current_pkg_version.patch) + current_pkg_version_no_prerelease = Version( + major=current_pkg_version.major, minor=current_pkg_version.minor, patch=current_pkg_version.patch + ) hist_versions_num = ( MAX_HISTORICAL_VERSIONS_FOR_DIFF @@ -145,78 +175,106 @@ def build_release(ctx: click.Context, config_file, update_version_lock: bool, ge else MAX_HISTORICAL_VERSIONS_PRE_DIFF ) click.echo( - '[+] Limit historical rule versions in the release package for ' - f'version {current_pkg_version_no_prerelease}: {hist_versions_num} versions') + "[+] Limit historical rule versions in the release package for " + f"version {current_pkg_version_no_prerelease}: {hist_versions_num} versions" + ) limited_historical_rules = sde.keep_latest_versions(historical_rules, num_versions=hist_versions_num) - package.add_historical_rules(limited_historical_rules, registry_data['version']) - click.echo(f'[+] Adding historical rules from {previous_pkg_version} package') + _ = package.add_historical_rules(limited_historical_rules, registry_data["version"]) + click.echo(f"[+] Adding historical rules from {previous_pkg_version} package") # NOTE: stopgap solution until security doc migration if generate_docs: - click.echo(f'[+] Generating security docs for {registry_data["version"]} package') - docs = IntegrationSecurityDocsMDX(registry_data['version'], Path(f'releases/{config["name"]}-docs'), - True, limited_historical_rules, package, note=update_message) - docs.generate() + click.echo(f"[+] Generating security docs for {registry_data['version']} package") + docs = IntegrationSecurityDocsMDX( + registry_data["version"], + Path(f"releases/{config['name']}-docs"), + True, + package, + limited_historical_rules, + note=update_message, + ) + _ = docs.generate() if verbose: - package.get_package_hash(verbose=verbose) - click.echo(f'- {len(package.rules)} rules included') + _ = package.get_package_hash(verbose=verbose) + click.echo(f"- {len(package.rules)} rules included") return package -def get_release_diff(pre: str, post: str, remote: Optional[str] = 'origin' - ) -> (Dict[str, TOMLRule], Dict[str, TOMLRule], Dict[str, DeprecatedRule]): +def get_release_diff( + pre: str, + post: str, + remote: str = "origin", +) -> tuple[dict[str, TOMLRule], dict[str, TOMLRule], dict[str, DeprecatedRule]]: """Build documents from two git tags for an integration package.""" pre_rules = RuleCollection() - pre_rules.load_git_tag(f'integration-v{pre}', remote, skip_query_validation=True) + pre_rules.load_git_tag(f"integration-v{pre}", remote, skip_query_validation=True) if pre_rules.errors: - click.echo(f'error loading {len(pre_rules.errors)} rule(s) from: {pre}, skipping:') - click.echo(' - ' + '\n - '.join([str(p) for p in pre_rules.errors])) + click.echo(f"error loading {len(pre_rules.errors)} rule(s) from: {pre}, skipping:") + click.echo(" - " + "\n - ".join([str(p) for p in pre_rules.errors])) post_rules = RuleCollection() - post_rules.load_git_tag(f'integration-v{post}', remote, skip_query_validation=True) + post_rules.load_git_tag(f"integration-v{post}", remote, skip_query_validation=True) if post_rules.errors: - click.echo(f'error loading {len(post_rules.errors)} rule(s) from: {post}, skipping:') - click.echo(' - ' + '\n - '.join([str(p) for p in post_rules.errors])) + click.echo(f"error loading {len(post_rules.errors)} rule(s) from: {post}, skipping:") + click.echo(" - " + "\n - ".join([str(p) for p in post_rules.errors])) - rules_changes = pre_rules.compare_collections(post_rules) - return rules_changes + return pre_rules.compare_collections(post_rules) -@dev_group.command('build-integration-docs') -@click.argument('registry-version') -@click.option('--pre', required=True, type=str, help='Tag for pre-existing rules') -@click.option('--post', required=True, type=str, help='Tag for rules post updates') -@click.option('--directory', '-d', type=Path, required=True, help='Output directory to save docs to') -@click.option('--force', '-f', is_flag=True, help='Bypass the confirmation prompt') -@click.option('--remote', '-r', default='origin', help='Override the remote from "origin"') -@click.option('--update-message', default='Rule Updates.', type=str, help='Update message for new package') +@dev_group.command("build-integration-docs") +@click.argument("registry-version") +@click.option("--pre", required=True, type=str, help="Tag for pre-existing rules") +@click.option("--post", required=True, type=str, help="Tag for rules post updates") +@click.option("--directory", "-d", type=Path, required=True, help="Output directory to save docs to") +@click.option("--force", "-f", is_flag=True, help="Bypass the confirmation prompt") +@click.option("--remote", "-r", default="origin", help='Override the remote from "origin"') +@click.option("--update-message", default="Rule Updates.", type=str, help="Update message for new package") @click.pass_context -def build_integration_docs(ctx: click.Context, registry_version: str, pre: str, post: str, - directory: Path, force: bool, update_message: str, - remote: Optional[str] = 'origin') -> IntegrationSecurityDocs: +def build_integration_docs( # noqa: PLR0913 + ctx: click.Context, + registry_version: str, + pre: str, + post: str, + directory: Path, + force: bool, + update_message: str, + remote: str = "origin", +) -> IntegrationSecurityDocs: """Build documents from two git tags for an integration package.""" - if not force: - if not click.confirm(f'This will refresh tags and may overwrite local tags for: {pre} and {post}. Continue?'): - ctx.exit(1) + if not force and not click.confirm( + f"This will refresh tags and may overwrite local tags for: {pre} and {post}. Continue?" + ): + ctx.exit(1) + + if Version.parse(pre) >= Version.parse(post): + raise ValueError(f"pre: {pre} is not less than post: {post}") + + if not Version.parse(pre): + raise ValueError(f"pre: {pre} is not a valid semver") - assert Version.parse(pre) < Version.parse(post), f'pre: {pre} is not less than post: {post}' - assert Version.parse(pre), f'pre: {pre} is not a valid semver' - assert Version.parse(post), f'post: {post} is not a valid semver' + if not Version.parse(post): + raise ValueError(f"post: {post} is not a valid semver") rules_changes = get_release_diff(pre, post, remote) - docs = IntegrationSecurityDocs(registry_version, directory, True, *rules_changes, update_message=update_message) + docs = IntegrationSecurityDocs( + registry_version, + directory, + True, + *rules_changes, + update_message=update_message, + ) package_dir = docs.generate() - click.echo(f'Generated documents saved to: {package_dir}') + click.echo(f"Generated documents saved to: {package_dir}") updated, new, deprecated = rules_changes - click.echo(f'- {len(updated)} updated rules') - click.echo(f'- {len(new)} new rules') - click.echo(f'- {len(deprecated)} deprecated rules') + click.echo(f"- {len(updated)} updated rules") + click.echo(f"- {len(new)} new rules") + click.echo(f"- {len(deprecated)} deprecated rules") return docs @@ -225,13 +283,23 @@ def build_integration_docs(ctx: click.Context, registry_version: str, pre: str, @click.option("--major-release", is_flag=True, help="bump the major version") @click.option("--minor-release", is_flag=True, help="bump the minor version") @click.option("--patch-release", is_flag=True, help="bump the patch version") -@click.option("--new-package", type=click.Choice(['true', 'false']), help="indicates new package") -@click.option("--maturity", type=click.Choice(['beta', 'ga'], case_sensitive=False), - required=True, help="beta or production versions") -def bump_versions(major_release: bool, minor_release: bool, patch_release: bool, new_package: str, maturity: str): +@click.option("--new-package", type=click.Choice(["true", "false"]), help="indicates new package") +@click.option( + "--maturity", + type=click.Choice(["beta", "ga"], case_sensitive=False), + required=True, + help="beta or production versions", +) +def bump_versions( + major_release: bool, + minor_release: bool, + patch_release: bool, + new_package: str, + maturity: str, +) -> None: """Bump the versions""" - pkg_data = RULES_CONFIG.packages['package'] + pkg_data = RULES_CONFIG.packages["package"] kibana_ver = Version.parse(pkg_data["name"], optional_minor_and_patch=True) pkg_ver = Version.parse(pkg_data["registry_data"]["version"]) pkg_kibana_ver = Version.parse(pkg_data["registry_data"]["conditions"]["kibana.version"].lstrip("^")) @@ -246,8 +314,9 @@ def bump_versions(major_release: bool, minor_release: bool, patch_release: bool, pkg_data["registry_data"]["conditions"]["kibana.version"] = f"^{pkg_kibana_ver.bump_minor()}" pkg_data["registry_data"]["version"] = str(pkg_ver.bump_minor().bump_prerelease("beta")) if patch_release: - latest_patch_release_ver = find_latest_integration_version("security_detection_engine", - maturity, pkg_kibana_ver) + latest_patch_release_ver = find_latest_integration_version( + "security_detection_engine", maturity, pkg_kibana_ver + ) # if an existing minor or major does not have a package, bump from the last # example is 8.10.0-beta.1 is last, but on 9.0.0 major @@ -265,8 +334,8 @@ def bump_versions(major_release: bool, minor_release: bool, patch_release: bool, latest_patch_release_ver = latest_patch_release_ver.bump_patch() pkg_data["registry_data"]["version"] = str(latest_patch_release_ver.bump_prerelease("beta")) - if 'release' in pkg_data['registry_data']: - pkg_data['registry_data']['release'] = maturity + if "release" in pkg_data["registry_data"]: + pkg_data["registry_data"]["release"] = maturity click.echo(f"Kibana version: {pkg_data['name']}") click.echo(f"Package Kibana version: {pkg_data['registry_data']['conditions']['kibana.version']}") @@ -294,9 +363,14 @@ def bump_versions(major_release: bool, minor_release: bool, patch_release: bool, @click.option("--comment", is_flag=True, help="If set, enables commenting on the PR (requires --pr-number)") @click.option("--save-double-bumps", type=Path, help="Optional path to save the double bumps to a file") @click.pass_context -def check_version_lock( - ctx: click.Context, pr_number: int, local_file: str, token: str, comment: bool, save_double_bumps: Path -): +def check_version_lock( # noqa: PLR0913 + ctx: click.Context, + pr_number: int, + local_file: str, + token: str, + comment: bool, + save_double_bumps: Path, +) -> None: """ Check the version lock file and optionally comment on the PR if the --comment flag is set. @@ -312,7 +386,7 @@ def check_version_lock( double_bumps = [] comment_body = "No double bumps detected." - def format_comment_body(double_bumps: list) -> str: + def format_comment_body(double_bumps: list[tuple[str, str, int, int]]) -> str: """Format the comment body for double bumps.""" comment_body = f"{len(double_bumps)} Double bumps detected:\n\n" comment_body += "
\n" @@ -325,16 +399,18 @@ def format_comment_body(double_bumps: list) -> str: comment_body += "\n
\n" return comment_body - def save_double_bumps_to_file(double_bumps: list, save_path: Path): + def save_double_bumps_to_file(double_bumps: list[tuple[str, str, int, int]], save_path: Path) -> None: """Save double bumps to a CSV file.""" save_path.parent.mkdir(parents=True, exist_ok=True) if save_path.is_file(): click.echo(f"File {save_path} already exists. Skipping save.") else: with save_path.open("w", newline="") as csvfile: - csv.writer(csvfile).writerows([["Rule ID", "Rule Name", "Removed", "Added"]] + double_bumps) + csv.writer(csvfile).writerows([["Rule ID", "Rule Name", "Removed", "Added"], *double_bumps]) click.echo(f"Double bumps saved to {save_path}") + pr = None + if pr_number: click.echo(f"Fetching version lock file from PR #{pr_number}") pr = repo.get_pull(pr_number) @@ -349,47 +425,50 @@ def save_double_bumps_to_file(double_bumps: list, save_path: Path): click.echo(f"{len(double_bumps)} Double bumps detected") if comment and pr_number: comment_body = format_comment_body(double_bumps) - pr.create_issue_comment(comment_body) + if pr: + _ = pr.create_issue_comment(comment_body) if save_double_bumps: save_double_bumps_to_file(double_bumps, save_double_bumps) ctx.exit(1) else: click.echo("No double bumps detected.") - if comment and pr_number: - pr.create_issue_comment(comment_body) + if comment and pr_number and pr: + _ = pr.create_issue_comment(comment_body) @dataclasses.dataclass class GitChangeEntry: status: str original_path: Path - new_path: Optional[Path] = None + new_path: Path | None = None @classmethod - def from_line(cls, text: str) -> 'GitChangeEntry': + def from_line(cls, text: str) -> "GitChangeEntry": columns = text.split("\t") - assert 2 <= len(columns) <= 3 - - columns[1:] = [Path(c) for c in columns[1:]] - return cls(*columns) + if len(columns) not in (2, 3): + raise ValueError("Unexpected number of columns") + paths = [Path(c) for c in columns[1:]] + return cls(columns[0], *paths) @property def path(self) -> Path: return self.new_path or self.original_path - def revert(self, dry_run=False): + def revert(self, dry_run: bool = False) -> None: """Run a git command to revert this change.""" - def git(*args): + def git(*args: Any) -> None: command_line = ["git"] + [str(arg) for arg in args] click.echo(subprocess.list2cmdline(command_line)) if not dry_run: - subprocess.check_call(command_line) + _ = subprocess.check_call(command_line) if self.status.startswith("R"): # renames are actually Delete (D) and Add (A) # revert in opposite order + if not self.new_path: + raise ValueError("No new path found") GitChangeEntry("A", self.new_path).revert(dry_run=dry_run) GitChangeEntry("D", self.original_path).revert(dry_run=dry_run) return @@ -397,11 +476,11 @@ def git(*args): # remove the file from the staging area (A|M|D) git("restore", "--staged", self.original_path) - def read(self, git_tree="HEAD") -> bytes: + def read(self, git_tree: str = "HEAD") -> bytes: """Read the file from disk or git.""" if self.status == "D": # deleted files need to be recovered from git - return subprocess.check_output(["git", "show", f"{git_tree}:{self.path}"]) + return subprocess.check_output(["git", "show", f"{git_tree}:{self.path}"]) # noqa: S607 return self.path.read_bytes() @@ -410,21 +489,21 @@ def read(self, git_tree="HEAD") -> bytes: @click.option("--target-stack-version", "-t", help="Minimum stack version to filter the staging area", required=True) @click.option("--dry-run", is_flag=True, help="List the changes that would be made") @click.option("--exception-list", help="List of files to skip staging", default="") -def prune_staging_area(target_stack_version: str, dry_run: bool, exception_list: str): +def prune_staging_area(target_stack_version: str, dry_run: bool, exception_list: str) -> None: """Prune the git staging area to remove changes to incompatible rules.""" exceptions = { "detection_rules/etc/packages.yaml", } exceptions.update(exception_list.split(",")) - target_stack_version = Version.parse(target_stack_version, optional_minor_and_patch=True) + target_stack_version_parsed = Version.parse(target_stack_version, optional_minor_and_patch=True) # load a structured summary of the diff from git - git_output = subprocess.check_output(["git", "diff", "--name-status", "HEAD"]) + git_output = subprocess.check_output(["git", "diff", "--name-status", "HEAD"]) # noqa: S607 changes = [GitChangeEntry.from_line(line) for line in git_output.decode("utf-8").splitlines()] # track which changes need to be reverted because of incompatibilities - reversions: List[GitChangeEntry] = [] + reversions: list[GitChangeEntry] = [] for change in changes: if str(change.path) in exceptions: @@ -437,10 +516,11 @@ def prune_staging_area(target_stack_version: str, dry_run: bool, exception_list: if str(change.path.absolute()).startswith(str(rules_dir)) and change.path.suffix == ".toml": # bypass TOML validation in case there were schema changes dict_contents = RuleCollection.deserialize_toml_string(change.read()) - min_stack_version: Optional[str] = dict_contents.get("metadata", {}).get("min_stack_version") + min_stack_version: str | None = dict_contents.get("metadata", {}).get("min_stack_version") - if min_stack_version is not None and \ - (target_stack_version < Version.parse(min_stack_version, optional_minor_and_patch=True)): + if min_stack_version is not None and ( + target_stack_version_parsed < Version.parse(min_stack_version, optional_minor_and_patch=True) + ): # rule is incompatible, add to the list of reversions to make later reversions.append(change) break @@ -454,93 +534,100 @@ def prune_staging_area(target_stack_version: str, dry_run: bool, exception_list: change.revert(dry_run=dry_run) -@dev_group.command('update-lock-versions') -@click.argument('rule-ids', nargs=-1, required=False) +@dev_group.command("update-lock-versions") +@click.argument("rule-ids", nargs=-1, required=False) @click.pass_context -@click.option('--force', is_flag=True, help='Force update without confirmation') -def update_lock_versions(ctx: click.Context, rule_ids: Tuple[str, ...], force: bool): +@click.option("--force", is_flag=True, help="Force update without confirmation") +def update_lock_versions(ctx: click.Context, rule_ids: tuple[str, ...], force: bool) -> list[definitions.UUIDString]: """Update rule hashes in version.lock.json file without bumping version.""" rules = RuleCollection.default() - - if rule_ids: - rules = rules.filter(lambda r: r.id in rule_ids) - else: - rules = rules.filter(production_filter) + rules = rules.filter(lambda r: r.id in rule_ids) if rule_ids else rules.filter(production_filter) if not force and not click.confirm( - f'Are you sure you want to update hashes for {len(rules)} rules without a version bump?' + f"Are you sure you want to update hashes for {len(rules)} rules without a version bump?" ): - return + return [] if RULES_CONFIG.bypass_version_lock: - click.echo('WARNING: You cannot run this command when the versioning strategy is configured to bypass the ' - 'version lock. Set `bypass_version_lock` to `False` in the rules config to use the version lock.') + click.echo( + "WARNING: You cannot run this command when the versioning strategy is configured to bypass the " + "version lock. Set `bypass_version_lock` to `False` in the rules config to use the version lock." + ) ctx.exit() # this command may not function as expected anymore due to previous changes eliminating the use of add_new=False - changed, new, _ = loaded_version_lock.manage_versions(rules, exclude_version_update=True, save_changes=True) + changed, _, _ = loaded_version_lock.manage_versions(rules, exclude_version_update=True, save_changes=True) if not changed: - click.echo('No hashes updated') + click.echo("No hashes updated") return changed -@dev_group.command('kibana-diff') -@click.option('--rule-id', '-r', multiple=True, help='Optionally specify rule ID') -@click.option('--repo', default='elastic/kibana', help='Repository where branch is located') -@click.option('--branch', '-b', default='main', help='Specify the kibana branch to diff against') -@click.option('--threads', '-t', type=click.IntRange(1), default=50, help='Number of threads to use to download rules') -def kibana_diff(rule_id, repo, branch, threads): +@dev_group.command("kibana-diff") +@click.option("--rule-id", "-r", multiple=True, help="Optionally specify rule ID") +@click.option("--repo", default="elastic/kibana", help="Repository where branch is located") +@click.option("--branch", "-b", default="main", help="Specify the kibana branch to diff against") +@click.option("--threads", "-t", type=click.IntRange(1), default=50, help="Number of threads to use to download rules") +def kibana_diff(rule_id: list[str], repo: str, branch: str, threads: int) -> dict[str, Any]: """Diff rules against their version represented in kibana if exists.""" from .misc import get_kibana_rules rules = RuleCollection.default() - - if rule_id: - rules = rules.filter(lambda r: r.id in rule_id).id_map - else: - rules = rules.filter(production_filter).id_map + rules = rules.filter(lambda r: r.id in rule_id).id_map if rule_id else rules.filter(production_filter).id_map repo_hashes = {r.id: r.contents.get_hash(include_version=True) for r in rules.values()} - kibana_rules = {r['rule_id']: r for r in get_kibana_rules(repo=repo, branch=branch, threads=threads).values()} - kibana_hashes = {r['rule_id']: dict_hash(r) for r in kibana_rules.values()} + kibana_rules = {r["rule_id"]: r for r in get_kibana_rules(repo=repo, branch=branch, threads=threads).values()} + kibana_hashes = {r["rule_id"]: dict_hash(r) for r in kibana_rules.values()} missing_from_repo = list(set(kibana_hashes).difference(set(repo_hashes))) missing_from_kibana = list(set(repo_hashes).difference(set(kibana_hashes))) - rule_diff = [] - for rule_id, rule_hash in repo_hashes.items(): - if rule_id in missing_from_kibana: + rule_diff: list[str] = [] + for _rule_id, _rule_hash in repo_hashes.items(): + if _rule_id in missing_from_kibana: continue - if rule_hash != kibana_hashes[rule_id]: + if _rule_hash != kibana_hashes[_rule_id]: rule_diff.append( - f'versions - repo: {rules[rule_id].contents.autobumped_version}, ' - f'kibana: {kibana_rules[rule_id]["version"]} -> ' - f'{rule_id} - {rules[rule_id].contents.name}' + f"versions - repo: {rules[_rule_id].contents.autobumped_version}, " + f"kibana: {kibana_rules[_rule_id]['version']} -> " + f"{_rule_id} - {rules[_rule_id].contents.name}" ) - diff = { - 'missing_from_kibana': [f'{r} - {rules[r].name}' for r in missing_from_kibana], - 'diff': rule_diff, - 'missing_from_repo': [f'{r} - {kibana_rules[r]["name"]}' for r in missing_from_repo] + diff: dict[str, Any] = { + "missing_from_kibana": [f"{r} - {rules[r].name}" for r in missing_from_kibana], + "diff": rule_diff, + "missing_from_repo": [f"{r} - {kibana_rules[r]['name']}" for r in missing_from_repo], } - diff['stats'] = {k: len(v) for k, v in diff.items()} - diff['stats'].update(total_repo_prod_rules=len(rules), total_gh_prod_rules=len(kibana_rules)) + diff["stats"] = {k: len(v) for k, v in diff.items()} + diff["stats"].update(total_repo_prod_rules=len(rules), total_gh_prod_rules=len(kibana_rules)) click.echo(json.dumps(diff, indent=2, sort_keys=True)) return diff @dev_group.command("integrations-pr") -@click.argument("local-repo", type=click.Path(exists=True, file_okay=False, dir_okay=True), - default=get_path("..", "integrations")) -@click.option("--token", required=True, prompt=get_github_token() is None, default=get_github_token(), - help="GitHub token to use for the PR", hide_input=True) -@click.option("--pkg-directory", "-d", help="Directory to save the package in cloned repository", - default=Path("packages", "security_detection_engine")) +@click.argument( + "local-repo", + type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path), + default=get_path(["..", "integrations"]), +) +@click.option( + "--token", + required=True, + prompt=get_github_token() is None, + default=get_github_token(), + help="GitHub token to use for the PR", + hide_input=True, +) +@click.option( + "--pkg-directory", + "-d", + help="Directory to save the package in cloned repository", + default=Path("packages", "security_detection_engine"), +) @click.option("--base-branch", "-b", help="Base branch in target repository", default="main") @click.option("--branch-name", "-n", help="New branch for the rules commit") @click.option("--github-repo", "-r", help="Repository to use for the branch", default="elastic/integrations") @@ -549,9 +636,19 @@ def kibana_diff(rule_id, repo, branch, threads): @click.option("--draft", is_flag=True, help="Open the PR as a draft") @click.option("--remote", help="Override the remote from 'origin'", default="origin") @click.pass_context -def integrations_pr(ctx: click.Context, local_repo: str, token: str, draft: bool, - pkg_directory: str, base_branch: str, remote: str, - branch_name: Optional[str], github_repo: str, assign: Tuple[str, ...], label: Tuple[str, ...]): +def integrations_pr( # noqa: PLR0913, PLR0915 + ctx: click.Context, + local_repo: Path, + token: str, + draft: bool, + pkg_directory: str, + base_branch: str, + remote: str, + branch_name: str | None, + github_repo: str, + assign: tuple[str, ...], + label: tuple[str, ...], +) -> None: """Create a pull request to publish the Fleet package to elastic/integrations.""" github = GithubClient(token) github.assert_github() @@ -559,11 +656,17 @@ def integrations_pr(ctx: click.Context, local_repo: str, token: str, draft: bool repo = client.get_repo(github_repo) # Use elastic-package to format and lint - gopath = utils.gopath().strip("'\"") - assert gopath is not None, "$GOPATH isn't set" + gopath = utils.gopath() - err = 'elastic-package missing, run: go install github.com/elastic/elastic-package@latest and verify go bin path' - assert subprocess.check_output(['elastic-package'], stderr=subprocess.DEVNULL), err + if not gopath: + raise ValueError("GOPATH not found") + + gopath = gopath.strip("'\"") + + if not subprocess.check_output(["elastic-package"], stderr=subprocess.DEVNULL): # noqa: S607 + raise ValueError( + "elastic-package missing, run: go install github.com/elastic/elastic-package@latest and verify go bin path" + ) local_repo = Path(local_repo).resolve() stack_version = Package.load_configs()["name"] @@ -588,44 +691,52 @@ def integrations_pr(ctx: click.Context, local_repo: str, token: str, draft: bool # refresh the local clone of the repository git = utils.make_git("-C", local_repo) - git("checkout", base_branch) - git("pull", remote, base_branch) + _ = git("checkout", base_branch) + _ = git("pull", remote, base_branch) # Switch to a new branch in elastic/integrations branch_name = branch_name or f"detection-rules/{package_version}-{short_commit_hash}" - git("checkout", "-b", branch_name) + _ = git("checkout", "-b", branch_name) # Load the changelog in memory, before it's removed. Come back for it after the PR is created target_directory = local_repo / pkg_directory changelog_path = target_directory / "changelog.yml" - changelog_entries: list = yaml.safe_load(changelog_path.read_text(encoding="utf-8")) - - changelog_entries.insert(0, { - "version": package_version, - "changes": [ - # This will be changed later - {"description": "Release security rules update", "type": "enhancement", - "link": "https://github.com/elastic/integrations/pulls/0000"} - ] - }) + changelog_entries: list[dict[str, Any]] = yaml.safe_load(changelog_path.read_text(encoding="utf-8")) + + changelog_entries.insert( + 0, + { + "version": package_version, + "changes": [ + # This will be changed later + { + "description": "Release security rules update", + "type": "enhancement", + "link": "https://github.com/elastic/integrations/pulls/0000", + } + ], + }, + ) # Remove existing assets and replace everything shutil.rmtree(target_directory) actual_target_directory = shutil.copytree(release_dir, target_directory) - assert Path(actual_target_directory).absolute() == Path(target_directory).absolute(), \ - f"Expected a copy to {pkg_directory}" + if Path(actual_target_directory).absolute() != Path(target_directory).absolute(): + raise ValueError(f"Expected a copy to {pkg_directory}") # Add the changelog back - def save_changelog(): + def save_changelog() -> None: with changelog_path.open("wt") as f: # add a note for other maintainers of elastic/integrations to be careful with versions - f.write("# newer versions go on top\n") - f.write("# NOTE: please use pre-release versions (e.g. -beta.0) until a package is ready for production\n") + _ = f.write("# newer versions go on top\n") + _ = f.write( + "# NOTE: please use pre-release versions (e.g. -beta.0) until a package is ready for production\n" + ) yaml.dump(changelog_entries, f, allow_unicode=True, default_flow_style=False, indent=2, sort_keys=False) save_changelog() - def elastic_pkg(*args): + def elastic_pkg(*args: Any) -> None: """Run a command with $GOPATH/bin/elastic-package in the package directory.""" prev = Path.cwd() os.chdir(target_directory) @@ -633,16 +744,16 @@ def elastic_pkg(*args): try: elastic_pkg_cmd = [str(Path(gopath, "bin", "elastic-package"))] elastic_pkg_cmd.extend(list(args)) - return subprocess.check_call(elastic_pkg_cmd) + _ = subprocess.check_call(elastic_pkg_cmd) finally: os.chdir(str(prev)) elastic_pkg("format") # Upload the files to a branch - git("add", pkg_directory) - git("commit", "-m", message) - git("push", "--set-upstream", remote, branch_name) + _ = git("add", pkg_directory) + _ = git("commit", "-m", message) + _ = git("push", "--set-upstream", remote, branch_name) # Create a pull request (not done yet, but we need the PR number) body = textwrap.dedent(f""" @@ -673,14 +784,17 @@ def elastic_pkg(*args): None """) # noqa: E501 - pr = repo.create_pull(title=message, body=body, base=base_branch, head=branch_name, - maintainer_can_modify=True, draft=draft) + pr = repo.create_pull( + title=message, body=body, base=base_branch, head=branch_name, maintainer_can_modify=True, draft=draft + ) # labels could also be comma separated - label = {lbl for cs_labels in label for lbl in cs_labels.split(",") if lbl} + cs_labels_split = {lbl for cs_labels in label for lbl in cs_labels.split(",") if lbl} + + labels = sorted(list(label) + list(cs_labels_split)) - if label: - pr.add_to_labels(*sorted(label)) + if labels: + pr.add_to_labels(*labels) if assign: pr.add_to_assignees(*assign) @@ -693,30 +807,29 @@ def elastic_pkg(*args): save_changelog() # format the yml file with elastic-package - elastic_pkg("format") - elastic_pkg("lint") + _ = elastic_pkg("format") + _ = elastic_pkg("lint") # Push the updated changelog to the PR branch - git("add", pkg_directory) - git("commit", "-m", f"Add changelog entry for {package_version}") - git("push") + _ = git("add", pkg_directory) + _ = git("commit", "-m", f"Add changelog entry for {package_version}") + _ = git("push") -@dev_group.command('license-check') -@click.option('--ignore-directory', '-i', multiple=True, help='Directories to skip (relative to base)') +@dev_group.command("license-check") +@click.option("--ignore-directory", "-i", multiple=True, help="Directories to skip (relative to base)") @click.pass_context -def license_check(ctx, ignore_directory): +def license_check(ctx: click.Context, ignore_directory: list[str]) -> None: """Check that all code files contain a valid license.""" ignore_directory += ("env",) failed = False - base_path = get_path() - for path in base_path.rglob('*.py'): - relative_path = path.relative_to(base_path) + for path in utils.ROOT_DIR.rglob("*.py"): + relative_path = path.relative_to(utils.ROOT_DIR) if relative_path.parts[0] in ignore_directory: continue - with io.open(path, "rt", encoding="utf-8") as f: + with path.open(encoding="utf-8") as f: contents = f.read() # skip over shebang lines @@ -733,144 +846,158 @@ def license_check(ctx, ignore_directory): ctx.exit(int(failed)) -@dev_group.command('test-version-lock') -@click.argument('branches', nargs=-1, required=True) -@click.option('--remote', '-r', default='origin', help='Override the remote from "origin"') +@dev_group.command("test-version-lock") +@click.argument("branches", nargs=-1, required=True) +@click.option("--remote", "-r", default="origin", help='Override the remote from "origin"') @click.pass_context -def test_version_lock(ctx: click.Context, branches: tuple, remote: str): +def test_version_lock(ctx: click.Context, branches: list[str], remote: str) -> None: """Simulate the incremental step in the version locking to find version change violations.""" - git = utils.make_git('-C', '.') - current_branch = git('rev-parse', '--abbrev-ref', 'HEAD') + git = utils.make_git("-C", ".") + current_branch = git("rev-parse", "--abbrev-ref", "HEAD") try: - click.echo(f'iterating lock process for branches: {branches}') + click.echo(f"iterating lock process for branches: {branches}") for branch in branches: click.echo(branch) - git('checkout', f'{remote}/{branch}') - subprocess.check_call(['python', '-m', 'detection_rules', 'dev', 'build-release', '-u']) + _ = git("checkout", f"{remote}/{branch}") + _ = subprocess.check_call(["python", "-m", "detection_rules", "dev", "build-release", "-u"]) # noqa: S607 finally: - rules_config = ctx.obj['rules_config'] - diff = git('--no-pager', 'diff', str(rules_config.version_lock_file)) - outfile = get_path() / 'lock-diff.txt' - outfile.write_text(diff) - click.echo(f'diff saved to {outfile}') + rules_config = ctx.obj["rules_config"] + diff = git("--no-pager", "diff", str(rules_config.version_lock_file)) + outfile = utils.ROOT_DIR / "lock-diff.txt" + _ = outfile.write_text(diff) + click.echo(f"diff saved to {outfile}") - click.echo('reverting changes in version.lock') - git('checkout', '-f') - git('checkout', current_branch) + click.echo("reverting changes in version.lock") + _ = git("checkout", "-f") + _ = git("checkout", current_branch) -@dev_group.command('package-stats') -@click.option('--token', '-t', help='GitHub token to search API authenticated (may exceed threshold without auth)') -@click.option('--threads', default=50, help='Number of threads to download rules from GitHub') +@dev_group.command("package-stats") +@click.option("--token", "-t", help="GitHub token to search API authenticated (may exceed threshold without auth)") +@click.option("--threads", default=50, help="Number of threads to download rules from GitHub") @click.pass_context -def package_stats(ctx, token, threads): +def package_stats(ctx: click.Context, token: str | None, threads: int) -> None: """Get statistics for current rule package.""" - current_package: Package = ctx.invoke(build_release, verbose=False, release=None) - release = f'v{current_package.name}.0' - new, modified, errors = rule_loader.load_github_pr_rules(labels=[release], token=token, threads=threads) - - click.echo(f'Total rules as of {release} package: {len(current_package.rules)}') - click.echo(f'New rules: {len(current_package.new_ids)}') - click.echo(f'Modified rules: {len(current_package.changed_ids)}') - click.echo(f'Deprecated rules: {len(current_package.removed_ids)}') - - click.echo('\n-----\n') - click.echo('Rules in active PRs for current package: ') - click.echo(f'New rules: {len(new)}') - click.echo(f'Modified rules: {len(modified)}') - - -@dev_group.command('search-rule-prs') -@click.argument('query', required=False) -@click.option('--no-loop', '-n', is_flag=True, help='Run once with no loop') -@click.option('--columns', '-c', multiple=True, help='Specify columns to add the table') -@click.option('--language', type=click.Choice(["eql", "kql"]), default="kql") -@click.option('--token', '-t', help='GitHub token to search API authenticated (may exceed threshold without auth)') -@click.option('--threads', default=50, help='Number of threads to download rules from GitHub') + current_package: Package = ctx.invoke(build_release, verbose=False) + release = f"v{current_package.name}.0" + new, modified, _ = rule_loader.load_github_pr_rules(labels=[release], token=token, threads=threads) + + click.echo(f"Total rules as of {release} package: {len(current_package.rules)}") + click.echo(f"New rules: {len(current_package.new_ids)}") + click.echo(f"Modified rules: {len(current_package.changed_ids)}") + click.echo(f"Deprecated rules: {len(current_package.removed_ids)}") + + click.echo("\n-----\n") + click.echo("Rules in active PRs for current package: ") + click.echo(f"New rules: {len(new)}") + click.echo(f"Modified rules: {len(modified)}") + + +@dev_group.command("search-rule-prs") +@click.argument("query", required=False) +@click.option("--no-loop", "-n", is_flag=True, help="Run once with no loop") +@click.option("--columns", "-c", multiple=True, help="Specify columns to add the table") +@click.option("--language", type=click.Choice(["eql", "kql"]), default="kql") +@click.option("--token", "-t", help="GitHub token to search API authenticated (may exceed threshold without auth)") +@click.option("--threads", default=50, help="Number of threads to download rules from GitHub") @click.pass_context -def search_rule_prs(ctx, no_loop, query, columns, language, token, threads): +def search_rule_prs( # noqa: PLR0913 + ctx: click.Context, + no_loop: bool, + query: str | None, + columns: list[str], + language: Literal["eql", "kql"], + token: str | None, + threads: int, +) -> None: """Use KQL or EQL to find matching rules from active GitHub PRs.""" - from uuid import uuid4 - from .main import search_rules - all_rules: Dict[Path, TOMLRule] = {} - new, modified, errors = rule_loader.load_github_pr_rules(token=token, threads=threads) + all_rules: dict[Path, TOMLRule] = {} + new, modified, _ = rule_loader.load_github_pr_rules(token=token, threads=threads) - def add_github_meta(this_rule: TOMLRule, status: str, original_rule_id: Optional[definitions.UUIDString] = None): + def add_github_meta( + this_rule: TOMLRule, + status: str, + original_rule_id: definitions.UUIDString | None = None, + ) -> None: pr = this_rule.gh_pr data = rule.contents.data extend_meta = { - 'status': status, - 'github': { - 'base': pr.base.label, - 'comments': [c.body for c in pr.get_comments()], - 'commits': pr.commits, - 'created_at': str(pr.created_at), - 'head': pr.head.label, - 'is_draft': pr.draft, - 'labels': [lbl.name for lbl in pr.get_labels()], - 'last_modified': str(pr.last_modified), - 'title': pr.title, - 'url': pr.html_url, - 'user': pr.user.login - } + "status": status, + "github": { + "base": pr.base.label, + "comments": [c.body for c in pr.get_comments()], + "commits": pr.commits, + "created_at": str(pr.created_at), + "head": pr.head.label, + "is_draft": pr.draft, + "labels": [lbl.name for lbl in pr.get_labels()], + "last_modified": str(pr.last_modified), + "title": pr.title, + "url": pr.html_url, + "user": pr.user.login, + }, } if original_rule_id: - extend_meta['original_rule_id'] = original_rule_id + extend_meta["original_rule_id"] = original_rule_id data = dataclasses.replace(rule.contents.data, rule_id=str(uuid4())) - rule_path = Path(f'pr-{pr.number}-{rule.path}') + rule_path = Path(f"pr-{pr.number}-{rule.path}") new_meta = dataclasses.replace(rule.contents.metadata, extended=extend_meta) contents = dataclasses.replace(rule.contents, metadata=new_meta, data=data) new_rule = TOMLRule(path=rule_path, contents=contents) + if not new_rule.path: + raise ValueError("No rule path found") all_rules[new_rule.path] = new_rule - for rule_id, rule in new.items(): - add_github_meta(rule, 'new') + for rule in new.values(): + add_github_meta(rule, "new") for rule_id, rules in modified.items(): for rule in rules: - add_github_meta(rule, 'modified', rule_id) + add_github_meta(rule, "modified", rule_id) loop = not no_loop ctx.invoke(search_rules, query=query, columns=columns, language=language, rules=all_rules, pager=loop) while loop: - query = click.prompt(f'Search loop - enter new {language} query or ctrl-z to exit') - columns = click.prompt('columns', default=','.join(columns)).split(',') + query = click.prompt(f"Search loop - enter new {language} query or ctrl-z to exit") + columns = click.prompt("columns", default=",".join(columns)).split(",") ctx.invoke(search_rules, query=query, columns=columns, language=language, rules=all_rules, pager=True) -@dev_group.command('deprecate-rule') -@click.argument('rule-file', type=Path) -@click.option('--deprecation-folder', '-d', type=Path, required=True, - help='Location to move the deprecated rule file to') +@dev_group.command("deprecate-rule") +@click.argument("rule-file", type=Path) +@click.option( + "--deprecation-folder", "-d", type=Path, required=True, help="Location to move the deprecated rule file to" +) @click.pass_context -def deprecate_rule(ctx: click.Context, rule_file: Path, deprecation_folder: Path): +def deprecate_rule(ctx: click.Context, rule_file: Path, deprecation_folder: Path) -> None: """Deprecate a rule.""" version_info = loaded_version_lock.version_lock rule_collection = RuleCollection() contents = rule_collection.load_file(rule_file).contents - rule = TOMLRule(path=rule_file, contents=contents) + rule = TOMLRule(path=rule_file, contents=contents) # type: ignore[reportArgumentType] if rule.contents.id not in version_info and not RULES_CONFIG.bypass_version_lock: - click.echo('Rule has not been version locked and so does not need to be deprecated. ' - 'Delete the file or update the maturity to `development` instead.') + click.echo( + "Rule has not been version locked and so does not need to be deprecated. " + "Delete the file or update the maturity to `development` instead." + ) ctx.exit() - today = time.strftime('%Y/%m/%d') + today = time.strftime("%Y/%m/%d") deprecated_path = deprecation_folder / rule_file.name # create the new rule and save it - new_meta = dataclasses.replace(rule.contents.metadata, - updated_date=today, - deprecation_date=today, - maturity='deprecated') + new_meta = dataclasses.replace( + rule.contents.metadata, updated_date=today, deprecation_date=today, maturity="deprecated" + ) contents = dataclasses.replace(rule.contents, metadata=new_meta) new_rule = TOMLRule(contents=contents, path=deprecated_path) deprecated_path.parent.mkdir(parents=True, exist_ok=True) @@ -878,72 +1005,86 @@ def deprecate_rule(ctx: click.Context, rule_file: Path, deprecation_folder: Path # remove the old rule rule_file.unlink() - click.echo(f'Rule moved to {deprecated_path} - remember to git add this file') - - -@dev_group.command('update-navigator-gists') -@click.option('--directory', type=Path, default=CURRENT_RELEASE_PATH.joinpath('extras', 'navigator_layers'), - help='Directory containing only navigator files.') -@click.option('--token', required=True, prompt=get_github_token() is None, default=get_github_token(), - help='GitHub token to push to gist', hide_input=True) -@click.option('--gist-id', default=NAVIGATOR_GIST_ID, help='Gist ID to be updated (must exist).') -@click.option('--print-markdown', is_flag=True, help='Print the generated urls') -@click.option('--update-coverage', is_flag=True, help=f'Update the {REPO_DOCS_DIR}/ATT&CK-coverage.md file') -def update_navigator_gists(directory: Path, token: str, gist_id: str, print_markdown: bool, - update_coverage: bool) -> list: + click.echo(f"Rule moved to {deprecated_path} - remember to git add this file") + + +@dev_group.command("update-navigator-gists") +@click.option( + "--directory", + type=click.Path(exists=True, file_okay=False, dir_okay=True, writable=True, path_type=Path), + default=CURRENT_RELEASE_PATH.joinpath("extras", "navigator_layers"), + help="Directory containing only navigator files.", +) +@click.option( + "--token", + required=True, + prompt=get_github_token() is None, + default=get_github_token(), + help="GitHub token to push to gist", + hide_input=True, +) +@click.option("--gist-id", default=NAVIGATOR_GIST_ID, help="Gist ID to be updated (must exist).") +@click.option("--print-markdown", is_flag=True, help="Print the generated urls") +@click.option("--update-coverage", is_flag=True, help=f"Update the {REPO_DOCS_DIR}/ATT&CK-coverage.md file") +def update_navigator_gists( + directory: Path, + token: str, + gist_id: str, + print_markdown: bool, + update_coverage: bool, +) -> list[str]: """Update the gists with new navigator files.""" - assert directory.exists(), f'{directory} does not exist' - def raw_permalink(raw_link): + def raw_permalink(raw_link: str) -> str: # Gist file URLs change with each revision, but can be permalinked to the latest by removing the hash after raw - prefix, _, suffix = raw_link.rsplit('/', 2) - return '/'.join([prefix, suffix]) + prefix, _, suffix = raw_link.rsplit("/", 2) + return f"{prefix}/{suffix}" - file_map = {f: f.read_text() for f in directory.glob('*.json')} + file_map = {f: f.read_text() for f in directory.glob("*.json")} try: - response = update_gist(token, - file_map, - description='ATT&CK Navigator layer files.', - gist_id=gist_id, - pre_purge=True) + response = update_gist( + token, file_map, description="ATT&CK Navigator layer files.", gist_id=gist_id, pre_purge=True + ) except requests.exceptions.HTTPError as exc: if exc.response.status_code == requests.status_codes.codes.not_found: - raise client_error('Gist not found: verify the gist_id exists and the token has access to it', exc=exc) - else: - raise + raise raise_client_error( + "Gist not found: verify the gist_id exists and the token has access to it", exc=exc + ) from exc + raise response_data = response.json() - raw_urls = {name: raw_permalink(data['raw_url']) for name, data in response_data['files'].items()} + raw_urls = {name: raw_permalink(data["raw_url"]) for name, data in response_data["files"].items()} - base_url = 'https://mitre-attack.github.io/attack-navigator/#layerURL={}&leave_site_dialog=false&tabs=false' + base_url = "https://mitre-attack.github.io/attack-navigator/#layerURL={}&leave_site_dialog=false&tabs=false" # pull out full and platform coverage to print on top of markdown table - all_url = base_url.format(urllib.parse.quote_plus(raw_urls.pop('Elastic-detection-rules-all.json'))) - platforms_url = base_url.format(urllib.parse.quote_plus(raw_urls.pop('Elastic-detection-rules-platforms.json'))) + all_url = base_url.format(urllib.parse.quote_plus(raw_urls.pop("Elastic-detection-rules-all.json"))) + platforms_url = base_url.format(urllib.parse.quote_plus(raw_urls.pop("Elastic-detection-rules-platforms.json"))) generated_urls = [all_url, platforms_url] - markdown_links = [] + markdown_links: list[str] = [] for name, gist_url in raw_urls.items(): query = urllib.parse.quote_plus(gist_url) - url = f'https://mitre-attack.github.io/attack-navigator/#layerURL={query}&leave_site_dialog=false&tabs=false' + url = f"https://mitre-attack.github.io/attack-navigator/#layerURL={query}&leave_site_dialog=false&tabs=false" generated_urls.append(url) - link_name = name.split('.')[0] - markdown_links.append(f'|[{link_name}]({url})|') + link_name = name.split(".")[0] + markdown_links.append(f"|[{link_name}]({url})|") markdown = [ - f'**Full coverage**: {NAVIGATOR_BADGE}', - '\n', - f'**Coverage by platform**: [navigator]({platforms_url})', - '\n', - '| other navigator links by rule attributes |', - '|------------------------------------------|', - ] + markdown_links + f"**Full coverage**: {NAVIGATOR_BADGE}", + "\n", + f"**Coverage by platform**: [navigator]({platforms_url})", + "\n", + "| other navigator links by rule attributes |", + "|------------------------------------------|", + *markdown_links, + ] if print_markdown: - click.echo('\n'.join(markdown) + '\n') + click.echo("\n".join(markdown) + "\n") if update_coverage: - coverage_file_path = get_path(REPO_DOCS_DIR, 'ATT&CK-coverage.md') + coverage_file_path = get_path([REPO_DOCS_DIR, "ATT&CK-coverage.md"]) header_lines = textwrap.dedent("""# Rule coverage ATT&CK navigator layer files are generated when a package is built with `make release` or @@ -958,59 +1099,68 @@ def raw_permalink(raw_link): The source files for these links are regenerated with every successful merge to main. These represent coverage from the state of rules in the `main` branch. """) - updated_file = header_lines + '\n\n' + '\n'.join(markdown) + '\n' + updated_file = header_lines + "\n\n" + "\n".join(markdown) + "\n" # Replace the old URLs with the new ones - with open(coverage_file_path, 'w') as md_file: - md_file.write(updated_file) - click.echo(f'Updated ATT&CK coverage URL(s) in {coverage_file_path}' + '\n') + with coverage_file_path.open("w") as md_file: + _ = md_file.write(updated_file) + click.echo(f"Updated ATT&CK coverage URL(s) in {coverage_file_path}" + "\n") - click.echo(f'Gist update status on {len(generated_urls)} files: {response.status_code} {response.reason}') + click.echo(f"Gist update status on {len(generated_urls)} files: {response.status_code} {response.reason}") return generated_urls -@dev_group.command('trim-version-lock') -@click.argument('stack_version') -@click.option('--skip-rule-updates', is_flag=True, help='Skip updating the rules') -@click.option('--dry-run', is_flag=True, help='Print the changes rather than saving the file') +@dev_group.command("trim-version-lock") +@click.argument("stack_version") +@click.option("--skip-rule-updates", is_flag=True, help="Skip updating the rules") +@click.option("--dry-run", is_flag=True, help="Print the changes rather than saving the file") @click.pass_context -def trim_version_lock(ctx: click.Context, stack_version: str, skip_rule_updates: bool, dry_run: bool): +def trim_version_lock( # noqa: PLR0912, PLR0915 + ctx: click.Context, + stack_version: str, + skip_rule_updates: bool, + dry_run: bool, +) -> None: """Trim all previous entries within the version lock file which are lower than the min_version.""" stack_versions = get_stack_versions() - assert stack_version in stack_versions, \ - f'Unknown min_version ({stack_version}), expected: {", ".join(stack_versions)}' + if stack_version not in stack_versions: + raise ValueError(f"Unknown min_version ({stack_version}), expected: {', '.join(stack_versions)}") min_version = Version.parse(stack_version) if RULES_CONFIG.bypass_version_lock: - click.echo('WARNING: Cannot trim the version lock when the versioning strategy is configured to bypass the ' - 'version lock. Set `bypass_version_lock` to `false` in the rules config to use the version lock.') + click.echo( + "WARNING: Cannot trim the version lock when the versioning strategy is configured to bypass the " + "version lock. Set `bypass_version_lock` to `false` in the rules config to use the version lock." + ) ctx.exit() version_lock_dict = loaded_version_lock.version_lock.to_dict() - removed = defaultdict(list) - rule_msv_drops = [] + removed: dict[str, list[str]] = defaultdict(list) + rule_msv_drops: list[str] = [] - today = time.strftime('%Y/%m/%d') + today = time.strftime("%Y/%m/%d") rc: RuleCollection | None = None if dry_run: rc = RuleCollection() - else: - if not skip_rule_updates: - click.echo('Loading rules ...') - rc = RuleCollection.default() + elif not skip_rule_updates: + click.echo("Loading rules ...") + rc = RuleCollection.default() + + if not rc: + raise ValueError("No rule collection found") for rule_id, lock in version_lock_dict.items(): file_min_stack: Version | None = None - if 'min_stack_version' in lock: - file_min_stack = Version.parse((lock['min_stack_version']), optional_minor_and_patch=True) + if "min_stack_version" in lock: + file_min_stack = Version.parse((lock["min_stack_version"]), optional_minor_and_patch=True) if file_min_stack <= min_version: removed[rule_id].append( - f'locked min_stack_version <= {min_version} - {"will remove" if dry_run else "removing"}!' + f"locked min_stack_version <= {min_version} - {'will remove' if dry_run else 'removing'}!" ) rule_msv_drops.append(rule_id) file_min_stack = None if not dry_run: - lock.pop('min_stack_version') + lock.pop("min_stack_version") if not skip_rule_updates: # remove the min_stack_version and min_stack_comments from rules as well (and update date) rule = rc.id_map.get(rule_id) @@ -1019,17 +1169,17 @@ def trim_version_lock(ctx: click.Context, stack_version: str, skip_rule_updates: rule.contents.metadata, updated_date=today, min_stack_version=None, - min_stack_comments=None + min_stack_comments=None, ) contents = dataclasses.replace(rule.contents, metadata=new_meta) new_rule = TOMLRule(contents=contents, path=rule.path) new_rule.save_toml() - removed[rule_id].append('rule min_stack_version dropped') + removed[rule_id].append("rule min_stack_version dropped") else: - removed[rule_id].append('rule not found to update!') + removed[rule_id].append("rule not found to update!") - if 'previous' in lock: - prev_vers = [Version.parse(v, optional_minor_and_patch=True) for v in list(lock['previous'])] + if "previous" in lock: + prev_vers = [Version.parse(v, optional_minor_and_patch=True) for v in list(lock["previous"])] outdated_vers = [v for v in prev_vers if v < min_version] if not outdated_vers: @@ -1041,60 +1191,67 @@ def trim_version_lock(ctx: click.Context, stack_version: str, skip_rule_updates: for outdated in outdated_vers: short_outdated = f"{outdated.major}.{outdated.minor}" - popped = lock['previous'].pop(str(short_outdated)) + popped = lock["previous"].pop(str(short_outdated)) # the core of the update - we only need to keep previous entries that are newer than the min supported # version (from stack-schema-map and stack-version parameter) and older than the locked # min_stack_version for a given rule, if one exists if file_min_stack and outdated == latest_version and outdated < file_min_stack: - lock['previous'][f'{min_version.major}.{min_version.minor}'] = popped - removed[rule_id].append(f'{short_outdated} updated to: {min_version.major}.{min_version.minor}') + lock["previous"][f"{min_version.major}.{min_version.minor}"] = popped + removed[rule_id].append(f"{short_outdated} updated to: {min_version.major}.{min_version.minor}") else: - removed[rule_id].append(f'{outdated} dropped') + removed[rule_id].append(f"{outdated} dropped") # remove the whole previous entry if it is now blank - if not lock['previous']: - lock.pop('previous') + if not lock["previous"]: + lock.pop("previous") - click.echo(f'Changes {"that will be " if dry_run else ""} applied:' if removed else 'No changes') - click.echo('\n'.join(f'{k}: {", ".join(v)}' for k, v in removed.items())) + click.echo(f"Changes {'that will be ' if dry_run else ''} applied:" if removed else "No changes") + click.echo("\n".join(f"{k}: {', '.join(v)}" for k, v in removed.items())) if not dry_run: - new_lock = VersionLockFile.from_dict(dict(data=version_lock_dict)) + new_lock = VersionLockFile.from_dict({"data": version_lock_dict}) new_lock.save_to_file() -@dev_group.group('diff') -def diff_group(): +@dev_group.group("diff") +def diff_group() -> None: """Commands for statistics on changes and diffs.""" -@diff_group.command('endpoint-by-attack') -@click.option('--pre', required=True, help='Tag for pre-existing rules') -@click.option('--post', required=True, help='Tag for rules post updates') -@click.option('--force', '-f', is_flag=True, help='Bypass the confirmation prompt') -@click.option('--remote', '-r', default='origin', help='Override the remote from "origin"') +@diff_group.command("endpoint-by-attack") +@click.option("--pre", required=True, help="Tag for pre-existing rules") +@click.option("--post", required=True, help="Tag for rules post updates") +@click.option("--force", "-f", is_flag=True, help="Bypass the confirmation prompt") +@click.option("--remote", "-r", default="origin", help='Override the remote from "origin"') @click.pass_context -def endpoint_by_attack(ctx: click.Context, pre: str, post: str, force: bool, remote: Optional[str] = 'origin'): +def endpoint_by_attack( + ctx: click.Context, + pre: str, + post: str, + force: bool, + remote: str = "origin", +) -> tuple[Any, Any, Any]: """Rule diffs across tagged branches, broken down by ATT&CK tactics.""" - if not force: - if not click.confirm(f'This will refresh tags and may overwrite local tags for: {pre} and {post}. Continue?'): - ctx.exit(1) + if not force and not click.confirm( + f"This will refresh tags and may overwrite local tags for: {pre} and {post}. Continue?" + ): + ctx.exit(1) changed, new, deprecated = get_release_diff(pre, post, remote) - oses = ('windows', 'linux', 'macos') + oses = ("windows", "linux", "macos") - def delta_stats(rule_map) -> List[dict]: - stats = defaultdict(lambda: defaultdict(int)) - os_totals = defaultdict(int) - tactic_totals = defaultdict(int) + def delta_stats(rule_map: dict[str, TOMLRule] | dict[str, DeprecatedRule]) -> list[dict[str, Any]]: + stats: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) + os_totals: dict[str, int] = defaultdict(int) + tactic_totals: dict[str, int] = defaultdict(int) - for rule_id, rule in rule_map.items(): - threat = rule.contents.data.get('threat') - os_types = [i.lower() for i in rule.contents.data.get('tags') or [] if i.lower() in oses] + for rule in rule_map.values(): + threat = rule.contents.data.get("threat") + os_types: list[str] = [i.lower() for i in rule.contents.data.get("tags") or [] if i.lower() in oses] # type: ignore[reportUnknownVariableType] if not threat or not os_types: continue if isinstance(threat[0], dict): - tactics = sorted(set(e['tactic']['name'] for e in threat)) + tactics = sorted({e["tactic"]["name"] for e in threat}) else: tactics = ThreatMapping.flatten(threat).tactic_names for tactic in tactics: @@ -1104,138 +1261,178 @@ def delta_stats(rule_map) -> List[dict]: stats[tactic][os_type] += 1 # structure stats for table - rows = [] + rows: list[dict[str, Any]] = [] for tac, stat in stats.items(): - row = {'tactic': tac, 'total': tactic_totals[tac]} + row: dict[str, Any] = {"tactic": tac, "total": tactic_totals[tac]} for os_type, count in stat.items(): - row[os_type] = count + row[os_type] = count # noqa: PERF403 rows.append(row) - rows.append(dict(tactic='total_by_os', **os_totals)) - + rows.append(dict(tactic="total_by_os", **os_totals)) return rows - fields = ['tactic', 'linux', 'macos', 'windows', 'total'] + fields = ["tactic", "linux", "macos", "windows", "total"] changed_stats = delta_stats(changed) - table = Table.from_list(fields, changed_stats) - click.echo(f'Changed rules {len(changed)}\n{table}\n') + table = Table.from_list(fields, changed_stats) # type: ignore[reportUnknownMemberType] + click.echo(f"Changed rules {len(changed)}\n{table}\n") new_stats = delta_stats(new) - table = Table.from_list(fields, new_stats) - click.echo(f'New rules {len(new)}\n{table}\n') + table = Table.from_list(fields, new_stats) # type: ignore[reportUnknownMemberType] + click.echo(f"New rules {len(new)}\n{table}\n") dep_stats = delta_stats(deprecated) - table = Table.from_list(fields, dep_stats) - click.echo(f'Deprecated rules {len(deprecated)}\n{table}\n') + table = Table.from_list(fields, dep_stats) # type: ignore[reportUnknownMemberType] + click.echo(f"Deprecated rules {len(deprecated)}\n{table}\n") return changed_stats, new_stats, dep_stats -@dev_group.group('test') -def test_group(): +@dev_group.group("test") +def test_group() -> None: """Commands for testing against stack resources.""" -@test_group.command('event-search') -@click.argument('query') -@click.option('--index', '-i', multiple=True, help='Index patterns to search against') -@click.option('--eql/--lucene', '-e/-l', 'language', default=None, help='Query language used (default: kql)') -@click.option('--date-range', '-d', type=(str, str), default=('now-7d', 'now'), help='Date range to scope search') -@click.option('--count', '-c', is_flag=True, help='Return count of results only') -@click.option('--max-results', '-m', type=click.IntRange(1, 1000), default=100, - help='Max results to return (capped at 1000)') -@click.option('--verbose', '-v', is_flag=True, default=True) -@add_client('elasticsearch') -def event_search(query, index, language, date_range, count, max_results, verbose=True, - elasticsearch_client: Elasticsearch = None): +@test_group.command("event-search") +@click.argument("query") +@click.option("--index", "-i", multiple=True, help="Index patterns to search against") +@click.option("--eql/--lucene", "-e/-l", "language", default=None, help="Query language used (default: kql)") +@click.option("--date-range", "-d", type=(str, str), default=("now-7d", "now"), help="Date range to scope search") +@click.option("--count", "-c", is_flag=True, help="Return count of results only") +@click.option( + "--max-results", + "-m", + type=click.IntRange(1, 1000), + default=100, + help="Max results to return (capped at 1000)", +) +@click.option("--verbose", "-v", is_flag=True, default=True) +@add_client(["elasticsearch"]) +def event_search( # noqa: PLR0913 + query: str, + index: list[str], + language: str | None, + date_range: tuple[str, str], + count: bool, + max_results: int, + elasticsearch_client: Elasticsearch, + verbose: bool = True, +) -> Any | list[Any]: """Search using a query against an Elasticsearch instance.""" start_time, end_time = date_range - index = index or ('*',) - language_used = "kql" if language is None else "eql" if language is True else "lucene" + index = index or ["*"] + language_used = "kql" if language is None else "eql" if language else "lucene" collector = CollectEvents(elasticsearch_client, max_results) if verbose: - click.echo(f'searching {",".join(index)} from {start_time} to {end_time}') - click.echo(f'{language_used}: {query}') + click.echo(f"searching {','.join(index)} from {start_time} to {end_time}") + click.echo(f"{language_used}: {query}") if count: results = collector.count(query, language_used, index, start_time, end_time) - click.echo(f'total results: {results}') + click.echo(f"total results: {results}") else: results = collector.search(query, language_used, index, start_time, end_time, max_results) - click.echo(f'total results: {len(results)} (capped at {max_results})') + click.echo(f"total results: {len(results)} (capped at {max_results})") click.echo_via_pager(json.dumps(results, indent=2, sort_keys=True)) return results -@test_group.command('rule-event-search') +@test_group.command("rule-event-search") @single_collection -@click.option('--date-range', '-d', type=(str, str), default=('now-7d', 'now'), help='Date range to scope search') -@click.option('--count', '-c', is_flag=True, help='Return count of results only') -@click.option('--max-results', '-m', type=click.IntRange(1, 1000), default=100, - help='Max results to return (capped at 1000)') -@click.option('--verbose', '-v', is_flag=True) +@click.option("--date-range", "-d", type=(str, str), default=("now-7d", "now"), help="Date range to scope search") +@click.option("--count", "-c", is_flag=True, help="Return count of results only") +@click.option( + "--max-results", + "-m", + type=click.IntRange(1, 1000), + default=100, + help="Max results to return (capped at 1000)", +) +@click.option("--verbose", "-v", is_flag=True) @click.pass_context -@add_client('elasticsearch') -def rule_event_search(ctx, rule, date_range, count, max_results, verbose, - elasticsearch_client: Elasticsearch = None): +@add_client(["elasticsearch"]) +def rule_event_search( # noqa: PLR0913 + ctx: click.Context, + rule: Any, + date_range: tuple[str, str], + count: bool, + max_results: int, + elasticsearch_client: Elasticsearch, + verbose: bool = False, +) -> None: """Search using a rule file against an Elasticsearch instance.""" if isinstance(rule.contents.data, QueryRuleData): if verbose: - click.echo(f'Searching rule: {rule.name}') + click.echo(f"Searching rule: {rule.name}") data = rule.contents.data rule_lang = data.language - if rule_lang == 'kuery': + if rule_lang == "kuery": language_flag = None - elif rule_lang == 'eql': + elif rule_lang == "eql": language_flag = True else: language_flag = False - index = data.index or ['*'] - ctx.invoke(event_search, query=data.query, index=index, language=language_flag, - date_range=date_range, count=count, max_results=max_results, verbose=verbose, - elasticsearch_client=elasticsearch_client) + index = data.index or ["*"] + ctx.invoke( + event_search, + query=data.query, + index=index, + language=language_flag, + date_range=date_range, + count=count, + max_results=max_results, + verbose=verbose, + elasticsearch_client=elasticsearch_client, + ) else: - client_error('Rule is not a query rule!') + raise_client_error("Rule is not a query rule!") -@test_group.command('rule-survey') -@click.argument('query', required=False) -@click.option('--date-range', '-d', type=(str, str), default=('now-7d', 'now'), help='Date range to scope search') -@click.option('--dump-file', type=click.Path(dir_okay=False), - default=get_path('surveys', f'{time.strftime("%Y%m%dT%H%M%SL")}.json'), - help='Save details of results (capped at 1000 results/rule)') -@click.option('--hide-zero-counts', '-z', is_flag=True, help='Exclude rules with zero hits from printing') -@click.option('--hide-errors', '-e', is_flag=True, help='Exclude rules with errors from printing') +@test_group.command("rule-survey") +@click.argument("query", required=False) +@click.option("--date-range", "-d", type=(str, str), default=("now-7d", "now"), help="Date range to scope search") +@click.option( + "--dump-file", + type=click.Path(dir_okay=False, path_type=Path), + default=get_path(["surveys", f"{time.strftime('%Y%m%dT%H%M%SL')}.json"]), + help="Save details of results (capped at 1000 results/rule)", +) +@click.option("--hide-zero-counts", "-z", is_flag=True, help="Exclude rules with zero hits from printing") +@click.option("--hide-errors", "-e", is_flag=True, help="Exclude rules with errors from printing") @click.pass_context -@add_client('elasticsearch', 'kibana', add_to_ctx=True) -def rule_survey(ctx: click.Context, query, date_range, dump_file, hide_zero_counts, hide_errors, - elasticsearch_client: Elasticsearch = None, kibana_client: Kibana = None): +@add_client(["elasticsearch", "kibana"], add_to_ctx=True) +def rule_survey( # noqa: PLR0913 + ctx: click.Context, + query: str, + date_range: tuple[str, str], + dump_file: Path, + hide_zero_counts: bool, + hide_errors: bool, + elasticsearch_client: Elasticsearch, + kibana_client: Kibana, +) -> list[dict[str, int]]: """Survey rule counts.""" - from kibana.resources import Signal from .main import search_rules - # from .eswrap import parse_unique_field_results - - survey_results = [] + survey_results: list[dict[str, int]] = [] start_time, end_time = date_range if query: rules = RuleCollection() - paths = [Path(r['file']) for r in ctx.invoke(search_rules, query=query, verbose=False)] + paths = [Path(r["file"]) for r in ctx.invoke(search_rules, query=query, verbose=False)] rules.load_files(paths) else: rules = RuleCollection.default().filter(production_filter) - click.echo(f'Running survey against {len(rules)} rules') - click.echo(f'Saving detailed dump to: {dump_file}') + click.echo(f"Running survey against {len(rules)} rules") + click.echo(f"Saving detailed dump to: {dump_file}") collector = CollectEvents(elasticsearch_client) details = collector.search_from_rule(rules, start_time=start_time, end_time=end_time) @@ -1243,72 +1440,75 @@ def rule_survey(ctx: click.Context, query, date_range, dump_file, hide_zero_coun # add alerts with kibana_client: - range_dsl = {'query': {'bool': {'filter': []}}} - add_range_to_dsl(range_dsl['query']['bool']['filter'], start_time, end_time) - alerts = {a['_source']['signal']['rule']['rule_id']: a['_source'] - for a in Signal.search(range_dsl, size=10000)['hits']['hits']} - - # for alert in alerts: - # rule_id = alert['signal']['rule']['rule_id'] - # rule = rules.id_map[rule_id] - # unique_results = parse_unique_field_results(rule.contents.data.type, rule.contents.data.unique_fields, alert) + range_dsl: dict[str, Any] = {"query": {"bool": {"filter": []}}} + add_range_to_dsl(range_dsl["query"]["bool"]["filter"], start_time, end_time) + alerts: dict[str, Any] = { + a["_source"]["signal"]["rule"]["rule_id"]: a["_source"] + for a in Signal.search(range_dsl, size=10000)["hits"]["hits"] # type: ignore[reportUnknownMemberType] + } for rule_id, count in counts.items(): alert_count = len(alerts.get(rule_id, [])) if alert_count > 0: - count['alert_count'] = alert_count + count["alert_count"] = alert_count details[rule_id].update(count) - search_count = count['search_count'] - if not alert_count and (hide_zero_counts and search_count == 0) or (hide_errors and search_count == -1): + search_count = count["search_count"] + if (not alert_count and (hide_zero_counts and search_count == 0)) or (hide_errors and search_count == -1): continue survey_results.append(count) - fields = ['rule_id', 'name', 'search_count', 'alert_count'] - table = Table.from_list(fields, survey_results) + fields = ["rule_id", "name", "search_count", "alert_count"] + table = Table.from_list(fields, survey_results) # type: ignore[reportUnknownMemberType] - if len(survey_results) > 200: + if len(survey_results) > 200: # noqa: PLR2004 click.echo_via_pager(table) else: click.echo(table) - os.makedirs(get_path('surveys'), exist_ok=True) - with open(dump_file, 'w') as f: + get_path(["surveys"]).mkdir(exist_ok=True) + with dump_file.open("w") as f: json.dump(details, f, indent=2, sort_keys=True) return survey_results -@dev_group.group('utils') -def utils_group(): +@dev_group.group("utils") +def utils_group() -> None: """Commands for dev utility methods.""" -@utils_group.command('get-branches') -@click.option('--outfile', '-o', type=Path, default=get_etc_path("target-branches.yaml"), help='File to save output to') -def get_branches(outfile: Path): +@utils_group.command("get-branches") +@click.option( + "--outfile", + "-o", + type=Path, + default=get_etc_path(["target-branches.yaml"]), + help="File to save output to", +) +def get_branches(outfile: Path) -> None: branch_list = get_stack_versions(drop_patch=True) target_branches = json.dumps(branch_list[:-1]) + "\n" - outfile.write_text(target_branches) + _ = outfile.write_text(target_branches) -@dev_group.group('integrations') -def integrations_group(): +@dev_group.group("integrations") +def integrations_group() -> None: """Commands for dev integrations methods.""" -@integrations_group.command('build-manifests') -@click.option('--overwrite', '-o', is_flag=True, help="Overwrite the existing integrations-manifest.json.gz file") +@integrations_group.command("build-manifests") +@click.option("--overwrite", "-o", is_flag=True, help="Overwrite the existing integrations-manifest.json.gz file") @click.option("--integration", "-i", type=str, help="Adds an integration tag to the manifest file") @click.option("--prerelease", "-p", is_flag=True, default=False, help="Include prerelease versions") -def build_integration_manifests(overwrite: bool, integration: str, prerelease: bool = False): +def build_integration_manifests(overwrite: bool, integration: str, prerelease: bool = False) -> None: """Builds consolidated integrations manifests file.""" click.echo("loading rules to determine all integration tags") - def flatten(tag_list: List[str]) -> List[str]: - return list(set([tag for tags in tag_list for tag in (flatten(tags) if isinstance(tags, list) else [tags])])) + def flatten(tag_list: list[str | list[str]] | list[str]) -> list[str]: + return list({tag for tags in tag_list for tag in (flatten(tags) if isinstance(tags, list) else [tags])}) if integration: build_integrations_manifest(overwrite=False, integration=integration, prerelease=prerelease) @@ -1320,11 +1520,12 @@ def flatten(tag_list: List[str]) -> List[str]: build_integrations_manifest(overwrite, rule_integrations=unique_integration_tags) -@integrations_group.command('build-schemas') -@click.option('--overwrite', '-o', is_flag=True, help="Overwrite the entire integrations-schema.json.gz file") -@click.option('--integration', '-i', type=str, - help="Adds a single integration schema to the integrations-schema.json.gz file") -def build_integration_schemas(overwrite: bool, integration: str): +@integrations_group.command("build-schemas") +@click.option("--overwrite", "-o", is_flag=True, help="Overwrite the entire integrations-schema.json.gz file") +@click.option( + "--integration", "-i", type=str, help="Adds a single integration schema to the integrations-schema.json.gz file" +) +def build_integration_schemas(overwrite: bool, integration: str) -> None: """Builds consolidated integrations schemas file.""" click.echo("Building integration schemas...") @@ -1337,51 +1538,62 @@ def build_integration_schemas(overwrite: bool, integration: str): click.echo(f"Time taken to generate schemas: {(end_time - start_time) / 60:.2f} minutes") -@integrations_group.command('show-latest-compatible') -@click.option('--package', '-p', help='Name of package') -@click.option('--stack_version', '-s', required=True, help='Rule stack version') +@integrations_group.command("show-latest-compatible") +@click.option("--package", "-p", help="Name of package") +@click.option("--stack_version", "-s", required=True, help="Rule stack version") def show_latest_compatible_version(package: str, stack_version: str) -> None: """Prints the latest integration compatible version for specified package based on stack version supplied.""" packages_manifest = None try: packages_manifest = load_integrations_manifests() - except Exception as e: - click.echo(f"Error loading integrations manifests: {str(e)}") + except Exception as e: # noqa: BLE001 + click.echo(f"Error loading integrations manifests: {e!s}") return try: - version = find_latest_compatible_version(package, "", - Version.parse(stack_version, optional_minor_and_patch=True), - packages_manifest) + version = find_latest_compatible_version( + package, "", Version.parse(stack_version, optional_minor_and_patch=True), packages_manifest + ) click.echo(f"Compatible integration {version=}") - except Exception as e: - click.echo(f"Error finding compatible version: {str(e)}") + except Exception as e: # noqa: BLE001 + click.echo(f"Error finding compatible version: {e!s}") return -@dev_group.group('schemas') -def schemas_group(): +@dev_group.group("schemas") +def schemas_group() -> None: """Commands for dev schema methods.""" @schemas_group.command("update-rule-data") -def update_rule_data_schemas(): - classes = [BaseRuleData] + list(typing.get_args(AnyRuleData)) +def update_rule_data_schemas() -> None: + classes = [BaseRuleData, *typing.get_args(AnyRuleData)] for cls in classes: - cls.save_schema() + _ = cls.save_schema() @schemas_group.command("generate") -@click.option("--token", required=True, prompt=get_github_token() is None, default=get_github_token(), - help="GitHub token to use for the PR", hide_input=True) -@click.option("--schema", "-s", required=True, type=click.Choice(["endgame", "ecs", "beats", "endpoint"]), - help="Schema to generate") +@click.option( + "--token", + required=True, + prompt=get_github_token() is None, + default=get_github_token(), + help="GitHub token to use for the PR", + hide_input=True, +) +@click.option( + "--schema", + "-s", + required=True, + type=click.Choice(["endgame", "ecs", "beats", "endpoint"]), + help="Schema to generate", +) @click.option("--schema-version", "-sv", help="Tagged version from TBD. e.g., 1.9.0") @click.option("--endpoint-target", "-t", type=str, default="endpoint", help="Target endpoint schema") @click.option("--overwrite", is_flag=True, help="Overwrite if versions exist") -def generate_schema(token: str, schema: str, schema_version: str, endpoint_target: str, overwrite: bool): +def generate_schema(token: str, schema: str, schema_version: str, endpoint_target: str, overwrite: bool) -> None: """Download schemas and generate flattend schema.""" github = GithubClient(token) client = github.authenticated_client @@ -1413,8 +1625,9 @@ def generate_schema(token: str, schema: str, schema_version: str, endpoint_targe repo = client.get_repo("elastic/endpoint-package") contents = repo.get_contents("custom_schemas") optional_endpoint_targets = [ - Path(f.path).name.replace("custom_", "").replace(".yml", "") - for f in contents if f.name.endswith(".yml") or Path(f.path).name == endpoint_target + Path(f.path).name.replace("custom_", "").replace(".yml", "") # type: ignore[reportUnknownMemberType] + for f in contents # type: ignore[reportUnknownVariableType] + if f.name.endswith(".yml") or Path(f.path).name == endpoint_target # type: ignore[reportUnknownMemberType] ] if not endpoint_target: @@ -1426,147 +1639,154 @@ def generate_schema(token: str, schema: str, schema_version: str, endpoint_targe click.echo(f"Done generating {schema} schema") -@dev_group.group('attack') -def attack_group(): +@dev_group.group("attack") +def attack_group() -> None: """Commands for managing Mitre ATT&CK data and mappings.""" -@attack_group.command('refresh-data') -def refresh_attack_data() -> dict: +@attack_group.command("refresh-data") +def refresh_attack_data() -> dict[str, Any] | None: """Refresh the ATT&CK data file.""" data, _ = attack.refresh_attack_data() return data -@attack_group.command('refresh-redirect-mappings') -def refresh_threat_mappings(): +@attack_group.command("refresh-redirect-mappings") +def refresh_threat_mappings() -> None: """Refresh the ATT&CK redirect file and update all rule threat mappings.""" # refresh the attack_technique_redirects - click.echo('refreshing data in attack_technique_redirects.json') + click.echo("refreshing data in attack_technique_redirects.json") attack.refresh_redirected_techniques_map() -@attack_group.command('update-rules') -def update_attack_in_rules() -> List[Optional[TOMLRule]]: +@attack_group.command("update-rules") +def update_attack_in_rules() -> list[TOMLRule]: """Update threat mappings attack data in all rules.""" - new_rules = [] + new_rules: list[TOMLRule] = [] redirected_techniques = attack.load_techniques_redirect() - today = time.strftime('%Y/%m/%d') + today = time.strftime("%Y/%m/%d") rules = RuleCollection.default() for rule in rules.rules: needs_update = False - valid_threat: List[ThreatMapping] = [] - threat_pending_update = {} + updated_threat_map: dict[str, ThreatMapping] = {} threat = rule.contents.data.threat or [] for entry in threat: - tactic = entry.tactic.name - technique_ids = [] - technique_names = [] + tactic_id = entry.tactic.id + tactic_name = entry.tactic.name + technique_ids: list[str] = [] + technique_names: list[str] = [] for technique in entry.technique or []: technique_ids.append(technique.id) technique_names.append(technique.name) - technique_ids.extend([st.id for st in technique.subtechnique or []]) - technique_names.extend([st.name for st in technique.subtechnique or []]) + if technique.subtechnique: + technique_ids.extend([st.id for st in technique.subtechnique]) + technique_names.extend([st.name for st in technique.subtechnique]) - # check redirected techniques by ID - # redirected techniques are technique IDs that have changed but represent the same technique - if any([tid for tid in technique_ids if tid in redirected_techniques]): + if any(tid for tid in technique_ids if tid in redirected_techniques): needs_update = True - threat_pending_update[tactic] = technique_ids - click.echo(f"'{rule.contents.name}' requires update - technique ID change") - - # check for name change - # happens if technique ID is the same but name changes - expected_technique_names = [attack.technique_lookup[str(tid)]["name"] for tid in technique_ids] - if any([tname for tname in technique_names if tname not in expected_technique_names]): + click.echo(f"'{rule.contents.name}' requires update - technique ID change for tactic '{tactic_name}'") + elif any( + tname + for tname in technique_names + if tname + not in [ + attack.technique_lookup[str(tid)]["name"] + for tid in technique_ids + if str(tid) in attack.technique_lookup + ] + ): needs_update = True - threat_pending_update[tactic] = technique_ids - click.echo(f"'{rule.contents.name}' requires update - technique name change") + click.echo(f"'{rule.contents.name}' requires update - technique name change for tactic '{tactic_name}'") + if needs_update: + try: + updated_threat_entry = attack.build_threat_map_entry(tactic_name, *technique_ids) + updated_threat_map[tactic_id] = ThreatMapping.from_dict(updated_threat_entry) + except ValueError as exc: + raise ValueError(f"{rule.id} - {rule.name}: {exc}") from exc else: - valid_threat.append(entry) + updated_threat_map[tactic_id] = entry if needs_update: - for tactic, techniques in threat_pending_update.items(): - try: - updated_threat = attack.build_threat_map_entry(tactic, *techniques) - except ValueError as err: - raise ValueError(f'{rule.id} - {rule.name}: {err}') - - tm = ThreatMapping.from_dict(updated_threat) - valid_threat.append(tm) + final_threat_list = list(updated_threat_map.values()) + final_threat_list.sort(key=lambda x: x.tactic.name) new_meta = dataclasses.replace(rule.contents.metadata, updated_date=today) - new_data = dataclasses.replace(rule.contents.data, threat=valid_threat) + new_data = dataclasses.replace(rule.contents.data, threat=final_threat_list) new_contents = dataclasses.replace(rule.contents, data=new_data, metadata=new_meta) new_rule = TOMLRule(contents=new_contents, path=rule.path) new_rule.save_toml() new_rules.append(new_rule) if new_rules: - click.echo(f'\nFinished - {len(new_rules)} rules updated!') + click.echo(f"\nFinished - {len(new_rules)} rules updated!") else: - click.echo('No rule changes needed') + click.echo("No rule changes needed") return new_rules -@dev_group.group('transforms') -def transforms_group(): +@dev_group.group("transforms") +def transforms_group() -> None: """Commands for managing TOML [transform].""" -def guide_plugin_convert_(contents: Optional[str] = None, default: Optional[str] = '' - ) -> Optional[Dict[str, Dict[str, list]]]: +def guide_plugin_convert_( + contents: str | None = None, + default: str | None = "", +) -> dict[str, dict[str, list[str]]] | None: """Convert investigation guide plugin format to toml""" - contents = contents or click.prompt('Enter plugin contents', default=default) + contents = contents or click.prompt("Enter plugin contents", default=default) if not contents: - return + return None - parsed = re.match(r'!{(?P\w+)(?P{.+})}', contents.strip()) + parsed = re.match(r"!{(?P\w+)(?P{.+})}", contents.strip()) + if not parsed: + raise ValueError("No plugin name found") try: - plugin = parsed.group('plugin') - data = parsed.group('data') + plugin = parsed.group("plugin") + data = parsed.group("data") except AttributeError as e: - raise client_error('Unrecognized pattern', exc=e) - loaded = {'transform': {plugin: [json.loads(data)]}} - click.echo(pytoml.dumps(loaded)) + raise raise_client_error("Unrecognized pattern", exc=e) from e + loaded = {"transform": {plugin: [json.loads(data)]}} + click.echo(pytoml.dumps(loaded)) # type: ignore[reportUnknownMemberType] return loaded -@transforms_group.command('guide-plugin-convert') -def guide_plugin_convert(contents: Optional[str] = None, default: Optional[str] = '' - ) -> Optional[Dict[str, Dict[str, list]]]: +@transforms_group.command("guide-plugin-convert") +def guide_plugin_convert( + contents: str | None = None, default: str | None = "" +) -> dict[str, dict[str, list[str]]] | None: """Convert investigation guide plugin format to toml.""" return guide_plugin_convert_(contents=contents, default=default) -@transforms_group.command('guide-plugin-to-rule') -@click.argument('rule-path', type=Path) +@transforms_group.command("guide-plugin-to-rule") +@click.argument("rule-path", type=Path) @click.pass_context def guide_plugin_to_rule(ctx: click.Context, rule_path: Path, save: bool = True) -> TOMLRule: """Convert investigation guide plugin format to toml and save to rule.""" rc = RuleCollection() rule = rc.load_file(rule_path) - transforms = defaultdict(list) - existing_transform = rule.contents.transform - transforms.update(existing_transform.to_dict() if existing_transform is not None else {}) + transforms: dict[str, list[Any]] = defaultdict(list) + existing_transform: RuleTransform | None = rule.contents.transform # type: ignore[reportAssignmentType] + transforms.update(existing_transform.to_dict() if existing_transform else {}) - click.secho('(blank line to continue)', fg='yellow') + click.secho("(blank line to continue)", fg="yellow") while True: loaded = ctx.invoke(guide_plugin_convert) if not loaded: break - data = loaded['transform'] + data = loaded["transform"] for plugin, entries in data.items(): transforms[plugin].extend(entries) transform = RuleTransform.from_dict(transforms) - new_contents = TOMLRuleContents(data=rule.contents.data, metadata=rule.contents.metadata, transform=transform) + new_contents = TOMLRuleContents(data=rule.contents.data, metadata=rule.contents.metadata, transform=transform) # type: ignore[reportArgumentType] updated_rule = TOMLRule(contents=new_contents, path=rule.path) if save: diff --git a/detection_rules/docs.py b/detection_rules/docs.py index 27c455d203f..95e7b761875 100644 --- a/detection_rules/docs.py +++ b/detection_rules/docs.py @@ -4,18 +4,21 @@ # 2.0. """Create summary documents for a rule package.""" + import itertools import json import re import shutil import textwrap +import typing from collections import defaultdict from dataclasses import asdict, dataclass from datetime import datetime from pathlib import Path -from typing import Dict, Iterable, List, Optional, Union +from typing import Any -import xlsxwriter +import xlsxwriter # type: ignore[reportMissingTypeStubs] +import xlsxwriter.format # type: ignore[reportMissingTypeStubs] from semver import Version from .attack import attack_tm, matrix, tactics, technique_lookup @@ -30,38 +33,41 @@ class PackageDocument(xlsxwriter.Workbook): """Excel document for summarizing a rules package.""" - def __init__(self, path, package: Package): + def __init__(self, path: str, package: Package) -> None: """Create an excel workbook for the package.""" - self._default_format = {'font_name': 'Helvetica', 'font_size': 12} - super(PackageDocument, self).__init__(path) + self._default_format = {"font_name": "Helvetica", "font_size": 12} + super().__init__(path) # type: ignore[reportUnknownMemberType] self.package = package self.deprecated_rules = package.deprecated_rules self.production_rules = package.rules - self.percent = self.add_format({'num_format': '0%'}) - self.bold = self.add_format({'bold': True}) - self.default_header_format = self.add_format({'bold': True, 'bg_color': '#FFBE33'}) - self.center = self.add_format({'align': 'center', 'valign': 'center'}) - self.bold_center = self.add_format({'bold': True, 'align': 'center', 'valign': 'center'}) - self.right_align = self.add_format({'align': 'right'}) + self.percent = self.add_format({"num_format": "0%"}) + self.bold = self.add_format({"bold": True}) + self.default_header_format = self.add_format({"bold": True, "bg_color": "#FFBE33"}) + self.center = self.add_format({"align": "center", "valign": "center"}) + self.bold_center = self.add_format({"bold": True, "align": "center", "valign": "center"}) + self.right_align = self.add_format({"align": "right"}) self._coverage = self._get_attack_coverage() - def add_format(self, properties=None): + def add_format(self, properties: dict[str, Any] | None = None) -> xlsxwriter.format.Format: """Add a format to the doc.""" properties = properties or {} for key in self._default_format: if key not in properties: properties[key] = self._default_format[key] - return super(PackageDocument, self).add_format(properties) + return super().add_format(properties) # type: ignore[reportUnknownMemberType] - def _get_attack_coverage(self): - coverage = defaultdict(lambda: defaultdict(lambda: defaultdict(int))) + def _get_attack_coverage(self) -> dict[str, Any]: + coverage: dict[str, dict[str, dict[str, int]]] = defaultdict(lambda: defaultdict(lambda: defaultdict(int))) for rule in self.package.rules: threat = rule.contents.data.threat + if not rule.path: + raise ValueError("No rule path found") + sub_dir = Path(rule.path).parent.name if threat: @@ -74,16 +80,17 @@ def _get_attack_coverage(self): return coverage - def populate(self): + def populate(self) -> None: """Populate the different pages.""" self.add_summary() self.add_rule_details() self.add_attack_matrix() - self.add_rule_details(self.deprecated_rules, 'Deprecated Rules') + self.add_rule_details(self.deprecated_rules, "Deprecated Rules") - def add_summary(self): + @typing.no_type_check + def add_summary(self) -> None: """Add the summary worksheet.""" - worksheet = self.add_worksheet('Summary') + worksheet = self.add_worksheet("Summary") worksheet.freeze_panes(1, 0) worksheet.set_column(0, 0, 25) worksheet.set_column(1, 1, 10) @@ -92,66 +99,88 @@ def add_summary(self): worksheet.merge_range(row, 0, row, 1, "SUMMARY", self.bold_center) row += 1 - worksheet.write(row, 0, "Package Name") - worksheet.write(row, 1, self.package.name, self.right_align) + _ = worksheet.write(row, 0, "Package Name") + _ = worksheet.write(row, 1, self.package.name, self.right_align) row += 1 - tactic_counts = defaultdict(int) + tactic_counts: dict[str, int] = defaultdict(int) for rule in self.package.rules: threat = rule.contents.data.threat if threat: for entry in threat: tactic_counts[entry.tactic.name] += 1 - worksheet.write(row, 0, "Total Production Rules") - worksheet.write(row, 1, len(self.production_rules)) + _ = worksheet.write(row, 0, "Total Production Rules") + _ = worksheet.write(row, 1, len(self.production_rules)) row += 2 - worksheet.write(row, 0, "Total Deprecated Rules") - worksheet.write(row, 1, len(self.deprecated_rules)) + _ = worksheet.write(row, 0, "Total Deprecated Rules") + _ = worksheet.write(row, 1, len(self.deprecated_rules)) row += 1 - worksheet.write(row, 0, "Total Rules") - worksheet.write(row, 1, len(self.package.rules)) + _ = worksheet.write(row, 0, "Total Rules") + _ = worksheet.write(row, 1, len(self.package.rules)) row += 2 worksheet.merge_range(row, 0, row, 3, f"MITRE {attack_tm} TACTICS", self.bold_center) row += 1 for tactic in tactics: - worksheet.write(row, 0, tactic) - worksheet.write(row, 1, tactic_counts[tactic]) + _ = worksheet.write(row, 0, tactic) + _ = worksheet.write(row, 1, tactic_counts[tactic]) num_techniques = len(self._coverage[tactic]) total_techniques = len(matrix[tactic]) percent = float(num_techniques) / float(total_techniques) - worksheet.write(row, 2, percent, self.percent) - worksheet.write(row, 3, f'{num_techniques}/{total_techniques}', self.right_align) + _ = worksheet.write(row, 2, percent, self.percent) + _ = worksheet.write(row, 3, f"{num_techniques}/{total_techniques}", self.right_align) row += 1 - def add_rule_details(self, rules: Optional[Union[DeprecatedCollection, RuleCollection]] = None, - name='Rule Details'): + def add_rule_details( + self, + rules: DeprecatedCollection | RuleCollection | None = None, + name: str = "Rule Details", + ) -> None: """Add a worksheet for detailed metadata of rules.""" if rules is None: rules = self.production_rules - worksheet = self.add_worksheet(name) - worksheet.freeze_panes(1, 1) - headers = ('Name', 'ID', 'Version', 'Type', 'Language', 'Index', 'Tags', - f'{attack_tm} Tactics', f'{attack_tm} Techniques', 'Description') + worksheet = self.add_worksheet(name) # type: ignore[reportUnknownVariableType] + worksheet.freeze_panes(1, 1) # type: ignore[reportUnknownVariableType] + headers = ( + "Name", + "ID", + "Version", + "Type", + "Language", + "Index", + "Tags", + f"{attack_tm} Tactics", + f"{attack_tm} Techniques", + "Description", + ) for column, header in enumerate(headers): - worksheet.write(0, column, header, self.default_header_format) + _ = worksheet.write(0, column, header, self.default_header_format) # type: ignore[reportUnknownMemberType] - column_max_widths = [0 for i in range(len(headers))] + column_max_widths = [0 for _ in range(len(headers))] metadata_fields = ( - 'name', 'rule_id', 'version', 'type', 'language', 'index', 'tags', 'tactics', 'techniques', 'description' + "name", + "rule_id", + "version", + "type", + "language", + "index", + "tags", + "tactics", + "techniques", + "description", ) for row, rule in enumerate(rules, 1): - rule_contents = {'tactics': '', 'techniques': ''} + rule_contents = {"tactics": "", "techniques": ""} if isinstance(rules, RuleCollection): - flat_mitre = ThreatMapping.flatten(rule.contents.data.threat) - rule_contents = {'tactics': flat_mitre.tactic_names, 'techniques': flat_mitre.technique_ids} + flat_mitre = ThreatMapping.flatten(rule.contents.data.threat) # type: ignore[reportAttributeAccessIssue] + rule_contents = {"tactics": flat_mitre.tactic_names, "techniques": flat_mitre.technique_ids} rule_contents.update(rule.contents.to_api_format()) @@ -159,9 +188,9 @@ def add_rule_details(self, rules: Optional[Union[DeprecatedCollection, RuleColle value = rule_contents.get(field) if value is None: continue - elif isinstance(value, list): - value = ', '.join(value) - worksheet.write(row, column, value) + if isinstance(value, list): + value = ", ".join(value) + _ = worksheet.write(row, column, value) # type: ignore[reportUnknownMemberType] column_max_widths[column] = max(column_max_widths[column], len(str(value))) # cap description width at 80 @@ -169,37 +198,43 @@ def add_rule_details(self, rules: Optional[Union[DeprecatedCollection, RuleColle # this is still not perfect because the font used is not monospaced, but it gets it close for index, width in enumerate(column_max_widths): - worksheet.set_column(index, index, width) + _ = worksheet.set_column(index, index, width) # type: ignore[reportUnknownMemberType] - worksheet.autofilter(0, 0, len(rules) + 1, len(headers) - 1) + _ = worksheet.autofilter(0, 0, len(rules) + 1, len(headers) - 1) # type: ignore[reportUnknownMemberType] - def add_attack_matrix(self): + def add_attack_matrix(self) -> None: """Add a worksheet for ATT&CK coverage.""" - worksheet = self.add_worksheet(attack_tm + ' Coverage') - worksheet.freeze_panes(1, 0) - header = self.add_format({'font_size': 12, 'bold': True, 'bg_color': '#005B94', 'font_color': 'white'}) - default = self.add_format({'font_size': 10, 'text_wrap': True}) - bold = self.add_format({'font_size': 10, 'bold': True, 'text_wrap': True}) - technique_url = 'https://attack.mitre.org/techniques/' + worksheet = self.add_worksheet(attack_tm + " Coverage") # type: ignore[reportUnknownMemberType] + worksheet.freeze_panes(1, 0) # type: ignore[reportUnknownMemberType] + header = self.add_format({"font_size": 12, "bold": True, "bg_color": "#005B94", "font_color": "white"}) + default = self.add_format({"font_size": 10, "text_wrap": True}) + bold = self.add_format({"font_size": 10, "bold": True, "text_wrap": True}) + technique_url = "https://attack.mitre.org/techniques/" for column, tactic in enumerate(tactics): - worksheet.write(0, column, tactic, header) - worksheet.set_column(column, column, 20) + _ = worksheet.write(0, column, tactic, header) # type: ignore[reportUnknownMemberType] + _ = worksheet.set_column(column, column, 20) # type: ignore[reportUnknownMemberType] for row, technique_id in enumerate(matrix[tactic], 1): technique = technique_lookup[technique_id] fmt = bold if technique_id in self._coverage[tactic] else default coverage = self._coverage[tactic].get(technique_id) - coverage_str = '' + coverage_str = "" if coverage: - coverage_str = '\n\n' - coverage_str += '\n'.join(f'{sub_dir}: {count}' for sub_dir, count in coverage.items()) + coverage_str = "\n\n" + coverage_str += "\n".join(f"{sub_dir}: {count}" for sub_dir, count in coverage.items()) - worksheet.write_url(row, column, technique_url + technique_id.replace('.', '/'), cell_format=fmt, - string=technique['name'], tip=f'{technique_id}{coverage_str}') + _ = worksheet.write_url( # type: ignore[reportUnknownMemberType] + row, + column, + technique_url + technique_id.replace(".", "/"), + cell_format=fmt, + string=technique["name"], + tip=f"{technique_id}{coverage_str}", + ) - worksheet.autofilter(0, 0, max([len(v) for k, v in matrix.items()]) + 1, len(tactics) - 1) + _ = worksheet.autofilter(0, 0, max([len(v) for _, v in matrix.items()]) + 1, len(tactics) - 1) # type: ignore[reportUnknownMemberType] # product rule docs @@ -207,41 +242,40 @@ def add_attack_matrix(self): class AsciiDoc: - @classmethod - def bold_kv(cls, key: str, value: str): - return f'*{key}*: {value}' + def bold_kv(cls, key: str, value: str) -> str: + return f"*{key}*: {value}" @classmethod - def description_list(cls, value: Dict[str, str], linesep='\n\n'): - return f'{linesep}'.join(f'{k}::\n{v}' for k, v in value.items()) + def description_list(cls, value: dict[str, str], linesep: str = "\n\n") -> str: + return f"{linesep}".join(f"{k}::\n{v}" for k, v in value.items()) @classmethod - def bulleted(cls, value: str, depth=1): - return f'{"*" * depth} {value}' + def bulleted(cls, value: str, depth: int = 1) -> str: + return f"{'*' * depth} {value}" @classmethod - def bulleted_list(cls, values: Iterable): - return '* ' + '\n* '.join(values) + def bulleted_list(cls, values: list[str]) -> str: + return "* " + "\n* ".join(values) @classmethod - def code(cls, value: str, code='js'): + def code(cls, value: str, code: str = "js") -> str: line_sep = "-" * 34 - return f'[source, {code}]\n{line_sep}\n{value}\n{line_sep}' + return f"[source, {code}]\n{line_sep}\n{value}\n{line_sep}" @classmethod - def title(cls, depth: int, value: str): - return f'{"=" * depth} {value}' + def title(cls, depth: int, value: str) -> str: + return f"{'=' * depth} {value}" @classmethod - def inline_anchor(cls, value: str): - return f'[[{value}]]' + def inline_anchor(cls, value: str) -> str: + return f"[[{value}]]" @classmethod - def table(cls, data: dict) -> str: - entries = [f'| {k} | {v}' for k, v in data.items()] - table = ['[width="100%"]', '|==='] + entries + ['|==='] - return '\n'.join(table) + def table(cls, data: dict[str, Any]) -> str: + entries = [f"| {k} | {v}" for k, v in data.items()] + table = ['[width="100%"]', "|===", *entries, "|==="] + return "\n".join(table) class SecurityDocs: @@ -252,35 +286,50 @@ class KibanaSecurityDocs: """Generate docs for prebuilt rules in Elastic documentation.""" @staticmethod - def cmp_value(value): + def cmp_value(value: Any) -> Any: if isinstance(value, list): - cmp_new = tuple(value) + cmp_new = tuple(value) # type: ignore[reportUnknownArgumentType] elif isinstance(value, dict): cmp_new = json.dumps(value, sort_keys=True, indent=2) else: cmp_new = value - return cmp_new + return cmp_new # type: ignore[reportUnknownVariableType] class IntegrationSecurityDocs: """Generate docs for prebuilt rules in Elastic documentation.""" - def __init__(self, registry_version: str, directory: Path, overwrite=False, - updated_rules: Optional[Dict[str, TOMLRule]] = None, new_rules: Optional[Dict[str, TOMLRule]] = None, - deprecated_rules: Optional[Dict[str, TOMLRule]] = None, update_message: str = ""): + def __init__( # noqa: PLR0913 + self, + registry_version: str, + directory: Path, + overwrite: bool = False, + updated_rules: dict[str, TOMLRule] | None = None, + new_rules: dict[str, TOMLRule] | None = None, + deprecated_rules: dict[str, DeprecatedRule] | None = None, + update_message: str = "", + ) -> None: self.new_rules = new_rules self.updated_rules = updated_rules self.deprecated_rules = deprecated_rules - self.included_rules = list(itertools.chain(new_rules.values(), - updated_rules.values(), - deprecated_rules.values())) + self.included_rules: list[TOMLRule | DeprecatedRule] = [] + if new_rules: + self.included_rules += new_rules.values() + + if updated_rules: + self.included_rules += updated_rules.values() + + if deprecated_rules: + self.included_rules += deprecated_rules.values() all_rules = RuleCollection.default().rules self.sorted_rules = sorted(all_rules, key=lambda rule: rule.name) self.registry_version_str, self.base_name, self.prebuilt_rule_base = self.parse_registry(registry_version) self.directory = directory - self.package_directory = directory / "docs" / "detections" / "prebuilt-rules" / "downloadable-packages" / self.base_name # noqa: E501 + self.package_directory = ( + directory / "docs" / "detections" / "prebuilt-rules" / "downloadable-packages" / self.base_name + ) self.rule_details = directory / "docs" / "detections" / "prebuilt-rules" / "rule-details" self.update_message = update_message @@ -290,18 +339,20 @@ def __init__(self, registry_version: str, directory: Path, overwrite=False, self.package_directory.mkdir(parents=True, exist_ok=overwrite) @staticmethod - def parse_registry(registry_version: str) -> (str, str, str): - registry_version = Version.parse(registry_version, optional_minor_and_patch=True) - short_registry_version = [str(n) for n in registry_version[:3]] - registry_version_str = '.'.join(short_registry_version) + def parse_registry(registry_version_val: str) -> tuple[str, str, str]: + registry_version = Version.parse(registry_version_val, optional_minor_and_patch=True) + + parts = registry_version[:3] + short_registry_version = [str(n) for n in parts] # type: ignore[reportOptionalIterable] + registry_version_str = ".".join(short_registry_version) base_name = "-".join(short_registry_version) - prebuilt_rule_base = f'prebuilt-rule-{base_name}' + prebuilt_rule_base = f"prebuilt-rule-{base_name}" return registry_version_str, base_name, prebuilt_rule_base - def generate_appendix(self): + def generate_appendix(self) -> None: # appendix - appendix = self.package_directory / f'prebuilt-rules-{self.base_name}-appendix.asciidoc' + appendix = self.package_directory / f"prebuilt-rules-{self.base_name}-appendix.asciidoc" appendix_header = textwrap.dedent(f""" ["appendix",role="exclude",id="prebuilt-rule-{self.base_name}-prebuilt-rules-{self.base_name}-appendix"] @@ -311,13 +362,13 @@ def generate_appendix(self): """).lstrip() # noqa: E501 - include_format = f'include::{self.prebuilt_rule_base}-' + '{}.asciidoc[]' - appendix_lines = [appendix_header] + [include_format.format(name_to_title(r.name)) for r in self.included_rules] - appendix_str = '\n'.join(appendix_lines) + '\n' - appendix.write_text(appendix_str) + include_format = f"include::{self.prebuilt_rule_base}-" + "{}.asciidoc[]" + appendix_lines = [appendix_header] + [include_format.format(name_to_title(r.name)) for r in self.included_rules] # type: ignore[reportArgumentType] + appendix_str = "\n".join(appendix_lines) + "\n" + _ = appendix.write_text(appendix_str) - def generate_summary(self): - summary = self.package_directory / f'prebuilt-rules-{self.base_name}-summary.asciidoc' + def generate_summary(self) -> None: + summary = self.package_directory / f"prebuilt-rules-{self.base_name}-summary.asciidoc" summary_header = textwrap.dedent(f""" [[prebuilt-rule-{self.base_name}-prebuilt-rules-{self.base_name}-summary]] @@ -332,25 +383,34 @@ def generate_summary(self): |Rule |Description |Status |Version """).lstrip() # noqa: E501 - rule_entries = [] + rule_entries: list[str] = [] for rule in self.included_rules: - if rule.contents.metadata.get('maturity') == 'development': + if rule.contents.metadata.get("maturity") == "development": continue - title_name = name_to_title(rule.name) - status = 'new' if rule.id in self.new_rules else 'update' if rule.id in self.updated_rules else 'deprecated' - description = rule.contents.to_api_format()['description'] + title_name = name_to_title(rule.name) # type: ignore[reportArgumentType] + + if self.new_rules and rule.id in self.new_rules: + status = "new" + elif self.updated_rules and rule.id in self.updated_rules: + status = "update" + else: + status = "deprecated" + + description = rule.contents.to_api_format()["description"] version = rule.contents.autobumped_version - rule_entries.append(f'|<> ' - f'| {description} | {status} | {version} \n') + rule_entries.append( + f"|<> " + f"| {description} | {status} | {version} \n" + ) - summary_lines = [summary_header] + rule_entries + ['|=============================================='] - summary_str = '\n'.join(summary_lines) + '\n' - summary.write_text(summary_str) + summary_lines = [summary_header, *rule_entries, "|=============================================="] + summary_str = "\n".join(summary_lines) + "\n" + _ = summary.write_text(summary_str) - def generate_rule_reference(self): + def generate_rule_reference(self) -> None: """Generate rule reference page for prebuilt rules.""" - summary = self.directory / "docs" / "detections" / "prebuilt-rules" / 'prebuilt-rules-reference.asciidoc' - rule_list = self.directory / "docs" / "detections" / "prebuilt-rules" / 'rule-desc-index.asciidoc' + summary = self.directory / "docs" / "detections" / "prebuilt-rules" / "prebuilt-rules-reference.asciidoc" + rule_list = self.directory / "docs" / "detections" / "prebuilt-rules" / "rule-desc-index.asciidoc" summary_header = textwrap.dedent(""" [[prebuilt-rules]] @@ -368,88 +428,88 @@ def generate_rule_reference(self): |============================================== |Rule |Description |Tags |Added |Version - """).lstrip() # noqa: E501 + """).lstrip() - rule_entries = [] - rule_includes = [] + rule_entries: list[str] = [] + rule_includes: list[str] = [] for rule in self.sorted_rules: if isinstance(rule, DeprecatedRule): continue - if rule.contents.metadata.get('maturity') == 'development': + if rule.contents.metadata.get("maturity") == "development": continue title_name = name_to_title(rule.name) # skip rules not built for this package - built_rules = [x.name for x in self.rule_details.glob('*.asciidoc')] + built_rules = [x.name for x in self.rule_details.glob("*.asciidoc")] if f"{title_name}.asciidoc" not in built_rules: continue - rule_includes.append(f'include::rule-details/{title_name}.asciidoc[]') - tags = ', '.join(f'[{tag}]' for tag in rule.contents.data.tags) - description = rule.contents.to_api_format()['description'] + rule_includes.append(f"include::rule-details/{title_name}.asciidoc[]") + tags = ", ".join(f"[{tag}]" for tag in rule.contents.data.tags) # type: ignore[reportOptionalIterable] + description = rule.contents.to_api_format()["description"] version = rule.contents.autobumped_version added = rule.contents.metadata.min_stack_version - rule_entries.append(f'|<<{title_name}, {rule.name}>> |{description} |{tags} |{added} |{version}\n') + rule_entries.append(f"|<<{title_name}, {rule.name}>> |{description} |{tags} |{added} |{version}\n") - summary_lines = [summary_header] + rule_entries + ['|=============================================='] - summary_str = '\n'.join(summary_lines) + '\n' - summary.write_text(summary_str) + summary_lines = [summary_header, *rule_entries, "|=============================================="] + summary_str = "\n".join(summary_lines) + "\n" + _ = summary.write_text(summary_str) # update rule-desc-index.asciidoc - rule_list.write_text('\n'.join(rule_includes)) + _ = rule_list.write_text("\n".join(rule_includes)) - def generate_rule_details(self): + def generate_rule_details(self) -> None: """Generate rule details for each prebuilt rule.""" included_rules = [x.name for x in self.included_rules] for rule in self.sorted_rules: - if rule.contents.metadata.get('maturity') == 'development': + if rule.contents.metadata.get("maturity") == "development": continue rule_detail = IntegrationRuleDetail(rule.id, rule.contents.to_api_format(), {}, self.base_name) - rule_path = self.package_directory / f'{self.prebuilt_rule_base}-{name_to_title(rule.name)}.asciidoc' - prebuilt_rule_path = self.rule_details / f'{name_to_title(rule.name)}.asciidoc' # noqa: E501 + rule_path = self.package_directory / f"{self.prebuilt_rule_base}-{name_to_title(rule.name)}.asciidoc" + prebuilt_rule_path = self.rule_details / f"{name_to_title(rule.name)}.asciidoc" if rule.name in included_rules: # only include updates - rule_path.write_text(rule_detail.generate()) + _ = rule_path.write_text(rule_detail.generate()) # add all available rules to the rule details directory - prebuilt_rule_path.write_text(rule_detail.generate(title=f'{name_to_title(rule.name)}')) + _ = prebuilt_rule_path.write_text(rule_detail.generate(title=f"{name_to_title(rule.name)}")) - def generate_manual_updates(self): + def generate_manual_updates(self) -> None: """ Generate manual updates for prebuilt rules downloadable updates and index. """ updates = {} # Update downloadable rule updates entry - today = datetime.today().strftime('%d %b %Y') + today = datetime.today().strftime("%d %b %Y") # noqa: DTZ002 - updates['downloadable-updates.asciidoc'] = { - 'table_entry': ( - f'|<> | {today} | {len(self.new_rules)} | ' - f'{len(self.updated_rules)} | ' + updates["downloadable-updates.asciidoc"] = { + "table_entry": ( + f"|<> | {today} | {len(self.new_rules or [])} | " + f"{len(self.updated_rules or [])} | " + ), + "table_include": ( + f"include::downloadable-packages/{self.base_name}/" + f"prebuilt-rules-{self.base_name}-summary.asciidoc[leveloffset=+1]" ), - 'table_include': ( - f'include::downloadable-packages/{self.base_name}/' - f'prebuilt-rules-{self.base_name}-summary.asciidoc[leveloffset=+1]' - ) } - updates['index.asciidoc'] = { - 'index_include': ( - f'include::detections/prebuilt-rules/downloadable-packages/{self.base_name}/' - f'prebuilt-rules-{self.base_name}-appendix.asciidoc[]' + updates["index.asciidoc"] = { + "index_include": ( + f"include::detections/prebuilt-rules/downloadable-packages/{self.base_name}/" + f"prebuilt-rules-{self.base_name}-appendix.asciidoc[]" ) } # Add index.asciidoc:index_include in docs/index.asciidoc - docs_index = self.package_directory.parent.parent.parent.parent / 'index.asciidoc' - docs_index.write_text(docs_index.read_text() + '\n' + updates['index.asciidoc']['index_include'] + '\n') + docs_index = self.package_directory.parent.parent.parent.parent / "index.asciidoc" + _ = docs_index.write_text(docs_index.read_text() + "\n" + updates["index.asciidoc"]["index_include"] + "\n") # Add table_entry to docs/detections/prebuilt-rules/prebuilt-rules-downloadable-updates.asciidoc - downloadable_updates = self.package_directory.parent.parent / 'prebuilt-rules-downloadable-updates.asciidoc' + downloadable_updates = self.package_directory.parent.parent / "prebuilt-rules-downloadable-updates.asciidoc" version = Version.parse(self.registry_version_str) last_version = f"{version.major}.{version.minor - 1}" update_url = f"https://www.elastic.co/guide/en/security/{last_version}/prebuilt-rules-downloadable-updates.html" @@ -469,21 +529,22 @@ def generate_manual_updates(self): |Update version |Date | New rules | Updated rules | Notes """).lstrip() # noqa: E501 - new_content = updates['downloadable-updates.asciidoc']['table_entry'] + '\n' + self.update_message + new_content = updates["downloadable-updates.asciidoc"]["table_entry"] + "\n" + self.update_message self.add_content_to_table_top(downloadable_updates, summary_header, new_content) # Add table_include to/docs/detections/prebuilt-rules/prebuilt-rules-downloadable-updates.asciidoc # Reset the historic information at the beginning of each minor version - historic_data = downloadable_updates.read_text() if Version.parse(self.registry_version_str).patch > 1 else '' - downloadable_updates.write_text(historic_data + # noqa: W504 - updates['downloadable-updates.asciidoc']['table_include'] + '\n') + historic_data = downloadable_updates.read_text() if Version.parse(self.registry_version_str).patch > 1 else "" + _ = downloadable_updates.write_text( + historic_data + updates["downloadable-updates.asciidoc"]["table_include"] + "\n" + ) - def add_content_to_table_top(self, file_path: Path, summary_header: str, new_content: str): + def add_content_to_table_top(self, file_path: Path, summary_header: str, new_content: str) -> None: """Insert content at the top of a Markdown table right after the specified header.""" file_contents = file_path.read_text() # Find the header in the file - header = '|Update version |Date | New rules | Updated rules | Notes\n' + header = "|Update version |Date | New rules | Updated rules | Notes\n" header_index = file_contents.find(header) if header_index == -1: @@ -496,7 +557,7 @@ def add_content_to_table_top(self, file_path: Path, summary_header: str, new_con updated_contents = summary_header + f"\n{new_content}\n" + file_contents[insert_position:] # Write the updated contents back to the file - file_path.write_text(updated_contents) + _ = file_path.write_text(updated_contents) def generate(self) -> Path: self.generate_appendix() @@ -510,146 +571,159 @@ def generate(self) -> Path: class IntegrationRuleDetail: """Rule detail page generation.""" - def __init__(self, rule_id: str, rule: dict, changelog: Dict[str, dict], package_str: str): + def __init__( + self, + rule_id: str, + rule: dict[str, Any], + changelog: dict[str, dict[str, Any]], + package_str: str, + ) -> None: self.rule_id = rule_id self.rule = rule self.changelog = changelog self.package = package_str - self.rule_title = f'prebuilt-rule-{self.package}-{name_to_title(self.rule["name"])}' + self.rule_title = f"prebuilt-rule-{self.package}-{name_to_title(self.rule['name'])}" # set some defaults - self.rule.setdefault('max_signals', 100) - self.rule.setdefault('interval', '5m') + self.rule.setdefault("max_signals", 100) + self.rule.setdefault("interval", "5m") - def generate(self, title: str = None) -> str: + def generate(self, title: str | None = None) -> str: """Generate the rule detail page.""" title = title or self.rule_title page = [ AsciiDoc.inline_anchor(title), - AsciiDoc.title(3, self.rule['name']), - '', - self.rule['description'], - '', + AsciiDoc.title(3, self.rule["name"]), + "", + self.rule["description"], + "", self.metadata_str(), - '' + "", ] - if 'note' in self.rule: - page.extend([self.guide_str(), '']) - if 'setup' in self.rule: - page.extend([self.setup_str(), '']) - if 'query' in self.rule: - page.extend([self.query_str(), '']) - if 'threat' in self.rule: - page.extend([self.threat_mapping_str(), '']) + if "note" in self.rule: + page.extend([self.guide_str(), ""]) + if "setup" in self.rule: + page.extend([self.setup_str(), ""]) + if "query" in self.rule: + page.extend([self.query_str(), ""]) + if "threat" in self.rule: + page.extend([self.threat_mapping_str(), ""]) - return '\n'.join(page) + return "\n".join(page) def metadata_str(self) -> str: """Add the metadata section to the rule detail page.""" fields = { - 'type': 'Rule type', - 'index': 'Rule indices', - 'severity': 'Severity', - 'risk_score': 'Risk score', - 'interval': 'Runs every', - 'from': 'Searches indices from', - 'max_signals': 'Maximum alerts per execution', - 'references': 'References', - 'tags': 'Tags', - 'version': 'Version', - 'author': 'Rule authors', - 'license': 'Rule license' + "type": "Rule type", + "index": "Rule indices", + "severity": "Severity", + "risk_score": "Risk score", + "interval": "Runs every", + "from": "Searches indices from", + "max_signals": "Maximum alerts per execution", + "references": "References", + "tags": "Tags", + "version": "Version", + "author": "Rule authors", + "license": "Rule license", } - values = [] + values: list[str] = [] for field, friendly_name in fields.items(): value = self.rule.get(field) or self.changelog.get(field) - if isinstance(value, list): - str_value = f'\n\n{AsciiDoc.bulleted_list(value)}' + if value is None: + str_value = "None" + elif isinstance(value, list): + str_value = f"\n\n{AsciiDoc.bulleted_list(value)}" # type: ignore[reportUnknownArgumentType] else: str_value = str(value) - if field == 'from': - str_value += ' ({ref}/common-options.html#date-math[Date Math format], see also <>)' + if field == "from": + str_value += ( + " ({ref}/common-options.html#date-math[Date Math format], see also <>)" + ) - values.extend([AsciiDoc.bold_kv(friendly_name, str_value), '']) + values.extend([AsciiDoc.bold_kv(friendly_name, str_value), ""]) - return '\n'.join(values) + return "\n".join(values) def guide_str(self) -> str: """Add the guide section to the rule detail page.""" - guide = convert_markdown_to_asciidoc(self.rule['note']) - return f'{AsciiDoc.title(4, "Investigation guide")}\n\n\n{guide}' + guide = convert_markdown_to_asciidoc(self.rule["note"]) + return f"{AsciiDoc.title(4, 'Investigation guide')}\n\n\n{guide}" def setup_str(self) -> str: """Add the setup section to the rule detail page.""" - setup = convert_markdown_to_asciidoc(self.rule['setup']) - return f'{AsciiDoc.title(4, "Setup")}\n\n\n{setup}' + setup = convert_markdown_to_asciidoc(self.rule["setup"]) + return f"{AsciiDoc.title(4, 'Setup')}\n\n\n{setup}" def query_str(self) -> str: """Add the query section to the rule detail page.""" - return f'{AsciiDoc.title(4, "Rule query")}\n\n\n{AsciiDoc.code(self.rule["query"])}' + return f"{AsciiDoc.title(4, 'Rule query')}\n\n\n{AsciiDoc.code(self.rule['query'])}" def threat_mapping_str(self) -> str: """Add the threat mapping section to the rule detail page.""" - values = [AsciiDoc.bold_kv('Framework', 'MITRE ATT&CK^TM^'), ''] + values = [AsciiDoc.bold_kv("Framework", "MITRE ATT&CK^TM^"), ""] - for entry in self.rule['threat']: - tactic = entry['tactic'] + for entry in self.rule["threat"]: + tactic = entry["tactic"] entry_values = [ - AsciiDoc.bulleted('Tactic:'), - AsciiDoc.bulleted(f'Name: {tactic["name"]}', depth=2), - AsciiDoc.bulleted(f'ID: {tactic["id"]}', depth=2), - AsciiDoc.bulleted(f'Reference URL: {tactic["reference"]}', depth=2) + AsciiDoc.bulleted("Tactic:"), + AsciiDoc.bulleted(f"Name: {tactic['name']}", depth=2), + AsciiDoc.bulleted(f"ID: {tactic['id']}", depth=2), + AsciiDoc.bulleted(f"Reference URL: {tactic['reference']}", depth=2), ] - techniques = entry.get('technique', []) + techniques = entry.get("technique", []) for technique in techniques: - entry_values.extend([ - AsciiDoc.bulleted('Technique:'), - AsciiDoc.bulleted(f'Name: {technique["name"]}', depth=2), - AsciiDoc.bulleted(f'ID: {technique["id"]}', depth=2), - AsciiDoc.bulleted(f'Reference URL: {technique["reference"]}', depth=2) - ]) - - subtechniques = technique.get('subtechnique', []) + entry_values.extend( + [ + AsciiDoc.bulleted("Technique:"), + AsciiDoc.bulleted(f"Name: {technique['name']}", depth=2), + AsciiDoc.bulleted(f"ID: {technique['id']}", depth=2), + AsciiDoc.bulleted(f"Reference URL: {technique['reference']}", depth=2), + ] + ) + + subtechniques = technique.get("subtechnique", []) for subtechnique in subtechniques: - entry_values.extend([ - AsciiDoc.bulleted('Sub-technique:'), - AsciiDoc.bulleted(f'Name: {subtechnique["name"]}', depth=2), - AsciiDoc.bulleted(f'ID: {subtechnique["id"]}', depth=2), - AsciiDoc.bulleted(f'Reference URL: {subtechnique["reference"]}', depth=2) - ]) + entry_values.extend( + [ + AsciiDoc.bulleted("Sub-technique:"), + AsciiDoc.bulleted(f"Name: {subtechnique['name']}", depth=2), + AsciiDoc.bulleted(f"ID: {subtechnique['id']}", depth=2), + AsciiDoc.bulleted(f"Reference URL: {subtechnique['reference']}", depth=2), + ] + ) values.extend(entry_values) - return '\n'.join(values) + return "\n".join(values) def name_to_title(name: str) -> str: """Convert a rule name to tile.""" - initial = re.sub(r'[^\w]|_', r'-', name.lower().strip()) - return re.sub(r'-{2,}', '-', initial).strip('-') + initial = re.sub(r"[^\w]|_", r"-", name.lower().strip()) + return re.sub(r"-{2,}", "-", initial).strip("-") def convert_markdown_to_asciidoc(text: str) -> str: """Convert investigation guides and setup content from markdown to asciidoc.""" # Format the content after the stripped headers (#) to bold text with newlines. - markdown_header_pattern = re.compile(r'^(#+)\s*(.*?)$', re.MULTILINE) - text = re.sub(markdown_header_pattern, lambda m: f'\n*{m.group(2).strip()}*\n', text) + markdown_header_pattern = re.compile(r"^(#+)\s*(.*?)$", re.MULTILINE) + text = re.sub(markdown_header_pattern, lambda m: f"\n*{m.group(2).strip()}*\n", text) # Convert Markdown links to AsciiDoc format - markdown_link_pattern = re.compile(r'\[([^\]]+)\]\(([^)]+)\)') - text = re.sub(markdown_link_pattern, lambda m: f'{m.group(2)}[{m.group(1)}]', text) - - return text + markdown_link_pattern = re.compile(r"\[([^\]]+)\]\(([^)]+)\)") + return re.sub(markdown_link_pattern, lambda m: f"{m.group(2)}[{m.group(1)}]", text) @dataclass class UpdateEntry: """A class schema for downloadable update entries.""" + update_version: str date: str new_rules: int @@ -661,22 +735,23 @@ class UpdateEntry: @dataclass class DownloadableUpdates: """A class for managing downloadable updates.""" - packages: List[UpdateEntry] + + packages: list[UpdateEntry] @classmethod - def load_updates(cls): + def load_updates(cls) -> "DownloadableUpdates": """Load the package.""" - prebuilt = load_etc_dump("downloadable_updates.json") - packages = [UpdateEntry(**entry) for entry in prebuilt['packages']] + prebuilt = load_etc_dump(["downloadable_updates.json"]) + packages = [UpdateEntry(**entry) for entry in prebuilt["packages"]] return cls(packages) - def save_updates(self): + def save_updates(self) -> None: """Save the package.""" sorted_package = sorted(self.packages, key=lambda entry: Version.parse(entry.update_version), reverse=True) - data = {'packages': [asdict(entry) for entry in sorted_package]} - save_etc_dump(data, "downloadable_updates.json") + data = {"packages": [asdict(entry) for entry in sorted_package]} + save_etc_dump(data, ["downloadable_updates.json"]) - def add_entry(self, entry: UpdateEntry, overwrite: bool = False): + def add_entry(self, entry: UpdateEntry, overwrite: bool = False) -> None: """Add an entry to the package.""" existing_entry_index = -1 for index, existing_entry in enumerate(self.packages): @@ -696,66 +771,71 @@ class MDX: """A class for generating Markdown content.""" @classmethod - def bold(cls, value: str): + def bold(cls, value: str) -> str: """Return a bold str in Markdown.""" - return f'**{value}**' + return f"**{value}**" @classmethod - def bold_kv(cls, key: str, value: str): + def bold_kv(cls, key: str, value: str) -> str: """Return a bold key-value pair in Markdown.""" - return f'**{key}**: {value}' + return f"**{key}**: {value}" @classmethod - def description_list(cls, value: Dict[str, str], linesep='\n\n'): + def description_list(cls, value: dict[str, str], linesep: str = "\n\n") -> str: """Create a description list in Markdown.""" - return f'{linesep}'.join(f'**{k}**:\n\n{v}' for k, v in value.items()) + return f"{linesep}".join(f"**{k}**:\n\n{v}" for k, v in value.items()) @classmethod - def bulleted(cls, value: str, depth=1): + def bulleted(cls, value: str, depth: int = 1) -> str: """Create a bulleted list item with a specified depth.""" - return f'{" " * (depth - 1)}* {value}' + return f"{' ' * (depth - 1)}* {value}" @classmethod - def bulleted_list(cls, values: Iterable): + def bulleted_list(cls, values: list[str]) -> str: """Create a bulleted list from an iterable.""" - return '\n* ' + '\n* '.join(values) + return "\n* " + "\n* ".join(values) @classmethod - def code(cls, value: str, language='js'): + def code(cls, value: str, language: str = "js") -> str: """Return a code block with the specified language.""" return f"```{language}\n{value}```" @classmethod - def title(cls, depth: int, value: str): + def title(cls, depth: int, value: str) -> str: """Create a title with the specified depth.""" - return f'{"#" * depth} {value}' + return f"{'#' * depth} {value}" @classmethod - def inline_anchor(cls, value: str): + def inline_anchor(cls, value: str) -> str: """Create an inline anchor with the specified value.""" return f'' @classmethod - def table(cls, data: dict) -> str: + def table(cls, data: dict[str, Any]) -> str: """Create a table from a dictionary.""" - entries = [f'| {k} | {v}' for k, v in data.items()] - table = ['|---|---|'] + entries - return '\n'.join(table) + entries = [f"| {k} | {v}" for k, v in data.items()] + table = ["|---|---|", *entries] + return "\n".join(table) class IntegrationSecurityDocsMDX: """Generate docs for prebuilt rules in Elastic documentation using MDX.""" - def __init__(self, release_version: str, directory: Path, overwrite: bool = False, - historical_package: Optional[Dict[str, dict]] = - None, new_package: Optional[Dict[str, TOMLRule]] = None, - note: Optional[str] = "Rule Updates."): + def __init__( # noqa: PLR0913 + self, + release_version: str, + directory: Path, + overwrite: bool = False, + new_package: Package | None = None, + historical_package: dict[str, Any] | None = None, + note: str | None = "Rule Updates.", + ) -> None: self.historical_package = historical_package self.new_package = new_package self.rule_changes = self.get_rule_changes() - self.included_rules = list(itertools.chain(self.rule_changes["new"], - self.rule_changes["updated"], - self.rule_changes["deprecated"])) + self.included_rules = list( + itertools.chain(self.rule_changes["new"], self.rule_changes["updated"], self.rule_changes["deprecated"]) + ) self.release_version_str, self.base_name, self.prebuilt_rule_base = self.parse_release(release_version) self.package_directory = directory / self.base_name @@ -768,62 +848,69 @@ def __init__(self, release_version: str, directory: Path, overwrite: bool = Fals self.package_directory.mkdir(parents=True, exist_ok=overwrite) @staticmethod - def parse_release(release_version: str) -> (str, str, str): + def parse_release(release_version_val: str) -> tuple[str, str, str]: """Parse the release version into a string, base name, and prebuilt rule base.""" - release_version = Version.parse(release_version) - short_release_version = [str(n) for n in release_version[:3]] - release_version_str = '.'.join(short_release_version) + release_version = Version.parse(release_version_val) + parts = release_version[:3] + short_release_version = [str(n) for n in parts] # type: ignore[reportOptionalIterable] + release_version_str = ".".join(short_release_version) base_name = "-".join(short_release_version) - prebuilt_rule_base = f'prebuilt-rule-{base_name}' + prebuilt_rule_base = f"prebuilt-rule-{base_name}" return release_version_str, base_name, prebuilt_rule_base - def get_rule_changes(self): + def get_rule_changes(self) -> dict[str, list[TOMLRule | DeprecatedRule]]: """Compare the rules from the new_package against rules in the historical_package.""" - rule_changes = defaultdict(list) - rule_changes["new"], rule_changes["updated"], rule_changes["deprecated"] = [], [], [] + rule_changes: dict[str, list[TOMLRule | DeprecatedRule]] = { + "new": [], + "updated": [], + "deprecated": [], + } - historical_rule_ids = set(self.historical_package.keys()) + historical_package: dict[str, Any] = self.historical_package or {} + historical_rule_ids: set[str] = set(historical_package.keys()) - # Identify new and updated rules - for rule in self.new_package.rules: - rule_to_api_format = rule.contents.to_api_format() + if self.new_package: + # Identify new and updated rules + for rule in self.new_package.rules: + rule_to_api_format = rule.contents.to_api_format() - latest_version = rule_to_api_format["version"] - rule_id = f'{rule.id}_{latest_version}' + latest_version = rule_to_api_format["version"] + rule_id = f"{rule.id}_{latest_version}" - if rule_id not in historical_rule_ids and latest_version == 1: - rule_changes['new'].append(rule) - elif rule_id not in historical_rule_ids: - rule_changes['updated'].append(rule) + if rule_id not in historical_rule_ids and latest_version == 1: + rule_changes["new"].append(rule) + elif rule_id not in historical_rule_ids: + rule_changes["updated"].append(rule) # Identify deprecated rules # if rule is in the historical but not in the current package, its deprecated - deprecated_rule_ids = [] - for _, content in self.historical_package.items(): + deprecated_rule_ids: list[str] = [] + for content in historical_package.values(): rule_id = content["attributes"]["rule_id"] - if rule_id in self.new_package.deprecated_rules.id_map.keys(): + if self.new_package and rule_id in self.new_package.deprecated_rules.id_map: deprecated_rule_ids.append(rule_id) deprecated_rule_ids = list(set(deprecated_rule_ids)) for rule_id in deprecated_rule_ids: - rule_changes['deprecated'].append(self.new_package.deprecated_rules.id_map[rule_id]) + if self.new_package: + rule_changes["deprecated"].append(self.new_package.deprecated_rules.id_map[rule_id]) return dict(rule_changes) - def generate_current_rule_summary(self): + def generate_current_rule_summary(self) -> None: """Generate a summary of all available current rules in the latest package.""" - slug = f'prebuilt-rules-{self.base_name}-all-available-summary.mdx' + slug = f"prebuilt-rules-{self.base_name}-all-available-summary.mdx" summary = self.package_directory / slug - title = f'Latest rules for Stack Version ^{self.release_version_str}' + title = f"Latest rules for Stack Version ^{self.release_version_str}" summary_header = textwrap.dedent(f""" --- id: {slug} slug: /security-rules/{slug} title: {title} - date: {datetime.today().strftime('%Y-%d-%m')} + date: {datetime.today().strftime("%Y-%d-%m")} tags: ["rules", "security", "detection-rules"] --- @@ -833,25 +920,29 @@ def generate_current_rule_summary(self): | Rule | Description | Tags | Version |---|---|---|---| - """).lstrip() + """).lstrip() # noqa: DTZ002 - rule_entries = [] - for rule in self.new_package.rules: - title_name = name_to_title(rule.name) - to_api_format = rule.contents.to_api_format() - tags = ", ".join(to_api_format["tags"]) - rule_entries.append(f'| [{title_name}](rules/{self.prebuilt_rule_base}-{name_to_title(rule.name)}.mdx) | ' - f'{to_api_format["description"]} | {tags} | ' - f'{to_api_format["version"]}') + rule_entries: list[str] = [] + + if self.new_package: + for rule in self.new_package.rules: + title_name = name_to_title(rule.name) + to_api_format = rule.contents.to_api_format() + tags = ", ".join(to_api_format["tags"]) + rule_entries.append( + f"| [{title_name}](rules/{self.prebuilt_rule_base}-{name_to_title(rule.name)}.mdx) | " + f"{to_api_format['description']} | {tags} | " + f"{to_api_format['version']}" + ) rule_entries = sorted(rule_entries) - rule_entries = '\n'.join(rule_entries) + rule_entries_str = "\n".join(rule_entries) - summary.write_text(summary_header + rule_entries) + _ = summary.write_text(summary_header + rule_entries_str) - def generate_update_summary(self): + def generate_update_summary(self) -> None: """Generate a summary of all rule updates based on the latest package.""" - slug = f'prebuilt-rules-{self.base_name}-update-summary.mdx' + slug = f"prebuilt-rules-{self.base_name}-update-summary.mdx" summary = self.package_directory / slug title = "Current Available Rules" @@ -860,7 +951,7 @@ def generate_update_summary(self): id: {slug} slug: /security-rules/{slug} title: {title} - date: {datetime.today().strftime('%Y-%d-%m')} + date: {datetime.today().strftime("%Y-%d-%m")} tags: ["rules", "security", "detection-rules"] --- @@ -870,54 +961,60 @@ def generate_update_summary(self): | Rule | Description | Status | Version |---|---|---|---| - """).lstrip() + """).lstrip() # noqa: DTZ002 - rule_entries = [] + rule_entries: list[str] = [] new_rule_id_list = [rule.id for rule in self.rule_changes["new"]] updated_rule_id_list = [rule.id for rule in self.rule_changes["updated"]] for rule in self.included_rules: + if not rule.name: + raise ValueError("No rule name found") title_name = name_to_title(rule.name) - status = 'new' if rule.id in new_rule_id_list else 'update' if rule.id in updated_rule_id_list \ - else 'deprecated' + status = ( + "new" if rule.id in new_rule_id_list else "update" if rule.id in updated_rule_id_list else "deprecated" + ) to_api_format = rule.contents.to_api_format() - rule_entries.append(f'| [{title_name}](rules/{self.prebuilt_rule_base}-{name_to_title(rule.name)}.mdx) | ' - f'{to_api_format["description"]} | {status} | ' - f'{to_api_format["version"]}') + rule_entries.append( + f"| [{title_name}](rules/{self.prebuilt_rule_base}-{name_to_title(rule.name)}.mdx) | " + f"{to_api_format['description']} | {status} | " + f"{to_api_format['version']}" + ) rule_entries = sorted(rule_entries) - rule_entries = '\n'.join(rule_entries) + rule_entries_str = "\n".join(rule_entries) - summary.write_text(summary_header + rule_entries) + _ = summary.write_text(summary_header + rule_entries_str) - def generate_rule_details(self): + def generate_rule_details(self) -> None: """Generate a markdown file for each rule.""" rules_dir = self.package_directory / "rules" rules_dir.mkdir(exist_ok=True) - for rule in self.new_package.rules: - slug = f'{self.prebuilt_rule_base}-{name_to_title(rule.name)}.mdx' - rule_detail = IntegrationRuleDetailMDX(rule.id, rule.contents.to_api_format(), {}, self.base_name) - rule_path = rules_dir / slug - tags = ', '.join(f"\"{tag}\"" for tag in rule.contents.data.tags) - frontmatter = textwrap.dedent(f""" - --- - id: {slug} - slug: /security-rules/{slug} - title: {rule.name} - date: {datetime.today().strftime('%Y-%d-%m')} - tags: [{tags}] - --- - - """).lstrip() - rule_path.write_text(frontmatter + rule_detail.generate()) - - def generate_downloadable_updates_summary(self): + if self.new_package: + for rule in self.new_package.rules: + slug = f"{self.prebuilt_rule_base}-{name_to_title(rule.name)}.mdx" + rule_detail = IntegrationRuleDetailMDX(rule.id, rule.contents.to_api_format(), {}, self.base_name) + rule_path = rules_dir / slug + tags = ", ".join(f'"{tag}"' for tag in rule.contents.data.tags) # type: ignore[reportOptionalIterable] + frontmatter = textwrap.dedent(f""" + --- + id: {slug} + slug: /security-rules/{slug} + title: {rule.name} + date: {datetime.today().strftime("%Y-%d-%m")} + tags: [{tags}] + --- + + """).lstrip() # noqa: DTZ002 + _ = rule_path.write_text(frontmatter + rule_detail.generate()) + + def generate_downloadable_updates_summary(self) -> None: """Generate a summary of all the downloadable updates.""" - docs_url = 'https://www.elastic.co/guide/en/security/current/rules-ui-management.html#update-prebuilt-rules' - slug = 'prebuilt-rules-downloadable-packages-summary.mdx' + docs_url = "https://www.elastic.co/guide/en/security/current/rules-ui-management.html#update-prebuilt-rules" + slug = "prebuilt-rules-downloadable-packages-summary.mdx" title = "Downloadable rule updates" summary = self.package_directory / slug - today = datetime.today().strftime('%d %b %Y') + today = datetime.today().strftime("%d %b %Y") # noqa: DTZ002 package_list = DownloadableUpdates.load_updates() ref = f"./prebuilt-rules-{self.base_name}-update-summary.mdx" @@ -927,8 +1024,8 @@ def generate_downloadable_updates_summary(self): date=today, new_rules=len(self.rule_changes["new"]), updated_rules=len(self.rule_changes["updated"]), - note=self.note, - url=ref + note=self.note or "", + url=ref, ) package_list.add_entry(new_entry, self.overwrite) @@ -941,7 +1038,7 @@ def generate_downloadable_updates_summary(self): id: {slug} slug: /security-rules/{slug} title: {title} - date: {datetime.today().strftime('%Y-%d-%m')} + date: {datetime.today().strftime("%Y-%d-%m")} tags: ["rules", "security", "detection-rules"] --- @@ -955,15 +1052,22 @@ def generate_downloadable_updates_summary(self): |Update version |Date | New rules | Updated rules | Notes |---|---|---|---|---| - """).lstrip() + """).lstrip() # noqa: DTZ002 - entries = [] - for entry in sorted(package_list.packages, key=lambda entry: Version.parse(entry.update_version), reverse=True): - entries.append(f'| [{entry.update_version}]({entry.url}) | {today} |' - f' {entry.new_rules} | {entry.updated_rules} | {entry.note}| ') + entries: list[str] = [ + ( + f"| [{entry.update_version}]({entry.url}) | {today} |" + f" {entry.new_rules} | {entry.updated_rules} | {entry.note}| " + ) + for entry in sorted( + package_list.packages, + key=lambda entry: Version.parse(entry.update_version), + reverse=True, + ) + ] - entries = '\n'.join(entries) - summary.write_text(summary_header + entries) + entries_str = "\n".join(entries) + _ = summary.write_text(summary_header + entries_str) def generate(self) -> Path: """Generate the updates.""" @@ -986,7 +1090,13 @@ def generate(self) -> Path: class IntegrationRuleDetailMDX: """Generates a rule detail page in Markdown.""" - def __init__(self, rule_id: str, rule: dict, changelog: Dict[str, dict], package_str: str): + def __init__( + self, + rule_id: str, + rule: dict[str, Any], + changelog: dict[str, dict[str, Any]], + package_str: str, + ) -> None: """Initialize with rule ID, rule details, changelog, and package string. >>> rule_file = "/path/to/rule.toml" @@ -999,30 +1109,23 @@ def __init__(self, rule_id: str, rule: dict, changelog: Dict[str, dict], package self.rule = rule self.changelog = changelog self.package = package_str - self.rule_title = f'prebuilt-rule-{self.package}-{name_to_title(self.rule["name"])}' + self.rule_title = f"prebuilt-rule-{self.package}-{name_to_title(self.rule['name'])}" # set some defaults - self.rule.setdefault('max_signals', 100) - self.rule.setdefault('interval', '5m') + self.rule.setdefault("max_signals", 100) + self.rule.setdefault("interval", "5m") def generate(self) -> str: """Generate the rule detail page in Markdown.""" - page = [ - MDX.title(1, self.rule["name"]), - '', - self.rule['description'], - '', - self.metadata_str(), - '' - ] - if 'note' in self.rule: - page.extend([self.guide_str(), '']) - if 'query' in self.rule: - page.extend([self.query_str(), '']) - if 'threat' in self.rule: - page.extend([self.threat_mapping_str(), '']) + page = [MDX.title(1, self.rule["name"]), "", self.rule["description"], "", self.metadata_str(), ""] + if "note" in self.rule: + page.extend([self.guide_str(), ""]) + if "query" in self.rule: + page.extend([self.query_str(), ""]) + if "threat" in self.rule: + page.extend([self.threat_mapping_str(), ""]) - return '\n'.join(page) + return "\n".join(page) def metadata_str(self) -> str: """Generate the metadata section for the rule detail page.""" @@ -1030,73 +1133,79 @@ def metadata_str(self) -> str: date_math_doc = "https://www.elastic.co/guide/en/elasticsearch/reference/current/common-options.html#date-math" loopback_doc = "https://www.elastic.co/guide/en/security/current/rules-ui-create.html#rule-schedule" fields = { - 'type': 'Rule type', - 'index': 'Rule indices', - 'severity': 'Severity', - 'risk_score': 'Risk score', - 'interval': 'Runs every', - 'from': 'Searches indices from', - 'max_signals': 'Maximum alerts per execution', - 'references': 'References', - 'tags': 'Tags', - 'version': 'Version', - 'author': 'Rule authors', - 'license': 'Rule license' + "type": "Rule type", + "index": "Rule indices", + "severity": "Severity", + "risk_score": "Risk score", + "interval": "Runs every", + "from": "Searches indices from", + "max_signals": "Maximum alerts per execution", + "references": "References", + "tags": "Tags", + "version": "Version", + "author": "Rule authors", + "license": "Rule license", } - values = [] + values: list[str] = [] for field, friendly_name in fields.items(): value = self.rule.get(field) or self.changelog.get(field) - if isinstance(value, list): - str_value = MDX.bulleted_list(value) + if value is None: + str_value = "NONE" + elif isinstance(value, list): + str_value = MDX.bulleted_list(value) # type: ignore[reportUnknownArgumentType] else: str_value = str(value) - if field == 'from': - str_value += f' ([Date Math format]({date_math_doc}), [Additional look-back time]({loopback_doc}))' + if field == "from": + str_value += f" ([Date Math format]({date_math_doc}), [Additional look-back time]({loopback_doc}))" values.append(MDX.bold_kv(friendly_name, str_value)) - return '\n\n'.join(values) + return "\n\n".join(values) def guide_str(self) -> str: """Generate the investigation guide section for the rule detail page.""" - return f'{MDX.title(2, "Investigation guide")}\n\n{MDX.code(self.rule["note"], "markdown")}' + return f"{MDX.title(2, 'Investigation guide')}\n\n{MDX.code(self.rule['note'], 'markdown')}" def query_str(self) -> str: """Generate the rule query section for the rule detail page.""" - return f'{MDX.title(2, "Rule query")}\n\n{MDX.code(self.rule["query"], "sql")}' + return f"{MDX.title(2, 'Rule query')}\n\n{MDX.code(self.rule['query'], 'sql')}" def threat_mapping_str(self) -> str: """Generate the threat mapping section for the rule detail page.""" - values = [MDX.bold_kv('Framework', 'MITRE ATT&CK^TM^')] + values = [MDX.bold_kv("Framework", "MITRE ATT&CK^TM^")] - for entry in self.rule['threat']: - tactic = entry['tactic'] + for entry in self.rule["threat"]: + tactic = entry["tactic"] entry_values = [ - MDX.bulleted(MDX.bold('Tactic:')), - MDX.bulleted(f'Name: {tactic["name"]}', depth=2), - MDX.bulleted(f'ID: {tactic["id"]}', depth=2), - MDX.bulleted(f'Reference URL: {tactic["reference"]}', depth=2) + MDX.bulleted(MDX.bold("Tactic:")), + MDX.bulleted(f"Name: {tactic['name']}", depth=2), + MDX.bulleted(f"ID: {tactic['id']}", depth=2), + MDX.bulleted(f"Reference URL: {tactic['reference']}", depth=2), ] - techniques = entry.get('technique', []) + techniques = entry.get("technique", []) for technique in techniques: - entry_values.extend([ - MDX.bulleted('Technique:'), - MDX.bulleted(f'Name: {technique["name"]}', depth=3), - MDX.bulleted(f'ID: {technique["id"]}', depth=3), - MDX.bulleted(f'Reference URL: {technique["reference"]}', depth=3) - ]) - - subtechniques = technique.get('subtechnique', []) + entry_values.extend( + [ + MDX.bulleted("Technique:"), + MDX.bulleted(f"Name: {technique['name']}", depth=3), + MDX.bulleted(f"ID: {technique['id']}", depth=3), + MDX.bulleted(f"Reference URL: {technique['reference']}", depth=3), + ] + ) + + subtechniques = technique.get("subtechnique", []) for subtechnique in subtechniques: - entry_values.extend([ - MDX.bulleted('Sub-technique:'), - MDX.bulleted(f'Name: {subtechnique["name"]}', depth=3), - MDX.bulleted(f'ID: {subtechnique["id"]}', depth=3), - MDX.bulleted(f'Reference URL: {subtechnique["reference"]}', depth=4) - ]) + entry_values.extend( + [ + MDX.bulleted("Sub-technique:"), + MDX.bulleted(f"Name: {subtechnique['name']}", depth=3), + MDX.bulleted(f"ID: {subtechnique['id']}", depth=3), + MDX.bulleted(f"Reference URL: {subtechnique['reference']}", depth=4), + ] + ) values.extend(entry_values) - return '\n'.join(values) + return "\n".join(values) diff --git a/detection_rules/ecs.py b/detection_rules/ecs.py index e3fe2a66247..5417b642521 100644 --- a/detection_rules/ecs.py +++ b/detection_rules/ecs.py @@ -4,32 +4,33 @@ # 2.0. """ECS Schemas management.""" + import copy -import glob import json import os import shutil +from pathlib import Path +from typing import Any -import eql -import eql.types +import eql # type: ignore[reportMissingTypeStubs] +import eql.types # type: ignore[reportMissingTypeStubs] import requests -from semver import Version import yaml +from semver import Version from .config import CUSTOM_RULES_DIR, parse_rules_config from .custom_schemas import get_custom_schemas from .integrations import load_integrations_schemas -from .utils import (DateTimeEncoder, cached, get_etc_path, gzip_compress, - load_etc_dump, read_gzip, unzip) +from .utils import DateTimeEncoder, cached, get_etc_path, gzip_compress, load_etc_dump, read_gzip, unzip ECS_NAME = "ecs_schemas" -ECS_SCHEMAS_DIR = get_etc_path(ECS_NAME) +ECS_SCHEMAS_DIR = get_etc_path([ECS_NAME]) ENDPOINT_NAME = "endpoint_schemas" -ENDPOINT_SCHEMAS_DIR = get_etc_path(ENDPOINT_NAME) +ENDPOINT_SCHEMAS_DIR = get_etc_path([ENDPOINT_NAME]) RULES_CONFIG = parse_rules_config() -def add_field(schema, name, info): +def add_field(schema: dict[str, Any], name: str, info: Any) -> None: """Nest a dotted field within a dictionary.""" if "." not in name: schema[name] = info @@ -41,7 +42,7 @@ def add_field(schema, name, info): add_field(schema, remaining, info) -def _recursive_merge(existing, new, depth=0): +def _recursive_merge(existing: dict[str, Any], new: dict[str, Any], depth: int = 0) -> dict[str, Any]: """Return an existing dict merged into a new one.""" for key, value in existing.items(): if isinstance(value, dict): @@ -49,33 +50,33 @@ def _recursive_merge(existing, new, depth=0): new = copy.deepcopy(new) node = new.setdefault(key, {}) - _recursive_merge(value, node, depth + 1) + _ = _recursive_merge(value, node, depth + 1) # type: ignore[reportUnknownArgumentType] else: new[key] = value return new -def get_schema_files(): +def get_schema_files() -> list[Path]: """Get schema files from ecs directory.""" - return glob.glob(os.path.join(ECS_SCHEMAS_DIR, '*', '*.json.gz'), recursive=True) + return list(ECS_SCHEMAS_DIR.glob("**/*.json.gz")) -def get_schema_map(): +def get_schema_map() -> dict[str, Any]: """Get local schema files by version.""" - schema_map = {} + schema_map: dict[str, Any] = {} for file_name in get_schema_files(): path, name = os.path.split(file_name) - name = name.split('.')[0] - version = os.path.basename(path) + name = name.split(".")[0] + version = Path(path).name schema_map.setdefault(version, {})[name] = file_name return schema_map @cached -def get_schemas(): +def get_schemas() -> dict[str, Any]: """Get local schemas.""" schema_map = get_schema_map() @@ -86,40 +87,41 @@ def get_schemas(): return schema_map -def get_max_version(include_master=False): +def get_max_version(include_master: bool = False) -> str: """Get maximum available schema version.""" versions = get_schema_map().keys() - if include_master and any([v.startswith('master') for v in versions]): - return list(ECS_SCHEMAS_DIR.glob('master*'))[0].name + if include_master and any(v.startswith("master") for v in versions): + paths = list(ECS_SCHEMAS_DIR.glob("master*")) + return paths[0].name - return str(max([Version.parse(v) for v in versions if not v.startswith('master')])) + return str(max([Version.parse(v) for v in versions if not v.startswith("master")])) @cached -def get_schema(version=None, name='ecs_flat'): +def get_schema(version: str | None = None, name: str = "ecs_flat") -> dict[str, Any]: """Get schema by version.""" - if version == 'master': + if version == "master": version = get_max_version(include_master=True) return get_schemas()[version or str(get_max_version())][name] @cached -def get_eql_schema(version=None, index_patterns=None): +def get_eql_schema(version: str | None = None, index_patterns: list[str] | None = None) -> dict[str, Any]: """Return schema in expected format for eql.""" - schema = get_schema(version, name='ecs_flat') - str_types = ('text', 'ip', 'keyword', 'date', 'object', 'geo_point') - num_types = ('float', 'integer', 'long') + schema = get_schema(version, name="ecs_flat") + str_types = ("text", "ip", "keyword", "date", "object", "geo_point") + num_types = ("float", "integer", "long") schema = schema.copy() - def convert_type(t): - return 'string' if t in str_types else 'number' if t in num_types else 'boolean' + def convert_type(t: str) -> str: + return "string" if t in str_types else "number" if t in num_types else "boolean" - converted = {} + converted: dict[str, Any] = {} for field, schema_info in schema.items(): - field_type = schema_info.get('type', '') + field_type = schema_info.get("type", "") add_field(converted, field, convert_type(field_type)) # add non-ecs schema @@ -141,21 +143,21 @@ def convert_type(t): return converted -def flatten(schema): - flattened = {} +def flatten(schema: dict[str, Any]) -> dict[str, Any]: + flattened: dict[str, Any] = {} for k, v in schema.items(): if isinstance(v, dict): - flattened.update((k + "." + vk, vv) for vk, vv in flatten(v).items()) + flattened.update((k + "." + vk, vv) for vk, vv in flatten(v).items()) # type: ignore[reportUnknownArgumentType] else: flattened[k] = v return flattened @cached -def get_all_flattened_schema() -> dict: +def get_all_flattened_schema() -> dict[str, Any]: """Load all schemas into a flattened dictionary.""" - all_flattened_schema = {} - for _, schema in get_non_ecs_schema().items(): + all_flattened_schema: dict[str, Any] = {} + for schema in get_non_ecs_schema().values(): all_flattened_schema.update(flatten(schema)) ecs_schemas = get_schemas() @@ -163,12 +165,12 @@ def get_all_flattened_schema() -> dict: for index, info in ecs_schemas[version]["ecs_flat"].items(): all_flattened_schema.update({index: info["type"]}) - for _, integration_schema in load_integrations_schemas().items(): - for index, index_schema in integration_schema.items(): + for integration_schema in load_integrations_schemas().values(): + for index_schema in integration_schema.values(): # Detect if ML integration if "jobs" in index_schema: ml_schemas = {k: v for k, v in index_schema.items() if k != "jobs"} - for _, ml_schema in ml_schemas.items(): + for ml_schema in ml_schemas.values(): all_flattened_schema.update(flatten(ml_schema)) else: all_flattened_schema.update(flatten(index_schema)) @@ -177,33 +179,33 @@ def get_all_flattened_schema() -> dict: @cached -def get_non_ecs_schema(): +def get_non_ecs_schema() -> Any: """Load non-ecs schema.""" - return load_etc_dump('non-ecs-schema.json') + return load_etc_dump(["non-ecs-schema.json"]) @cached -def get_custom_index_schema(index_name: str, stack_version: str = None): +def get_custom_index_schema(index_name: str, stack_version: str | None = None) -> Any: """Load custom schema.""" custom_schemas = get_custom_schemas(stack_version) index_schema = custom_schemas.get(index_name, {}) - ccs_schema = custom_schemas.get(index_name.replace('::', ':').split(":", 1)[-1], {}) + ccs_schema = custom_schemas.get(index_name.replace("::", ":").split(":", 1)[-1], {}) index_schema.update(ccs_schema) return index_schema @cached -def get_index_schema(index_name): +def get_index_schema(index_name: str) -> Any: """Load non-ecs schema.""" non_ecs_schema = get_non_ecs_schema() index_schema = non_ecs_schema.get(index_name, {}) - ccs_schema = non_ecs_schema.get(index_name.replace('::', ':').split(":", 1)[-1], {}) + ccs_schema = non_ecs_schema.get(index_name.replace("::", ":").split(":", 1)[-1], {}) index_schema.update(ccs_schema) return index_schema -def flatten_multi_fields(schema): - converted = {} +def flatten_multi_fields(schema: dict[str, Any]) -> dict[str, Any]: + converted: dict[str, Any] = {} for field, info in schema.items(): converted[field] = info["type"] for subfield in info.get("multi_fields", []): @@ -213,43 +215,49 @@ def flatten_multi_fields(schema): class KqlSchema2Eql(eql.Schema): - type_mapping = { + type_mapping = { # noqa: RUF012 "keyword": eql.types.TypeHint.String, "ip": eql.types.TypeHint.String, "float": eql.types.TypeHint.Numeric, - # "double": eql.types.TypeHint.Numeric, - # "long": eql.types.TypeHint.Numeric, - # "short": eql.types.TypeHint.Numeric, "integer": eql.types.TypeHint.Numeric, "boolean": eql.types.TypeHint.Boolean, } - def __init__(self, kql_schema): + def __init__(self, kql_schema: dict[str, Any]) -> None: self.kql_schema = kql_schema - eql.Schema.__init__(self, {}, allow_any=True, allow_generic=False, allow_missing=False) + eql.Schema.__init__(self, {}, allow_any=True, allow_generic=False, allow_missing=False) # type: ignore[reportUnknownMemberType] - def validate_event_type(self, event_type): + def validate_event_type(self, _: Any) -> bool: # type: ignore[reportIncompatibleMethodOverride] # allow all event types to fill in X: # `X` where .... return True - def get_event_type_hint(self, event_type, path): - from kql.parser import elasticsearch_type_family + def get_event_type_hint(self, _: Any, path: list[str]) -> tuple[Any, None]: # type: ignore[reportIncompatibleMethodOverride] + from kql.parser import elasticsearch_type_family # type: ignore[reportMissingTypeStubs] dotted = ".".join(path) elasticsearch_type = self.kql_schema.get(dotted) + if not elasticsearch_type: + return None, None + es_type_family = elasticsearch_type_family(elasticsearch_type) eql_hint = self.type_mapping.get(es_type_family) if eql_hint is not None: return eql_hint, None + return None, None + @cached -def get_kql_schema(version=None, indexes=None, beat_schema=None) -> dict: +def get_kql_schema( + version: str | None = None, + indexes: list[str] | None = None, + beat_schema: dict[str, Any] | None = None, +) -> dict[str, Any]: """Get schema for KQL.""" - indexes = indexes or () - converted = flatten_multi_fields(get_schema(version, name='ecs_flat')) + indexes = indexes or [] + converted = flatten_multi_fields(get_schema(version, name="ecs_flat")) # non-ecs schema for index_name in indexes: @@ -269,70 +277,77 @@ def get_kql_schema(version=None, indexes=None, beat_schema=None) -> dict: return converted -def download_schemas(refresh_master=True, refresh_all=False, verbose=True): +def download_schemas(refresh_master: bool = True, refresh_all: bool = False, verbose: bool = True) -> None: """Download additional schemas from ecs releases.""" existing = [Version.parse(v) for v in get_schema_map()] if not refresh_all else [] - url = 'https://api.github.com/repos/elastic/ecs/releases' - releases = requests.get(url) + url = "https://api.github.com/repos/elastic/ecs/releases" + releases = requests.get(url, timeout=30) for release in releases.json(): - version = Version.parse(release.get('tag_name', '').lstrip('v')) + version = Version.parse(release.get("tag_name", "").lstrip("v")) # we don't ever want beta if not version or version < Version.parse("1.0.1") or version in existing: continue - schema_dir = os.path.join(ECS_SCHEMAS_DIR, str(version)) + schema_dir = ECS_SCHEMAS_DIR / str(version) + schema_dir.mkdir(exist_ok=True) - with unzip(requests.get(release['zipball_url']).content) as archive: + resp = requests.get(release["zipball_url"], timeout=30) + with unzip(resp.content) as archive: name_list = archive.namelist() base = name_list[0] - # members = [m for m in name_list if m.startswith('{}{}/'.format(base, 'use-cases')) and m.endswith('.yml')] - members = ['{}generated/ecs/ecs_flat.yml'.format(base), '{}generated/ecs/ecs_nested.yml'.format(base)] - saved = [] + members = [f"{base}generated/ecs/ecs_flat.yml", f"{base}generated/ecs/ecs_nested.yml"] + saved: list[str] = [] for member in members: - file_name = os.path.basename(member) - os.makedirs(schema_dir, exist_ok=True) + file_name = Path(member).name # load as yaml, save as json contents = yaml.safe_load(archive.read(member)) out_file = file_name.replace(".yml", ".json.gz") compressed = gzip_compress(json.dumps(contents, sort_keys=True, cls=DateTimeEncoder)) - new_path = get_etc_path(ECS_NAME, str(version), out_file) - with open(new_path, 'wb') as f: - f.write(compressed) + new_path = get_etc_path([ECS_NAME, str(version), out_file]) + with new_path.open("wb") as f: + _ = f.write(compressed) saved.append(out_file) if verbose: - print('Saved files to {}: \n\t- {}'.format(schema_dir, '\n\t- '.join(saved))) + print("Saved files to {}: \n\t- {}".format(schema_dir, "\n\t- ".join(saved))) # handle working master separately if refresh_master: - master_ver = requests.get('https://raw.githubusercontent.com/elastic/ecs/master/version') + master_ver = requests.get( + "https://raw.githubusercontent.com/elastic/ecs/master/version", + timeout=30, + ) master_ver = Version.parse(master_ver.text.strip()) - master_schema = requests.get('https://raw.githubusercontent.com/elastic/ecs/master/generated/ecs/ecs_flat.yml') + master_schema = requests.get( + "https://raw.githubusercontent.com/elastic/ecs/master/generated/ecs/ecs_flat.yml", + timeout=30, + ) master_schema = yaml.safe_load(master_schema.text) # prepend with underscore so that we can differentiate the fact that this is a working master version # but first clear out any existing masters, since we only ever want 1 at a time - existing_master = glob.glob(os.path.join(ECS_SCHEMAS_DIR, 'master_*')) + existing_master = ECS_SCHEMAS_DIR.glob("master_*") for m in existing_master: shutil.rmtree(m, ignore_errors=True) - master_dir = "master_{}".format(master_ver) - os.makedirs(get_etc_path(ECS_NAME, master_dir), exist_ok=True) + master_dir = f"master_{master_ver}" + master_dir_path = get_etc_path([ECS_NAME, master_dir]) + master_dir_path.mkdir(exist_ok=True) compressed = gzip_compress(json.dumps(master_schema, sort_keys=True, cls=DateTimeEncoder)) - new_path = get_etc_path(ECS_NAME, master_dir, "ecs_flat.json.gz") - with open(new_path, 'wb') as f: - f.write(compressed) + new_path = get_etc_path([ECS_NAME, master_dir, "ecs_flat.json.gz"]) + with new_path.open("wb") as f: + _ = f.write(compressed) if verbose: - print('Saved files to {}: \n\t- {}'.format(master_dir, 'ecs_flat.json.gz')) + print("Saved files to {}: \n\t- {}".format(master_dir, "ecs_flat.json.gz")) def download_endpoint_schemas(target: str, overwrite: bool = True) -> None: @@ -340,9 +355,9 @@ def download_endpoint_schemas(target: str, overwrite: bool = True) -> None: # location of custom schema YAML files url = "https://raw.githubusercontent.com/elastic/endpoint-package/main/custom_schemas" - r = requests.get(f"{url}/custom_{target}.yml") - if r.status_code == 404: - r = requests.get(f"{url}/{target}/custom_{target}.yaml") + r = requests.get(f"{url}/custom_{target}.yml", timeout=30) + if r.status_code == 404: # noqa: PLR2004 + r = requests.get(f"{url}/{target}/custom_{target}.yaml", timeout=30) r.raise_for_status() schema = yaml.safe_load(r.text)[0] root_name = schema["name"] @@ -351,11 +366,11 @@ def download_endpoint_schemas(target: str, overwrite: bool = True) -> None: # iterate over nested fields and flatten them for f in fields: - if 'multi_fields' in f: - for mf in f['multi_fields']: - flattened[f"{root_name}.{f['name']}.{mf['name']}"] = mf['type'] + if "multi_fields" in f: + for mf in f["multi_fields"]: + flattened[f"{root_name}.{f['name']}.{mf['name']}"] = mf["type"] else: - flattened[f"{root_name}.{f['name']}"] = f['type'] + flattened[f"{root_name}.{f['name']}"] = f["type"] # save schema to disk ENDPOINT_SCHEMAS_DIR.mkdir(parents=True, exist_ok=True) @@ -363,16 +378,16 @@ def download_endpoint_schemas(target: str, overwrite: bool = True) -> None: new_path = ENDPOINT_SCHEMAS_DIR / f"endpoint_{target}.json.gz" if overwrite: shutil.rmtree(new_path, ignore_errors=True) - with open(new_path, 'wb') as f: - f.write(compressed) + with new_path.open("wb") as f: + _ = f.write(compressed) print(f"Saved endpoint schema to {new_path}") @cached -def get_endpoint_schemas() -> dict: +def get_endpoint_schemas() -> dict[str, Any]: """Load endpoint schemas.""" - schema = {} - existing = glob.glob(os.path.join(ENDPOINT_SCHEMAS_DIR, '*.json.gz')) + schema: dict[str, Any] = {} + existing = ENDPOINT_SCHEMAS_DIR.glob("*.json.gz") for f in existing: schema.update(json.loads(read_gzip(f))) return schema diff --git a/detection_rules/endgame.py b/detection_rules/endgame.py index 4ed6bd6246d..1c26a04046d 100644 --- a/detection_rules/endgame.py +++ b/detection_rules/endgame.py @@ -4,11 +4,14 @@ # 2.0. """Endgame Schemas management.""" + import json import shutil import sys +from typing import Any -import eql +import eql # type: ignore[reportMissingTypeStubs] +from github import Github from .utils import ETC_DIR, DateTimeEncoder, cached, gzip_compress, read_gzip @@ -18,12 +21,12 @@ class EndgameSchemaManager: """Endgame Class to download, convert, and save endgame schemas from endgame-evecs.""" - def __init__(self, github_client, endgame_version: str): + def __init__(self, github_client: Github, endgame_version: str) -> None: self.repo = github_client.get_repo("elastic/endgame-evecs") self.endgame_version = endgame_version self.endgame_schema = self.download_endgame_schema() - def download_endgame_schema(self) -> dict: + def download_endgame_schema(self) -> dict[str, Any]: """Download schema from endgame-evecs.""" # Use the static mapping.json file downloaded from the endgame-evecs repo. @@ -31,57 +34,56 @@ def download_endgame_schema(self) -> dict: main_branch_sha = main_branch.commit.sha schema_path = "pkg/mapper/ecs/schema.json" contents = self.repo.get_contents(schema_path, ref=main_branch_sha) - endgame_mapping = json.loads(contents.decoded_content.decode()) - - return endgame_mapping + return json.loads(contents.decoded_content.decode()) # type: ignore[reportAttributeAccessIssue] - def save_schemas(self, overwrite: bool = False): + def save_schemas(self, overwrite: bool = False) -> None: """Save the endgame schemas to the etc/endgame_schemas directory.""" schemas_dir = ENDGAME_SCHEMA_DIR / self.endgame_version if schemas_dir.exists() and not overwrite: raise FileExistsError(f"{schemas_dir} exists, use overwrite to force") - else: - shutil.rmtree(str(schemas_dir.resolve()), ignore_errors=True) - schemas_dir.mkdir() + shutil.rmtree(str(schemas_dir.resolve()), ignore_errors=True) + schemas_dir.mkdir() # write the raw schema to disk raw_os_schema = self.endgame_schema os_schema_path = schemas_dir / "endgame_ecs_mapping.json.gz" compressed = gzip_compress(json.dumps(raw_os_schema, sort_keys=True, cls=DateTimeEncoder)) - os_schema_path.write_bytes(compressed) + _ = os_schema_path.write_bytes(compressed) print(f"Endgame raw schema file saved: {os_schema_path}") class EndgameSchema(eql.Schema): """Endgame schema for query validation.""" - type_mapping = { - "keyword": eql.types.TypeHint.String, - "ip": eql.types.TypeHint.String, - "float": eql.types.TypeHint.Numeric, - "integer": eql.types.TypeHint.Numeric, - "boolean": eql.types.TypeHint.Boolean, - "text": eql.types.TypeHint.String, + type_mapping: dict[str, Any] = { # noqa: RUF012 + "keyword": eql.types.TypeHint.String, # type: ignore[reportAttributeAccessIssue] + "ip": eql.types.TypeHint.String, # type: ignore[reportAttributeAccessIssue] + "float": eql.types.TypeHint.Numeric, # type: ignore[reportAttributeAccessIssue] + "integer": eql.types.TypeHint.Numeric, # type: ignore[reportAttributeAccessIssue] + "boolean": eql.types.TypeHint.Boolean, # type: ignore[reportAttributeAccessIssue] + "text": eql.types.TypeHint.String, # type: ignore[reportAttributeAccessIssue] } - def __init__(self, endgame_schema): + def __init__(self, endgame_schema: dict[str, Any]) -> None: self.endgame_schema = endgame_schema - eql.Schema.__init__(self, {}, allow_any=True, allow_generic=False, allow_missing=False) + eql.Schema.__init__(self, {}, allow_any=True, allow_generic=False, allow_missing=False) # type: ignore[reportUnknownMemberType] + + def get_event_type_hint(self, _: str, path: list[str]) -> None | tuple[Any, None]: # type: ignore[reportIncompatibleMethodOverride] + from kql.parser import elasticsearch_type_family # type: ignore[reportMissingTypeStubs] - def get_event_type_hint(self, event_type, path): - from kql.parser import elasticsearch_type_family dotted = ".".join(str(p) for p in path) elasticsearch_type = self.endgame_schema.get(dotted) - es_type_family = elasticsearch_type_family(elasticsearch_type) + es_type_family = elasticsearch_type_family(elasticsearch_type) # type: ignore[reportArgumentType] eql_hint = self.type_mapping.get(es_type_family) - if eql_hint is not None: + if eql_hint: return eql_hint, None + return None @cached -def read_endgame_schema(endgame_version: str, warn=False) -> dict: +def read_endgame_schema(endgame_version: str, warn: bool = False) -> dict[str, Any] | None: """Load Endgame json schema. The schemas must be generated with the `download_endgame_schema()` method.""" # expect versions to be in format of N.N.N or master/main @@ -92,10 +94,7 @@ def read_endgame_schema(endgame_version: str, warn=False) -> dict: if warn: relative_path = endgame_schema_path.relative_to(ENDGAME_SCHEMA_DIR) print(f"Missing file to validate: {relative_path}, skipping", file=sys.stderr) - return - else: - raise FileNotFoundError(str(endgame_schema_path)) - - schema = json.loads(read_gzip(endgame_schema_path)) + return None + raise FileNotFoundError(str(endgame_schema_path)) - return schema + return json.loads(read_gzip(endgame_schema_path)) diff --git a/detection_rules/eswrap.py b/detection_rules/eswrap.py index eaca05b553d..66686fe820a 100644 --- a/detection_rules/eswrap.py +++ b/detection_rules/eswrap.py @@ -4,70 +4,82 @@ # 2.0. """Elasticsearch cli commands.""" + import json -import os import sys import time from collections import defaultdict -from typing import List, Union +from pathlib import Path +from typing import IO, Any import click import elasticsearch +import kql # type: ignore[reportMissingTypeStubs] from elasticsearch import Elasticsearch from elasticsearch.client import AsyncSearchClient -import kql from .config import parse_rules_config from .main import root -from .misc import add_params, client_error, elasticsearch_options, get_elasticsearch_client, nested_get +from .misc import add_params, elasticsearch_options, get_elasticsearch_client, nested_get, raise_client_error from .rule import TOMLRule - from .rule_loader import RuleCollection -from .utils import format_command_options, normalize_timing_and_sort, unix_time_to_formatted, get_path +from .utils import event_sort, format_command_options, get_path, normalize_timing_and_sort, unix_time_to_formatted -COLLECTION_DIR = get_path('collections') -MATCH_ALL = {'bool': {'filter': [{'match_all': {}}]}} +COLLECTION_DIR = get_path(["collections"]) +MATCH_ALL: dict[str, dict[str, Any]] = {"bool": {"filter": [{"match_all": {}}]}} RULES_CONFIG = parse_rules_config() -def add_range_to_dsl(dsl_filter, start_time, end_time='now'): +def add_range_to_dsl(dsl_filter: list[dict[str, Any]], start_time: str, end_time: str = "now") -> None: dsl_filter.append( - {"range": {"@timestamp": {"gt": start_time, "lte": end_time, "format": "strict_date_optional_time"}}} + { + "range": { + "@timestamp": { + "gt": start_time, + "lte": end_time, + "format": "strict_date_optional_time", + }, + }, + } ) -def parse_unique_field_results(rule_type: str, unique_fields: List[str], search_results: dict): - parsed_results = defaultdict(lambda: defaultdict(int)) - hits = search_results['hits'] - hits = hits['hits'] if rule_type != 'eql' else hits.get('events') or hits.get('sequences', []) +def parse_unique_field_results( + rule_type: str, + unique_fields: list[str], + search_results: dict[str, Any], +) -> dict[str, Any]: + parsed_results: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) + hits = search_results["hits"] + hits = hits["hits"] if rule_type != "eql" else hits.get("events") or hits.get("sequences", []) for hit in hits: for field in unique_fields: - if 'events' in hit: - match = [] - for event in hit['events']: - matched = nested_get(event['_source'], field) - match.extend([matched] if not isinstance(matched, list) else matched) + if "events" in hit: + match: list[Any] = [] + for event in hit["events"]: + matched = nested_get(event["_source"], field) + match.extend([matched] if not isinstance(matched, list) else matched) # type: ignore[reportUnknownArgumentType] if not match: continue else: - match = nested_get(hit['_source'], field) + match = nested_get(hit["_source"], field) if not match: continue - match = ','.join(sorted(match)) if isinstance(match, list) else match - parsed_results[field][match] += 1 + match = ",".join(sorted(match)) if isinstance(match, list) else match # type: ignore[reportUnknownArgumentType] + parsed_results[field][match] += 1 # type: ignore[reportUnknownArgumentType] # if rule.type == eql, structure is different - return {'results': parsed_results} if parsed_results else {} + return {"results": parsed_results} if parsed_results else {} class Events: """Events collected from Elasticsearch.""" - def __init__(self, events): - self.events: dict = self._normalize_event_timing(events) + def __init__(self, events: dict[str, Any]) -> None: + self.events = self._normalize_event_timing(events) @staticmethod - def _normalize_event_timing(events): + def _normalize_event_timing(events: dict[str, Any]) -> dict[str, Any]: """Normalize event timestamps and sort.""" for agent_type, _events in events.items(): events[agent_type] = normalize_timing_and_sort(_events) @@ -75,338 +87,460 @@ def _normalize_event_timing(events): return events @staticmethod - def _get_dump_dir(rta_name=None, host_id=None, host_os_family=None): + def _get_dump_dir( + rta_name: str | None = None, + host_id: str | None = None, + host_os_family: str | None = None, + ) -> Path: """Prepare and get the dump path.""" if rta_name and host_os_family: - dump_dir = get_path('unit_tests', 'data', 'true_positives', rta_name, host_os_family) - os.makedirs(dump_dir, exist_ok=True) - return dump_dir - else: - time_str = time.strftime('%Y%m%dT%H%M%SL') - dump_dir = os.path.join(COLLECTION_DIR, host_id or 'unknown_host', time_str) - os.makedirs(dump_dir, exist_ok=True) + dump_dir = get_path(["unit_tests", "data", "true_positives", rta_name, host_os_family]) + dump_dir.mkdir(parents=True, exist_ok=True) return dump_dir + time_str = time.strftime("%Y%m%dT%H%M%SL") + dump_dir = COLLECTION_DIR / (host_id or "unknown_host") / time_str + dump_dir.mkdir(parents=True, exist_ok=True) + return dump_dir - def evaluate_against_rule(self, rule_id, verbose=True): + def evaluate_against_rule(self, rule_id: str, verbose: bool = True) -> list[Any]: """Evaluate a rule against collected events and update mapping.""" - from .utils import combine_sources, evaluate - rule = RuleCollection.default().id_map.get(rule_id) - assert rule is not None, f"Unable to find rule with ID {rule_id}" + if not rule: + raise ValueError(f"Unable to find rule with ID {rule_id}") merged_events = combine_sources(*self.events.values()) filtered = evaluate(rule, merged_events, normalize_kql_keywords=RULES_CONFIG.normalize_kql_keywords) if verbose: - click.echo('Matching results found') + click.echo("Matching results found") return filtered - def echo_events(self, pager=False, pretty=True): + def echo_events(self, pager: bool = False, pretty: bool = True) -> None: """Print events to stdout.""" echo_fn = click.echo_via_pager if pager else click.echo echo_fn(json.dumps(self.events, indent=2 if pretty else None, sort_keys=True)) - def save(self, rta_name=None, dump_dir=None, host_id=None): + def save(self, rta_name: str | None = None, dump_dir: Path | None = None, host_id: str | None = None) -> None: """Save collected events.""" - assert self.events, 'Nothing to save. Run Collector.run() method first or verify logging' + if not self.events: + raise ValueError("Nothing to save. Run Collector.run() method first or verify logging") host_os_family = None - for key in self.events.keys(): - if self.events.get(key, {})[0].get('host', {}).get('id') == host_id: - host_os_family = self.events.get(key, {})[0].get('host', {}).get('os').get('family') + for key in self.events: + if self.events.get(key, {})[0].get("host", {}).get("id") == host_id: + host_os_family = self.events.get(key, {})[0].get("host", {}).get("os").get("family") break if not host_os_family: - click.echo('Unable to determine host.os.family for host_id: {}'.format(host_id)) - host_os_family = click.prompt("Please enter the host.os.family for this host_id", - type=click.Choice(["windows", "macos", "linux"]), default="windows") + click.echo(f"Unable to determine host.os.family for host_id: {host_id}") + host_os_family = click.prompt( + "Please enter the host.os.family for this host_id", + type=click.Choice(["windows", "macos", "linux"]), + default="windows", + ) dump_dir = dump_dir or self._get_dump_dir(rta_name=rta_name, host_id=host_id, host_os_family=host_os_family) for source, events in self.events.items(): - path = os.path.join(dump_dir, source + '.ndjson') - with open(path, 'w') as f: - f.writelines([json.dumps(e, sort_keys=True) + '\n' for e in events]) - click.echo('{} events saved to: {}'.format(len(events), path)) + path = dump_dir / (source + ".ndjson") + with path.open("w") as f: + f.writelines([json.dumps(e, sort_keys=True) + "\n" for e in events]) + click.echo(f"{len(events)} events saved to: {path}") -class CollectEvents(object): +class CollectEvents: """Event collector for elastic stack.""" - def __init__(self, client, max_events=3000): - self.client: Elasticsearch = client + def __init__(self, client: Elasticsearch, max_events: int = 3000) -> None: + self.client = client self.max_events = max_events - def _build_timestamp_map(self, index_str): + def _build_timestamp_map(self, index: str) -> dict[str, Any]: """Build a mapping of indexes to timestamp data formats.""" - mappings = self.client.indices.get_mapping(index=index_str) - timestamp_map = {n: m['mappings'].get('properties', {}).get('@timestamp', {}) for n, m in mappings.items()} - return timestamp_map + mappings = self.client.indices.get_mapping(index=index) + return {n: m["mappings"].get("properties", {}).get("@timestamp", {}) for n, m in mappings.items()} - def _get_last_event_time(self, index_str, dsl=None): + def _get_last_event_time(self, index: str, dsl: dict[str, Any] | None = None) -> None | str: """Get timestamp of most recent event.""" - last_event = self.client.search(query=dsl, index=index_str, size=1, sort='@timestamp:desc')['hits']['hits'] + last_event = self.client.search(query=dsl, index=index, size=1, sort="@timestamp:desc")["hits"]["hits"] if not last_event: - return + return None last_event = last_event[0] - index = last_event['_index'] - timestamp = last_event['_source']['@timestamp'] + index = last_event["_index"] + timestamp = last_event["_source"]["@timestamp"] - timestamp_map = self._build_timestamp_map(index_str) - event_date_format = timestamp_map[index].get('format', '').split('||') + timestamp_map = self._build_timestamp_map(index) + event_date_format = timestamp_map[index].get("format", "").split("||") # there are many native supported date formats and even custom data formats, but most, including beats use the # default `strict_date_optional_time`. It would be difficult to try to account for all possible formats, so this # will work on the default and unix time. - if set(event_date_format) & {'epoch_millis', 'epoch_second'}: + if set(event_date_format) & {"epoch_millis", "epoch_second"}: timestamp = unix_time_to_formatted(timestamp) return timestamp @staticmethod - def _prep_query(query, language, index, start_time=None, end_time=None): + def _prep_query( + query: str | dict[str, Any], + language: str, + index: str | list[str] | tuple[str], + start_time: str | None = None, + end_time: str | None = None, + ) -> tuple[str, dict[str, Any], str | None]: """Prep a query for search.""" - index_str = ','.join(index if isinstance(index, (list, tuple)) else index.split(',')) - lucene_query = query if language == 'lucene' else None - - if language in ('kql', 'kuery'): - formatted_dsl = {'query': kql.to_dsl(query)} - elif language == 'eql': - formatted_dsl = {'query': query, 'filter': MATCH_ALL} - elif language == 'lucene': - formatted_dsl = {'query': {'bool': {'filter': []}}} - elif language == 'dsl': - formatted_dsl = {'query': query} + index_str = ",".join(index if isinstance(index, (list | tuple)) else index.split(",")) + lucene_query = str(query) if language == "lucene" else None + + if language in ("kql", "kuery"): + formatted_dsl = {"query": kql.to_dsl(query)} # type: ignore[reportUnknownMemberType] + elif language == "eql": + formatted_dsl = {"query": query, "filter": MATCH_ALL} + elif language == "lucene": + formatted_dsl: dict[str, Any] = {"query": {"bool": {"filter": []}}} + elif language == "dsl": + formatted_dsl = {"query": query} else: - raise ValueError(f'Unknown search language: {language}') + raise ValueError(f"Unknown search language: {language}") if start_time or end_time: - end_time = end_time or 'now' - dsl = formatted_dsl['filter']['bool']['filter'] if language == 'eql' else \ - formatted_dsl['query']['bool'].setdefault('filter', []) + end_time = end_time or "now" + dsl = ( + formatted_dsl["filter"]["bool"]["filter"] + if language == "eql" + else formatted_dsl["query"]["bool"].setdefault("filter", []) + ) + if not start_time: + raise ValueError("No start time provided") + add_range_to_dsl(dsl, start_time, end_time) return index_str, formatted_dsl, lucene_query - def search(self, query, language, index: Union[str, list] = '*', start_time=None, end_time=None, size=None, - **kwargs): + def search( # noqa: PLR0913 + self, + query: str | dict[str, Any], + language: str, + index: str | list[str] = "*", + start_time: str | None = None, + end_time: str | None = None, + size: int | None = None, + **kwargs: Any, + ) -> list[Any]: """Search an elasticsearch instance.""" - index_str, formatted_dsl, lucene_query = self._prep_query(query=query, language=language, index=index, - start_time=start_time, end_time=end_time) + index_str, formatted_dsl, lucene_query = self._prep_query( + query=query, language=language, index=index, start_time=start_time, end_time=end_time + ) formatted_dsl.update(size=size or self.max_events) - if language == 'eql': - results = self.client.eql.search(body=formatted_dsl, index=index_str, **kwargs)['hits'] - results = results.get('events') or results.get('sequences', []) + if language == "eql": + results = self.client.eql.search(body=formatted_dsl, index=index_str, **kwargs)["hits"] + results = results.get("events") or results.get("sequences", []) else: - results = self.client.search(body=formatted_dsl, q=lucene_query, index=index_str, - allow_no_indices=True, ignore_unavailable=True, **kwargs)['hits']['hits'] + results = self.client.search( + body=formatted_dsl, + q=lucene_query, + index=index_str, + allow_no_indices=True, + ignore_unavailable=True, + **kwargs, + )["hits"]["hits"] return results - def search_from_rule(self, rules: RuleCollection, start_time=None, end_time='now', size=None): + def search_from_rule( + self, + rules: RuleCollection, + start_time: str | None = None, + end_time: str = "now", + size: int | None = None, + ) -> dict[str, Any]: """Search an elasticsearch instance using a rule.""" async_client = AsyncSearchClient(self.client) - survey_results = {} - multi_search = [] - multi_search_rules = [] - async_searches = [] - eql_searches = [] + survey_results: dict[str, Any] = {} + multi_search: list[dict[str, Any]] = [] + multi_search_rules: list[TOMLRule] = [] + async_searches: list[tuple[TOMLRule, Any]] = [] + eql_searches: list[tuple[TOMLRule, dict[str, Any]]] = [] for rule in rules: - if not rule.contents.data.get('query'): + if not rule.contents.data.get("query"): continue - language = rule.contents.data.get('language') - query = rule.contents.data.query + language = rule.contents.data.get("language") + query = rule.contents.data.query # type: ignore[reportAttributeAccessIssue] rule_type = rule.contents.data.type - index_str, formatted_dsl, lucene_query = self._prep_query(query=query, - language=language, - index=rule.contents.data.get('index', '*'), - start_time=start_time, - end_time=end_time) + index_str, formatted_dsl, _ = self._prep_query( + query=query, # type: ignore[reportUnknownArgumentType] + language=language, # type: ignore[reportUnknownArgumentType] + index=rule.contents.data.get("index", "*"), # type: ignore[reportUnknownArgumentType] + start_time=start_time, + end_time=end_time, + ) formatted_dsl.update(size=size or self.max_events) # prep for searches: msearch for kql | async search for lucene | eql client search for eql - if language == 'kuery': + if language == "kuery": multi_search_rules.append(rule) - multi_search.append({'index': index_str, 'allow_no_indices': 'true', 'ignore_unavailable': 'true'}) + multi_search.append({"index": index_str, "allow_no_indices": "true", "ignore_unavailable": "true"}) multi_search.append(formatted_dsl) - elif language == 'lucene': + elif language == "lucene": # wait for 0 to try and force async with no immediate results (not guaranteed) - result = async_client.submit(body=formatted_dsl, q=query, index=index_str, - allow_no_indices=True, ignore_unavailable=True, - wait_for_completion_timeout=0) - if result['is_running'] is True: - async_searches.append((rule, result['id'])) + result = async_client.submit( + body=formatted_dsl, + q=query, # type: ignore[reportUnknownArgumentType] + index=index_str, + allow_no_indices=True, + ignore_unavailable=True, + wait_for_completion_timeout=0, + ) + if result["is_running"] is True: + async_searches.append((rule, result["id"])) else: - survey_results[rule.id] = parse_unique_field_results(rule_type, ['process.name'], - result['response']) - elif language == 'eql': - eql_body = { - 'index': index_str, - 'params': {'ignore_unavailable': 'true', 'allow_no_indices': 'true'}, - 'body': {'query': query, 'filter': formatted_dsl['filter']} + survey_results[rule.id] = parse_unique_field_results( + rule_type, ["process.name"], result["response"] + ) + elif language == "eql": + eql_body: dict[str, Any] = { + "index": index_str, + "params": {"ignore_unavailable": "true", "allow_no_indices": "true"}, + "body": {"query": query, "filter": formatted_dsl["filter"]}, } eql_searches.append((rule, eql_body)) # assemble search results multi_search_results = self.client.msearch(searches=multi_search) - for index, result in enumerate(multi_search_results['responses']): + for index, result in enumerate(multi_search_results["responses"]): try: rule = multi_search_rules[index] - survey_results[rule.id] = parse_unique_field_results(rule.contents.data.type, - rule.contents.data.unique_fields, result) + survey_results[rule.id] = parse_unique_field_results( + rule.contents.data.type, + rule.contents.data.unique_fields, # type: ignore[reportAttributeAccessIssje] + result, + ) except KeyError: - survey_results[multi_search_rules[index].id] = {'error_retrieving_results': True} + survey_results[multi_search_rules[index].id] = {"error_retrieving_results": True} for entry in eql_searches: - rule: TOMLRule - search_args: dict rule, search_args = entry try: result = self.client.eql.search(**search_args) - survey_results[rule.id] = parse_unique_field_results(rule.contents.data.type, - rule.contents.data.unique_fields, result) + survey_results[rule.id] = parse_unique_field_results( + rule.contents.data.type, + rule.contents.data.unique_fields, # type: ignore[reportAttributeAccessIssue] + result, # type: ignore[reportAttributeAccessIssue] + ) except (elasticsearch.NotFoundError, elasticsearch.RequestError) as e: - survey_results[rule.id] = {'error_retrieving_results': True, 'error': e.info['error']['reason']} + survey_results[rule.id] = {"error_retrieving_results": True, "error": e.info["error"]["reason"]} for entry in async_searches: rule: TOMLRule rule, async_id = entry - result = async_client.get(id=async_id)['response'] - survey_results[rule.id] = parse_unique_field_results(rule.contents.data.type, ['process.name'], result) + result = async_client.get(id=async_id)["response"] + survey_results[rule.id] = parse_unique_field_results(rule.contents.data.type, ["process.name"], result) return survey_results - def count(self, query, language, index: Union[str, list], start_time=None, end_time='now'): + def count( + self, + query: str, + language: str, + index: str | list[str], + start_time: str | None = None, + end_time: str | None = "now", + ) -> Any: """Get a count of documents from elasticsearch.""" - index_str, formatted_dsl, lucene_query = self._prep_query(query=query, language=language, index=index, - start_time=start_time, end_time=end_time) + index_str, formatted_dsl, lucene_query = self._prep_query( + query=query, + language=language, + index=index, + start_time=start_time, + end_time=end_time, + ) # EQL API has no count endpoint - if language == 'eql': - results = self.search(query=query, language=language, index=index, start_time=start_time, end_time=end_time, - size=1000) + if language == "eql": + results = self.search( + query=query, + language=language, + index=index, + start_time=start_time, + end_time=end_time, + size=1000, + ) return len(results) - else: - return self.client.count(body=formatted_dsl, index=index_str, q=lucene_query, allow_no_indices=True, - ignore_unavailable=True)['count'] - - def count_from_rule(self, rules: RuleCollection, start_time=None, end_time='now'): + resp = self.client.count( + body=formatted_dsl, + index=index_str, + q=lucene_query, + allow_no_indices=True, + ignore_unavailable=True, + ) + + return resp["count"] + + def count_from_rule( + self, + rules: RuleCollection, + start_time: str | None = None, + end_time: str | None = "now", + ) -> dict[str, Any]: """Get a count of documents from elasticsearch using a rule.""" - survey_results = {} + survey_results: dict[str, Any] = {} for rule in rules.rules: - rule_results = {'rule_id': rule.id, 'name': rule.name} + rule_results: dict[str, Any] = {"rule_id": rule.id, "name": rule.name} - if not rule.contents.data.get('query'): + if not rule.contents.data.get("query"): continue try: - rule_results['search_count'] = self.count(query=rule.contents.data.query, - language=rule.contents.data.language, - index=rule.contents.data.get('index', '*'), - start_time=start_time, - end_time=end_time) + rule_results["search_count"] = self.count( + query=rule.contents.data.query, # type: ignore[reportAttributeAccessIssue] + language=rule.contents.data.language, # type: ignore[reportAttributeAccessIssue] + index=rule.contents.data.get("index", "*"), # type: ignore[reportAttributeAccessIssue] + start_time=start_time, + end_time=end_time, + ) except (elasticsearch.NotFoundError, elasticsearch.RequestError): - rule_results['search_count'] = -1 + rule_results["search_count"] = -1 survey_results[rule.id] = rule_results return survey_results +def evaluate(rule: TOMLRule, events: list[Any], normalize_kql_keywords: bool = False) -> list[Any]: + """Evaluate a query against events.""" + evaluator = kql.get_evaluator(kql.parse(rule.query), normalize_kql_keywords=normalize_kql_keywords) # type: ignore[reportUnknownMemberType] + return list(filter(evaluator, events)) # type: ignore[reportUnknownMemberType] + + +def combine_sources(sources: list[Any]) -> list[Any]: + """Combine lists of events from multiple sources.""" + combined: list[Any] = [] + for source in sources: + combined.extend(source.copy()) + + return event_sort(combined) + + class CollectEventsWithDSL(CollectEvents): """Collect events from elasticsearch.""" @staticmethod - def _group_events_by_type(events): + def _group_events_by_type(events: list[Any]) -> dict[str, list[Any]]: """Group events by agent.type.""" - event_by_type = {} + event_by_type: dict[str, list[Any]] = {} for event in events: - event_by_type.setdefault(event['_source']['agent']['type'], []).append(event['_source']) + event_by_type.setdefault(event["_source"]["agent"]["type"], []).append(event["_source"]) return event_by_type - def run(self, dsl, indexes, start_time): + def run(self, dsl: dict[str, Any], indexes: str | list[str], start_time: str) -> Events: """Collect the events.""" - results = self.search(dsl, language='dsl', index=indexes, start_time=start_time, end_time='now', size=5000, - sort=[{'@timestamp': {'order': 'asc'}}]) + results = self.search( + dsl, + language="dsl", + index=indexes, + start_time=start_time, + end_time="now", + size=5000, + sort=[{"@timestamp": {"order": "asc"}}], + ) events = self._group_events_by_type(results) return Events(events) -@root.command('normalize-data') -@click.argument('events-file', type=click.File('r')) -def normalize_data(events_file): +@root.command("normalize-data") +@click.argument("events-file", type=Path) +def normalize_data(events_file: Path) -> None: """Normalize Elasticsearch data timestamps and sort.""" - file_name = os.path.splitext(os.path.basename(events_file.name))[0] - events = Events({file_name: [json.loads(e) for e in events_file.readlines()]}) - events.save(dump_dir=os.path.dirname(events_file.name)) + file_name = events_file.name + content = events_file.read_text() + lines = content.splitlines() + + events = Events({file_name: [json.loads(line) for line in lines]}) + events.save(dump_dir=events_file.parent) -@root.group('es') + +@root.group("es") @add_params(*elasticsearch_options) @click.pass_context -def es_group(ctx: click.Context, **kwargs): +def es_group(ctx: click.Context, **kwargs: Any) -> None: """Commands for integrating with Elasticsearch.""" - ctx.ensure_object(dict) + _ = ctx.ensure_object(dict) # type: ignore[reportUnknownVariableType] # only initialize an es client if the subcommand is invoked without help (hacky) if sys.argv[-1] in ctx.help_option_names: - click.echo('Elasticsearch client:') + click.echo("Elasticsearch client:") click.echo(format_command_options(ctx)) else: - ctx.obj['es'] = get_elasticsearch_client(ctx=ctx, **kwargs) + ctx.obj["es"] = get_elasticsearch_client(ctx=ctx, **kwargs) -@es_group.command('collect-events') -@click.argument('host-id') -@click.option('--query', '-q', help='KQL query to scope search') -@click.option('--index', '-i', multiple=True, help='Index(es) to search against (default: all indexes)') -@click.option('--rta-name', '-r', help='Name of RTA in order to save events directly to unit tests data directory') -@click.option('--rule-id', help='Updates rule mapping in rule-mapping.yaml file (requires --rta-name)') -@click.option('--view-events', is_flag=True, help='Print events after saving') +@es_group.command("collect-events") +@click.argument("host-id") +@click.option("--query", "-q", help="KQL query to scope search") +@click.option("--index", "-i", multiple=True, help="Index(es) to search against (default: all indexes)") +@click.option("--rta-name", "-r", help="Name of RTA in order to save events directly to unit tests data directory") +@click.option("--rule-id", help="Updates rule mapping in rule-mapping.yaml file (requires --rta-name)") +@click.option("--view-events", is_flag=True, help="Print events after saving") @click.pass_context -def collect_events(ctx, host_id, query, index, rta_name, rule_id, view_events): +def collect_events( # noqa: PLR0913 + ctx: click.Context, + host_id: str, + query: str, + index: list[str], + rta_name: str, + rule_id: str, + view_events: bool, +) -> Events: """Collect events from Elasticsearch.""" - client: Elasticsearch = ctx.obj['es'] - dsl = kql.to_dsl(query) if query else MATCH_ALL - dsl['bool'].setdefault('filter', []).append({'bool': {'should': [{'match_phrase': {'host.id': host_id}}]}}) + client: Elasticsearch = ctx.obj["es"] + dsl = kql.to_dsl(query) if query else MATCH_ALL # type: ignore[reportUnknownMemberType] + dsl["bool"].setdefault("filter", []).append( # type: ignore[reportUnknownMemberType] + { + "bool": { + "should": [{"match_phrase": {"host.id": host_id}}], + }, + } + ) try: collector = CollectEventsWithDSL(client) start = time.time() - click.pause('Press any key once detonation is complete ...') - start_time = f'now-{round(time.time() - start) + 5}s' - events = collector.run(dsl, index or '*', start_time) + click.pause("Press any key once detonation is complete ...") + start_time = f"now-{round(time.time() - start) + 5}s" + events = collector.run(dsl, index or "*", start_time) # type: ignore[reportUnknownArgument] events.save(rta_name=rta_name, host_id=host_id) if rta_name and rule_id: - events.evaluate_against_rule(rule_id) + _ = events.evaluate_against_rule(rule_id) if view_events and events.events: events.echo_events(pager=True) - return events except AssertionError as e: - error_msg = 'No events collected! Verify events are streaming and that the agent-hostname is correct' - client_error(error_msg, e, ctx=ctx) + error_msg = "No events collected! Verify events are streaming and that the agent-hostname is correct" + raise_client_error(error_msg, e, ctx=ctx) + + return events -@es_group.command('index-rules') -@click.option('--query', '-q', help='Optional KQL query to limit to specific rules') -@click.option('--from-file', '-f', type=click.File('r'), help='Load a previously saved uploadable bulk file') -@click.option('--save_files', '-s', is_flag=True, help='Optionally save the bulk request to a file') +@es_group.command("index-rules") +@click.option("--query", "-q", help="Optional KQL query to limit to specific rules") +@click.option("--from-file", "-f", type=click.File("r"), help="Load a previously saved uploadable bulk file") +@click.option("--save_files", "-s", is_flag=True, help="Optionally save the bulk request to a file") @click.pass_context -def index_repo(ctx: click.Context, query, from_file, save_files): +def index_repo(ctx: click.Context, query: str, from_file: IO[Any] | None, save_files: bool) -> None: """Index rules based on KQL search results to an elasticsearch instance.""" from .main import generate_rules_index - es_client: Elasticsearch = ctx.obj['es'] + es_client: Elasticsearch = ctx.obj["es"] if from_file: bulk_upload_docs = from_file.read() @@ -414,10 +548,10 @@ def index_repo(ctx: click.Context, query, from_file, save_files): # light validation only try: index_body = [json.loads(line) for line in bulk_upload_docs.splitlines()] - click.echo(f'{len([r for r in index_body if "rule" in r])} rules included') + click.echo(f"{len([r for r in index_body if 'rule' in r])} rules included") except json.JSONDecodeError: - client_error(f'Improperly formatted bulk request file: {from_file.name}') + raise_client_error(f"Improperly formatted bulk request file: {from_file.name}") else: - bulk_upload_docs, importable_rules_docs = ctx.invoke(generate_rules_index, query=query, save_files=save_files) + bulk_upload_docs, _ = ctx.invoke(generate_rules_index, query=query, save_files=save_files) - es_client.bulk(bulk_upload_docs) + _ = es_client.bulk(operations=bulk_upload_docs) diff --git a/detection_rules/exception.py b/detection_rules/exception.py index 1b89d6ab8dc..ead11ef5a3c 100644 --- a/detection_rules/exception.py +++ b/detection_rules/exception.py @@ -3,18 +3,19 @@ # 2.0; you may not use this file except in compliance with the Elastic License # 2.0. """Rule exceptions data.""" + from collections import defaultdict from dataclasses import dataclass from datetime import datetime from pathlib import Path -from typing import List, Optional, Union, Tuple, get_args +from typing import Any, get_args -import pytoml +import pytoml # type: ignore[reportMissingTypeStubs] from marshmallow import EXCLUDE, ValidationError, validates_schema +from .config import parse_rules_config from .mixins import MarshmallowDataclassMixin from .schemas import definitions -from .config import parse_rules_config RULES_CONFIG = parse_rules_config() @@ -24,21 +25,23 @@ @dataclass(frozen=True) class ExceptionMeta(MarshmallowDataclassMixin): """Data stored in an exception's [metadata] section of TOML.""" + creation_date: definitions.Date list_name: str - rule_ids: List[definitions.UUIDString] - rule_names: List[str] + rule_ids: list[definitions.UUIDString] + rule_names: list[str] updated_date: definitions.Date # Optional fields - deprecation_date: Optional[definitions.Date] - comments: Optional[str] - maturity: Optional[definitions.Maturity] + deprecation_date: definitions.Date | None = None + comments: str | None = None + maturity: definitions.Maturity | None = None @dataclass(frozen=True) class BaseExceptionItemEntry(MarshmallowDataclassMixin): """Shared object between nested and non-nested exception items.""" + field: str type: definitions.ExceptionEntryType @@ -46,101 +49,110 @@ class BaseExceptionItemEntry(MarshmallowDataclassMixin): @dataclass(frozen=True) class NestedExceptionItemEntry(BaseExceptionItemEntry, MarshmallowDataclassMixin): """Nested exception item entry.""" - entries: List['ExceptionItemEntry'] + + entries: list["ExceptionItemEntry"] @validates_schema - def validate_nested_entry(self, data: dict, **kwargs): + def validate_nested_entry(self, data: dict[str, Any], **_: Any) -> None: """More specific validation.""" - if data.get('list') is not None: - raise ValidationError('Nested entries cannot define a list') + if data.get("list"): + raise ValidationError("Nested entries cannot define a list") @dataclass(frozen=True) class ExceptionItemEntry(BaseExceptionItemEntry, MarshmallowDataclassMixin): """Exception item entry.""" + @dataclass(frozen=True) class ListObject: """List object for exception item entry.""" + id: str type: definitions.EsDataTypes - list: Optional[ListObject] operator: definitions.ExceptionEntryOperator - value: Optional[Union[str, List[str]]] + list_vals: ListObject | None = None + value: str | None | list[str] = None @validates_schema - def validate_entry(self, data: dict, **kwargs): + def validate_entry(self, data: dict[str, Any], **_: Any) -> None: """Validate the entry based on its type.""" - value = data.get('value', '') - if data['type'] in ('exists', 'list') and value is not None: - raise ValidationError(f'Entry of type {data["type"]} cannot have a value') - elif data['type'] in ('match', 'wildcard') and not isinstance(value, str): - raise ValidationError(f'Entry of type {data["type"]} must have a string value') - elif data['type'] == 'match_any' and not isinstance(value, list): - raise ValidationError(f'Entry of type {data["type"]} must have a list of strings as a value') + value = data.get("value", "") + if data["type"] in ("exists", "list") and value is not None: + raise ValidationError(f"Entry of type {data['type']} cannot have a value") + if data["type"] in ("match", "wildcard") and not isinstance(value, str): + raise ValidationError(f"Entry of type {data['type']} must have a string value") + if data["type"] == "match_any" and not isinstance(value, list): + raise ValidationError(f"Entry of type {data['type']} must have a list of strings as a value") @dataclass(frozen=True) class ExceptionItem(MarshmallowDataclassMixin): """Base exception item.""" + @dataclass(frozen=True) class Comment: """Comment object for exception item.""" + comment: str - comments: List[Optional[Comment]] + comments: list[Comment | None] description: str - entries: List[Union[ExceptionItemEntry, NestedExceptionItemEntry]] + entries: list[ExceptionItemEntry | NestedExceptionItemEntry] list_id: str - item_id: Optional[str] # api sets field when not provided - meta: Optional[dict] + item_id: str | None # api sets field when not provided + meta: dict[str, Any] | None name: str - namespace_type: Optional[definitions.ExceptionNamespaceType] # defaults to "single" if not provided - tags: Optional[List[str]] + namespace_type: definitions.ExceptionNamespaceType | None # defaults to "single" if not provided + tags: list[str] | None type: definitions.ExceptionItemType @dataclass(frozen=True) class EndpointException(ExceptionItem, MarshmallowDataclassMixin): """Endpoint exception item.""" - _tags: List[definitions.ExceptionItemEndpointTags] + + _tags: list[definitions.ExceptionItemEndpointTags] @validates_schema - def validate_endpoint(self, data: dict, **kwargs): + def validate_endpoint(self, data: dict[str, Any], **_: Any) -> None: """Validate the endpoint exception.""" - for entry in data['entries']: - if entry['operator'] == "excluded": + for entry in data["entries"]: + if entry["operator"] == "excluded": raise ValidationError("Endpoint exceptions cannot have an `excluded` operator") @dataclass(frozen=True) class DetectionException(ExceptionItem, MarshmallowDataclassMixin): """Detection exception item.""" - expire_time: Optional[str] # fields.DateTime] # maybe this is isoformat? + + expire_time: str | None # fields.DateTime] # maybe this is isoformat? @dataclass(frozen=True) class ExceptionContainer(MarshmallowDataclassMixin): """Exception container.""" + description: str - list_id: Optional[str] - meta: Optional[dict] + list_id: str | None + meta: dict[str, Any] | None name: str - namespace_type: Optional[definitions.ExceptionNamespaceType] - tags: Optional[List[str]] + namespace_type: definitions.ExceptionNamespaceType | None + tags: list[str] | None type: definitions.ExceptionContainerType - def to_rule_entry(self) -> dict: + def to_rule_entry(self) -> dict[str, Any]: """Returns a dict of the format required in rule.exception_list.""" # requires KSO id to be consider valid structure - return dict(namespace_type=self.namespace_type, type=self.type, list_id=self.list_id) + return {"namespace_type": self.namespace_type, "type": self.type, "list_id": self.list_id} @dataclass(frozen=True) class Data(MarshmallowDataclassMixin): """Data stored in an exception's [exception] section of TOML.""" + container: ExceptionContainer - items: Optional[List[DetectionException]] # Union[DetectionException, EndpointException]] + items: list[DetectionException] | None @dataclass(frozen=True) @@ -148,25 +160,24 @@ class TOMLExceptionContents(MarshmallowDataclassMixin): """Data stored in an exception file.""" metadata: ExceptionMeta - exceptions: List[Data] + exceptions: list[Data] @classmethod - def from_exceptions_dict(cls, exceptions_dict: dict, rule_list: list[dict]) -> "TOMLExceptionContents": + def from_exceptions_dict( + cls, exceptions_dict: dict[str, Any], rule_list: list[dict[str, Any]] + ) -> "TOMLExceptionContents": """Create a TOMLExceptionContents from a kibana rule resource.""" - rule_ids = [] - rule_names = [] + rule_ids: list[str] = [] + rule_names: list[str] = [] for rule in rule_list: rule_ids.append(rule["id"]) rule_names.append(rule["name"]) # Format date to match schema - creation_date = datetime.strptime(exceptions_dict["container"]["created_at"], "%Y-%m-%dT%H:%M:%S.%fZ").strftime( - "%Y/%m/%d" - ) - updated_date = datetime.strptime(exceptions_dict["container"]["updated_at"], "%Y-%m-%dT%H:%M:%S.%fZ").strftime( - "%Y/%m/%d" - ) + container = exceptions_dict["container"] + creation_date = datetime.strptime(container["created_at"], "%Y-%m-%dT%H:%M:%S.%fZ").strftime("%Y/%m/%d") # noqa: DTZ007 + updated_date = datetime.strptime(container["updated_at"], "%Y-%m-%dT%H:%M:%S.%fZ").strftime("%Y/%m/%d") # noqa: DTZ007 metadata = { "creation_date": creation_date, "list_name": exceptions_dict["container"]["name"], @@ -177,15 +188,14 @@ def from_exceptions_dict(cls, exceptions_dict: dict, rule_list: list[dict]) -> " return cls.from_dict({"metadata": metadata, "exceptions": [exceptions_dict]}, unknown=EXCLUDE) - def to_api_format(self) -> List[dict]: + def to_api_format(self) -> list[dict[str, Any]]: """Convert the TOML Exception to the API format.""" - converted = [] + converted: list[dict[str, Any]] = [] for exception in self.exceptions: converted.append(exception.container.to_dict()) if exception.items: - for item in exception.items: - converted.append(item.to_dict()) + converted.extend([item.to_dict() for item in exception.items]) return converted @@ -193,17 +203,19 @@ def to_api_format(self) -> List[dict]: @dataclass(frozen=True) class TOMLException: """TOML exception object.""" + contents: TOMLExceptionContents - path: Optional[Path] = None + path: Path | None = None @property - def name(self): + def name(self) -> str: """Return the name of the exception list.""" return self.contents.metadata.list_name - def save_toml(self): + def save_toml(self) -> None: """Save the exception to a TOML file.""" - assert self.path is not None, f"Can't save exception {self.name} without a path" + if not self.path: + raise ValueError(f"Can't save exception {self.name} without a path") # Check if self.path has a .toml extension path = self.path if path.suffix != ".toml": @@ -213,55 +225,60 @@ def save_toml(self): contents_dict = self.contents.to_dict() # Sort the dictionary so that 'metadata' is at the top sorted_dict = dict(sorted(contents_dict.items(), key=lambda item: item[0] != "metadata")) - pytoml.dump(sorted_dict, f) + pytoml.dump(sorted_dict, f) # type: ignore[reportUnknownMemberType] -def parse_exceptions_results_from_api(results: List[dict]) -> tuple[dict, dict, List[str], List[dict]]: +def parse_exceptions_results_from_api( + results: list[dict[str, Any]], +) -> tuple[dict[str, Any], dict[str, Any], list[str], list[dict[str, Any]]]: """Parse exceptions results from the API into containers and items.""" - exceptions_containers = {} - exceptions_items = defaultdict(list) - errors = [] - unparsed_results = [] + exceptions_containers: dict[str, Any] = {} + exceptions_items: dict[str, list[Any]] = defaultdict(list) + unparsed_results: list[dict[str, Any]] = [] for result in results: result_type = result.get("type") list_id = result.get("list_id") - if result_type in get_args(definitions.ExceptionContainerType): - exceptions_containers[list_id] = result - elif result_type in get_args(definitions.ExceptionItemType): - exceptions_items[list_id].append(result) + if result_type and list_id: + if result_type in get_args(definitions.ExceptionContainerType): + exceptions_containers[list_id] = result + elif result_type in get_args(definitions.ExceptionItemType): + exceptions_items[list_id].append(result) else: unparsed_results.append(result) - return exceptions_containers, exceptions_items, errors, unparsed_results + return exceptions_containers, exceptions_items, [], unparsed_results -def build_exception_objects(exceptions_containers: List[dict], exceptions_items: List[dict], - exception_list_rule_table: dict, exceptions_directory: Path, save_toml: bool = False, - skip_errors: bool = False, verbose=False, - ) -> Tuple[List[TOMLException], List[str], List[str]]: +def build_exception_objects( # noqa: PLR0913 + exceptions_containers: dict[str, Any], + exceptions_items: dict[str, Any], + exception_list_rule_table: dict[str, Any], + exceptions_directory: Path | None, + save_toml: bool = False, + skip_errors: bool = False, + verbose: bool = False, +) -> tuple[list[TOMLException], list[str], list[str]]: """Build TOMLException objects from a list of exception dictionaries.""" - output = [] - errors = [] - toml_exceptions = [] + output: list[str] = [] + errors: list[str] = [] + toml_exceptions: list[TOMLException] = [] for container in exceptions_containers.values(): try: - list_id = container.get("list_id") - items = exceptions_items.get(list_id) + list_id = container["list_id"] + items = exceptions_items[list_id] contents = TOMLExceptionContents.from_exceptions_dict( {"container": container, "items": items}, - exception_list_rule_table.get(list_id), + exception_list_rule_table[list_id], ) filename = f"{list_id}_exceptions.toml" if RULES_CONFIG.exception_dir is None and not exceptions_directory: - raise FileNotFoundError( + raise FileNotFoundError( # noqa: TRY301 "No Exceptions directory is specified. Please specify either in the config or CLI." ) exceptions_path = ( - Path(exceptions_directory) / filename - if exceptions_directory - else RULES_CONFIG.exception_dir / filename + Path(exceptions_directory) / filename if exceptions_directory else RULES_CONFIG.exception_dir / filename ) if verbose: output.append(f"[+] Building exception(s) for {exceptions_path}") diff --git a/detection_rules/generic_loader.py b/detection_rules/generic_loader.py index 41f91bee82b..d58e4f61f92 100644 --- a/detection_rules/generic_loader.py +++ b/detection_rules/generic_loader.py @@ -4,26 +4,29 @@ # 2.0. """Load generic toml formatted files for exceptions and actions.""" + +from collections.abc import Callable, Iterable, Iterator from pathlib import Path -from typing import Callable, Dict, Iterable, List, Optional, Union +from typing import TYPE_CHECKING, Any -import pytoml +import pytoml # type: ignore[reportMissingTypeStubs] from .action import TOMLAction, TOMLActionContents from .action_connector import TOMLActionConnector, TOMLActionConnectorContents from .config import parse_rules_config from .exception import TOMLException, TOMLExceptionContents from .rule_loader import dict_filter -from .schemas import definitions +if TYPE_CHECKING: + from .schemas import definitions RULES_CONFIG = parse_rules_config() -GenericCollectionTypes = Union[TOMLAction, TOMLActionConnector, TOMLException] -GenericCollectionContentTypes = Union[TOMLActionContents, TOMLActionConnectorContents, TOMLExceptionContents] +GenericCollectionTypes = TOMLAction | TOMLActionConnector | TOMLException +GenericCollectionContentTypes = TOMLActionContents | TOMLActionConnectorContents | TOMLExceptionContents -def metadata_filter(**metadata) -> Callable[[GenericCollectionTypes], bool]: +def metadata_filter(**metadata: Any) -> Callable[[GenericCollectionTypes], bool]: """Get a filter callback based off item metadata""" flt = dict_filter(metadata) @@ -37,49 +40,49 @@ def callback(item: GenericCollectionTypes) -> bool: class GenericCollection: """Generic collection for action and exception objects.""" - items: list + items: list[GenericCollectionTypes] __default = None - def __init__(self, items: Optional[List[GenericCollectionTypes]] = None): - self.id_map: Dict[definitions.UUIDString, GenericCollectionTypes] = {} - self.file_map: Dict[Path, GenericCollectionTypes] = {} - self.name_map: Dict[definitions.RuleName, GenericCollectionTypes] = {} - self.items: List[GenericCollectionTypes] = [] - self.errors: Dict[Path, Exception] = {} + def __init__(self, items: list[GenericCollectionTypes] | None = None) -> None: + self.id_map: dict[definitions.UUIDString, GenericCollectionTypes] = {} + self.file_map: dict[Path, GenericCollectionTypes] = {} + self.name_map: dict[definitions.RuleName, GenericCollectionTypes] = {} + self.items: list[GenericCollectionTypes] = [] + self.errors: dict[Path, Exception] = {} self.frozen = False - self._toml_load_cache: Dict[Path, dict] = {} + self._toml_load_cache: dict[Path, dict[str, Any]] = {} - for items in (items or []): - self.add_item(items) + for item in items or []: + self.add_item(item) def __len__(self) -> int: """Get the total amount of exceptions in the collection.""" return len(self.items) - def __iter__(self) -> Iterable[GenericCollectionTypes]: + def __iter__(self) -> Iterator[GenericCollectionTypes]: """Iterate over all items in the collection.""" return iter(self.items) def __contains__(self, item: GenericCollectionTypes) -> bool: """Check if an item is in the map by comparing IDs.""" - return item.id in self.id_map + return item.id in self.id_map # type: ignore[reportAttributeAccessIssue] - def filter(self, cb: Callable[[TOMLException], bool]) -> 'GenericCollection': + def filter(self, cb: Callable[[TOMLException], bool]) -> "GenericCollection": """Retrieve a filtered collection of items.""" filtered_collection = GenericCollection() - for item in filter(cb, self.items): + for item in filter(cb, self.items): # type: ignore[reportCallIssue] filtered_collection.add_item(item) return filtered_collection @staticmethod - def deserialize_toml_string(contents: Union[bytes, str]) -> dict: + def deserialize_toml_string(contents: bytes | str) -> dict[str, Any]: """Deserialize a TOML string into a dictionary.""" - return pytoml.loads(contents) + return pytoml.loads(contents) # type: ignore[reportUnknownVariableType] - def _load_toml_file(self, path: Path) -> dict: + def _load_toml_file(self, path: Path) -> dict[str, Any]: """Load a TOML file into a dictionary.""" if path in self._toml_load_cache: return self._toml_load_cache[path] @@ -92,22 +95,25 @@ def _load_toml_file(self, path: Path) -> dict: self._toml_load_cache[path] = toml_dict return toml_dict - def _get_paths(self, directory: Path, recursive=True) -> List[Path]: + def _get_paths(self, directory: Path, recursive: bool = True) -> list[Path]: """Get all TOML files in a directory.""" - return sorted(directory.rglob('*.toml') if recursive else directory.glob('*.toml')) + return sorted(directory.rglob("*.toml") if recursive else directory.glob("*.toml")) def _assert_new(self, item: GenericCollectionTypes) -> None: """Assert that the item is new and can be added to the collection.""" file_map = self.file_map name_map = self.name_map - assert not self.frozen, f"Unable to add item {item.name} to a frozen collection" - assert item.name not in name_map, \ - f"Rule Name {item.name} collides with {name_map[item.name].name}" + if self.frozen: + raise ValueError(f"Unable to add item {item.name} to a frozen collection") + + if item.name in name_map: + raise ValueError(f"Rule Name {item.name} collides with {name_map[item.name].name}") if item.path is not None: item_path = item.path.resolve() - assert item_path not in file_map, f"Item file {item_path} already loaded" + if item_path in file_map: + raise ValueError(f"Item file {item_path} already loaded") file_map[item_path] = item def add_item(self, item: GenericCollectionTypes) -> None: @@ -116,15 +122,19 @@ def add_item(self, item: GenericCollectionTypes) -> None: self.name_map[item.name] = item self.items.append(item) - def load_dict(self, obj: dict, path: Optional[Path] = None) -> GenericCollectionTypes: + def load_dict(self, obj: dict[str, Any], path: Path | None = None) -> GenericCollectionTypes: """Load a dictionary into the collection.""" - if 'exceptions' in obj: + if "exceptions" in obj: contents = TOMLExceptionContents.from_dict(obj) item = TOMLException(path=path, contents=contents) - elif 'actions' in obj: + elif "actions" in obj: contents = TOMLActionContents.from_dict(obj) + if not path: + raise ValueError("No path value provided") item = TOMLAction(path=path, contents=contents) - elif 'action_connectors' in obj: + elif "action_connectors" in obj: + if not path: + raise ValueError("No path value provided") contents = TOMLActionConnectorContents.from_dict(obj) item = TOMLActionConnector(path=path, contents=contents) else: @@ -140,11 +150,10 @@ def load_file(self, path: Path) -> GenericCollectionTypes: # use the default generic loader as a cache. # if it already loaded the item, then we can just use it from that - if self.__default is not None and self is not self.__default: - if path in self.__default.file_map: - item = self.__default.file_map[path] - self.add_item(item) - return item + if self.__default and self is not self.__default and path in self.__default.file_map: + item = self.__default.file_map[path] + self.add_item(item) + return item obj = self._load_toml_file(path) return self.load_dict(obj, path=path) @@ -155,10 +164,13 @@ def load_file(self, path: Path) -> GenericCollectionTypes: def load_files(self, paths: Iterable[Path]) -> None: """Load multiple files into the collection.""" for path in paths: - self.load_file(path) + _ = self.load_file(path) def load_directory( - self, directory: Path, recursive=True, toml_filter: Optional[Callable[[dict], bool]] = None + self, + directory: Path, + recursive: bool = True, + toml_filter: Callable[[dict[str, Any]], bool] | None = None, ) -> None: """Load all TOML files in a directory.""" paths = self._get_paths(directory, recursive=recursive) @@ -168,7 +180,10 @@ def load_directory( self.load_files(paths) def load_directories( - self, directories: Iterable[Path], recursive=True, toml_filter: Optional[Callable[[dict], bool]] = None + self, + directories: Iterable[Path], + recursive: bool = True, + toml_filter: Callable[[dict[str, Any]], bool] | None = None, ) -> None: """Load all TOML files in multiple directories.""" for path in directories: @@ -179,7 +194,7 @@ def freeze(self) -> None: self.frozen = True @classmethod - def default(cls) -> 'GenericCollection': + def default(cls) -> "GenericCollection": """Return the default item collection, which retrieves from default config location.""" if cls.__default is None: collection = GenericCollection() diff --git a/detection_rules/ghwrap.py b/detection_rules/ghwrap.py index 133bd7132c2..1f71cbbd3e9 100644 --- a/detection_rules/ghwrap.py +++ b/detection_rules/ghwrap.py @@ -12,50 +12,41 @@ import shutil import time from dataclasses import dataclass, field -from datetime import datetime +from datetime import UTC, datetime from pathlib import Path -from typing import Dict, Optional, Tuple +from typing import Any from zipfile import ZipFile import click import requests +from github import Github +from github.GitRelease import GitRelease +from github.GitReleaseAsset import GitReleaseAsset +from github.Repository import Repository from requests import Response from .schemas import definitions -# this is primarily for type hinting - all use of the github client should come from GithubClient class -try: - from github import Github - from github.Repository import Repository - from github.GitRelease import GitRelease - from github.GitReleaseAsset import GitReleaseAsset -except ImportError: - # for type hinting - Github = None # noqa: N806 - Repository = None # noqa: N806 - GitRelease = None # noqa: N806 - GitReleaseAsset = None # noqa: N806 - - -def get_gh_release(repo: Repository, release_name: Optional[str] = None, tag_name: Optional[str] = None) -> GitRelease: + +def get_gh_release(repo: Repository, release_name: str | None = None, tag_name: str | None = None) -> GitRelease | None: """Get a list of GitHub releases by repo.""" - assert release_name or tag_name, 'Must specify a release_name or tag_name' + if not release_name and not tag_name: + raise ValueError("Must specify a release_name or tag_name") releases = repo.get_releases() for release in releases: - if release_name and release_name == release.title: - return release - elif tag_name and tag_name == release.tag_name: + if (release_name and release_name == release.title) or (tag_name and tag_name == release.tag_name): return release + return None -def load_zipped_gh_assets_with_metadata(url: str) -> Tuple[str, dict]: +def load_zipped_gh_assets_with_metadata(url: str) -> tuple[str, dict[str, Any]]: """Download and unzip a GitHub assets.""" - response = requests.get(url) + response = requests.get(url, timeout=30) zipped_asset = ZipFile(io.BytesIO(response.content)) zipped_sha256 = hashlib.sha256(response.content).hexdigest() - assets = {} + assets: dict[str, Any] = {} for zipped in zipped_asset.filelist: if zipped.is_dir(): continue @@ -64,29 +55,29 @@ def load_zipped_gh_assets_with_metadata(url: str) -> Tuple[str, dict]: sha256 = hashlib.sha256(contents).hexdigest() assets[zipped.filename] = { - 'contents': contents, - 'metadata': { - 'compress_size': zipped.compress_size, + "contents": contents, + "metadata": { + "compress_size": zipped.compress_size, # zipfile provides only a 6 tuple datetime; -1 means DST is unknown; 0's set tm_wday and tm_yday - 'created_at': time.strftime('%Y-%m-%dT%H:%M:%SZ', zipped.date_time + (0, 0, -1)), - 'sha256': sha256, - 'size': zipped.file_size, - } + "created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", (*zipped.date_time, 0, 0, -1)), + "sha256": sha256, + "size": zipped.file_size, + }, } return zipped_sha256, assets -def load_json_gh_asset(url: str) -> dict: +def load_json_gh_asset(url: str) -> dict[str, Any]: """Load and return the contents of a json asset file.""" - response = requests.get(url) + response = requests.get(url, timeout=30) response.raise_for_status() return response.json() -def download_gh_asset(url: str, path: str, overwrite=False): +def download_gh_asset(url: str, path: str, overwrite: bool = False) -> None: """Download and unzip a GitHub asset.""" - zipped = requests.get(url) + zipped = requests.get(url, timeout=30) z = ZipFile(io.BytesIO(zipped.content)) Path(path).mkdir(exist_ok=True) @@ -94,41 +85,40 @@ def download_gh_asset(url: str, path: str, overwrite=False): shutil.rmtree(path, ignore_errors=True) z.extractall(path) - click.echo(f'files saved to {path}') + click.echo(f"files saved to {path}") z.close() -def update_gist(token: str, - file_map: Dict[Path, str], - description: str, - gist_id: str, - public=False, - pre_purge=False) -> Response: +def update_gist( # noqa: PLR0913 + token: str, + file_map: dict[Path, str], + description: str, + gist_id: str, + public: bool = False, + pre_purge: bool = False, +) -> Response: """Update existing gist.""" - url = f'https://api.github.com/gists/{gist_id}' - headers = { - 'accept': 'application/vnd.github.v3+json', - 'Authorization': f'token {token}' - } - body = { - 'description': description, - 'files': {}, # {path.name: {'content': contents} for path, contents in file_map.items()}, - 'public': public + url = f"https://api.github.com/gists/{gist_id}" + headers = {"accept": "application/vnd.github.v3+json", "Authorization": f"token {token}"} + body: dict[str, Any] = { + "description": description, + "files": {}, # {path.name: {'content': contents} for path, contents in file_map.items()}, + "public": public, } if pre_purge: # retrieve all existing file names which are not in the file_map and overwrite them to empty to delete files - response = requests.get(url) + response = requests.get(url, timeout=30) response.raise_for_status() data = response.json() - files = list(data['files']) - body['files'] = {file: {} for file in files if file not in file_map} - response = requests.patch(url, headers=headers, json=body) + files = list(data["files"]) + body["files"] = {file: {} for file in files if file not in file_map} + response = requests.patch(url, headers=headers, json=body, timeout=30) response.raise_for_status() - body['files'] = {path.name: {'content': contents} for path, contents in file_map.items()} - response = requests.patch(url, headers=headers, json=body) + body["files"] = {path.name: {"content": contents} for path, contents in file_map.items()} + response = requests.patch(url, headers=headers, json=body, timeout=30) response.raise_for_status() return response @@ -136,34 +126,33 @@ def update_gist(token: str, class GithubClient: """GitHub client wrapper.""" - def __init__(self, token: Optional[str] = None): + def __init__(self, token: str | None = None) -> None: """Get an unauthenticated client, verified authenticated client, or a default client.""" self.assert_github() - self.client: Github = Github(token) + self.client = Github(token) self.unauthenticated_client = Github() self.__token = token self.__authenticated_client = None @classmethod - def assert_github(cls): + def assert_github(cls) -> None: if not Github: - raise ModuleNotFoundError('Missing PyGithub - try running `pip3 install .[dev]`') + raise ModuleNotFoundError("Missing PyGithub - try running `pip3 install .[dev]`") @property def authenticated_client(self) -> Github: if not self.__token: - raise ValueError('Token not defined! Re-instantiate with a token or use add_token method') + raise ValueError("Token not defined! Re-instantiate with a token or use add_token method") if not self.__authenticated_client: self.__authenticated_client = Github(self.__token) return self.__authenticated_client - def add_token(self, token): + def add_token(self, token: str) -> None: self.__token = token @dataclass class AssetManifestEntry: - compress_size: int created_at: datetime name: str @@ -173,18 +162,16 @@ class AssetManifestEntry: @dataclass class AssetManifestMetadata: - relative_url: str - entries: Dict[str, AssetManifestEntry] + entries: dict[str, AssetManifestEntry] zipped_sha256: definitions.Sha256 - created_at: datetime = field(default_factory=datetime.utcnow) - description: Optional[str] = None # populated by GitHub release asset label + created_at: datetime = field(default_factory=lambda: datetime.now(UTC)) + description: str | None = None # populated by GitHub release asset label @dataclass class ReleaseManifest: - - assets: Dict[str, AssetManifestMetadata] + assets: dict[str, AssetManifestMetadata] assets_url: str author: str # parsed from GitHub release metadata as: author[login] created_at: str @@ -194,15 +181,20 @@ class ReleaseManifest: published_at: str url: str zipball_url: str - tag_name: str = None - description: str = None # parsed from GitHub release metadata as: body + tag_name: str | None = None + description: str | None = None # parsed from GitHub release metadata as: body class ManifestManager: """Manifest handler for GitHub releases.""" - def __init__(self, repo: str = 'elastic/detection-rules', release_name: Optional[str] = None, - tag_name: Optional[str] = None, token: Optional[str] = None): + def __init__( + self, + repo: str = "elastic/detection-rules", + release_name: str | None = None, + tag_name: str | None = None, + token: str | None = None, + ) -> None: self.repo_name = repo self.release_name = release_name self.tag_name = tag_name @@ -210,34 +202,37 @@ def __init__(self, repo: str = 'elastic/detection-rules', release_name: Optional self.has_token = token is not None self.repo: Repository = self.gh_client.client.get_repo(repo) - self.release: GitRelease = get_gh_release(self.repo, release_name, tag_name) + release = get_gh_release(self.repo, release_name, tag_name) + if not release: + raise ValueError("No release info found") + self.release = release if not self.release: - raise ValueError(f'No release found for {tag_name or release_name}') + raise ValueError(f"No release found for {tag_name or release_name}") if not self.release_name: self.release_name = self.release.title - self.manifest_name = f'manifest-{self.release_name}.json' - self.assets: dict = self._get_enriched_assets_from_release() + self.manifest_name = f"manifest-{self.release_name}.json" + self.assets = self._get_enriched_assets_from_release() self.release_manifest = self._create() self.__release_manifest_dict = dataclasses.asdict(self.release_manifest) self.manifest_size = len(json.dumps(self.__release_manifest_dict)) @property def release_manifest_fl(self) -> io.BytesIO: - return io.BytesIO(json.dumps(self.__release_manifest_dict, sort_keys=True).encode('utf-8')) + return io.BytesIO(json.dumps(self.__release_manifest_dict, sort_keys=True).encode("utf-8")) def _create(self) -> ReleaseManifest: """Create the manifest from GitHub asset metadata and file contents.""" assets = {} for asset_name, asset_data in self.assets.items(): - entries = {} - data = asset_data['data'] - metadata = asset_data['metadata'] + entries: dict[str, AssetManifestEntry] = {} + data = asset_data["data"] + metadata = asset_data["metadata"] for file_name, file_data in data.items(): - file_metadata = file_data['metadata'] + file_metadata = file_data["metadata"] name = Path(file_name).name file_metadata.update(name=name) @@ -245,59 +240,71 @@ def _create(self) -> ReleaseManifest: entry = AssetManifestEntry(**file_metadata) entries[name] = entry - assets[asset_name] = AssetManifestMetadata(metadata['browser_download_url'], entries, - metadata['zipped_sha256'], metadata['created_at'], - metadata['label']) + assets[asset_name] = AssetManifestMetadata( + metadata["browser_download_url"], + entries, + metadata["zipped_sha256"], + metadata["created_at"], + metadata["label"], + ) release_metadata = self._parse_release_metadata() release_metadata.update(assets=assets) - release_manifest = ReleaseManifest(**release_metadata) - - return release_manifest + return ReleaseManifest(**release_metadata) - def _parse_release_metadata(self) -> dict: + def _parse_release_metadata(self) -> dict[str, Any]: """Parse relevant info from GitHub metadata for release manifest.""" - ignore = ['assets'] - manual_set_keys = ['author', 'description'] + ignore = ["assets"] + manual_set_keys = ["author", "description"] keys = [f.name for f in dataclasses.fields(ReleaseManifest) if f.name not in ignore + manual_set_keys] parsed = {k: self.release.raw_data[k] for k in keys} - parsed.update(description=self.release.raw_data['body'], author=self.release.raw_data['author']['login']) + parsed.update(description=self.release.raw_data["body"], author=self.release.raw_data["author"]["login"]) return parsed def save(self) -> GitReleaseAsset: """Save manifest files.""" if not self.has_token: - raise ValueError('You must provide a token to save a manifest to a GitHub release') + raise ValueError("You must provide a token to save a manifest to a GitHub release") - asset = self.release.upload_asset_from_memory(self.release_manifest_fl, - self.manifest_size, - self.manifest_name) - click.echo(f'Manifest saved as {self.manifest_name} to {self.release.html_url}') + asset = self.release.upload_asset_from_memory(self.release_manifest_fl, self.manifest_size, self.manifest_name) + click.echo(f"Manifest saved as {self.manifest_name} to {self.release.html_url}") return asset @classmethod - def load(cls, name: str, repo: str = 'elastic/detection-rules', token: Optional[str] = None) -> Optional[dict]: + def load( + cls, + name: str, + repo_name: str = "elastic/detection-rules", + token: str | None = None, + ) -> dict[str, Any] | None: """Load a manifest.""" gh_client = GithubClient(token) - repo = gh_client.client.get_repo(repo) + repo = gh_client.client.get_repo(repo_name) release = get_gh_release(repo, tag_name=name) + if not release: + raise ValueError("No release info found") + for asset in release.get_assets(): - if asset.name == f'manifest-{name}.json': + if asset.name == f"manifest-{name}.json": return load_json_gh_asset(asset.browser_download_url) + return None @classmethod - def load_all(cls, repo: str = 'elastic/detection-rules', token: Optional[str] = None - ) -> Tuple[Dict[str, dict], list]: + def load_all( + cls, + repo_name: str = "elastic/detection-rules", + token: str | None = None, + ) -> tuple[dict[str, dict[str, Any]], list[str]]: """Load a consolidated manifest.""" gh_client = GithubClient(token) - repo = gh_client.client.get_repo(repo) + repo = gh_client.client.get_repo(repo_name) - consolidated = {} - missing = set() + consolidated: dict[str, dict[str, Any]] = {} + missing: set[str] = set() for release in repo.get_releases(): name = release.tag_name - asset = next((a for a in release.get_assets() if a.name == f'manifest-{name}.json'), None) + asset = next((a for a in release.get_assets() if a.name == f"manifest-{name}.json"), None) if not asset: missing.add(name) else: @@ -306,28 +313,29 @@ def load_all(cls, repo: str = 'elastic/detection-rules', token: Optional[str] = return consolidated, list(missing) @classmethod - def get_existing_asset_hashes(cls, repo: str = 'elastic/detection-rules', token: Optional[str] = None) -> dict: + def get_existing_asset_hashes( + cls, + repo: str = "elastic/detection-rules", + token: str | None = None, + ) -> dict[str, Any]: """Load all assets with their hashes, by release.""" - flat = {} - consolidated, _ = cls.load_all(repo=repo, token=token) + flat: dict[str, Any] = {} + consolidated, _ = cls.load_all(repo_name=repo, token=token) for release, data in consolidated.items(): - for asset in data['assets'].values(): + for asset in data["assets"].values(): flat_release = flat[release] = {} - for asset_name, asset_data in asset['entries'].items(): - flat_release[asset_name] = asset_data['sha256'] + for asset_name, asset_data in asset["entries"].items(): + flat_release[asset_name] = asset_data["sha256"] return flat - def _get_enriched_assets_from_release(self) -> dict: + def _get_enriched_assets_from_release(self) -> dict[str, Any]: """Get assets and metadata from a GitHub release.""" - assets = {} + assets: dict[str, Any] = {} for asset in [a.raw_data for a in self.release.get_assets()]: - zipped_sha256, data = load_zipped_gh_assets_with_metadata(asset['browser_download_url']) + zipped_sha256, data = load_zipped_gh_assets_with_metadata(asset["browser_download_url"]) asset.update(zipped_sha256=zipped_sha256) - assets[asset['name']] = { - 'metadata': asset, - 'data': data - } + assets[asset["name"]] = {"metadata": asset, "data": data} return assets diff --git a/detection_rules/integrations.py b/detection_rules/integrations.py index 14ce262b8d1..172204888b4 100644 --- a/detection_rules/integrations.py +++ b/detection_rules/integrations.py @@ -4,43 +4,50 @@ # 2.0. """Functions to support and interact with Kibana integrations.""" -import glob + +import fnmatch import gzip import json import re -from collections import defaultdict, OrderedDict +from collections import OrderedDict, defaultdict +from collections.abc import Iterator from pathlib import Path -from typing import Generator, List, Tuple, Union, Optional +from typing import TYPE_CHECKING, Any +import kql # type: ignore[reportMissingTypeStubs] import requests -from semver import Version import yaml from marshmallow import EXCLUDE, Schema, fields, post_load - -import kql +from semver import Version from . import ecs -from .config import load_current_package_version from .beats import flatten_ecs_schema -from .utils import cached, get_etc_path, read_gzip, unzip +from .config import load_current_package_version from .schemas import definitions +from .utils import cached, get_etc_path, read_gzip, unzip + +if TYPE_CHECKING: + from .rule import QueryRuleData, RuleMeta -MANIFEST_FILE_PATH = get_etc_path('integration-manifests.json.gz') + +MANIFEST_FILE_PATH = get_etc_path(["integration-manifests.json.gz"]) DEFAULT_MAX_RULE_VERSIONS = 1 -SCHEMA_FILE_PATH = get_etc_path('integration-schemas.json.gz') -_notified_integrations = set() +SCHEMA_FILE_PATH = get_etc_path(["integration-schemas.json.gz"]) + + +_notified_integrations: set[str] = set() @cached -def load_integrations_manifests() -> dict: +def load_integrations_manifests() -> dict[str, Any]: """Load the consolidated integrations manifest.""" - return json.loads(read_gzip(get_etc_path('integration-manifests.json.gz'))) + return json.loads(read_gzip(get_etc_path(["integration-manifests.json.gz"]))) @cached -def load_integrations_schemas() -> dict: +def load_integrations_schemas() -> dict[str, Any]: """Load the consolidated integrations schemas.""" - return json.loads(read_gzip(get_etc_path('integration-schemas.json.gz'))) + return json.loads(read_gzip(get_etc_path(["integration-schemas.json.gz"]))) class IntegrationManifestSchema(Schema): @@ -54,51 +61,56 @@ class IntegrationManifestSchema(Schema): owner = fields.Dict(required=False) @post_load - def transform_policy_template(self, data, **kwargs): + def transform_policy_template(self, data: dict[str, Any], **_: Any) -> dict[str, Any]: if "policy_templates" in data: data["policy_templates"] = [policy["name"] for policy in data["policy_templates"]] return data -def build_integrations_manifest(overwrite: bool, rule_integrations: list = [], - integration: str = None, prerelease: bool = False) -> None: +def build_integrations_manifest( + overwrite: bool, + rule_integrations: list[str] = [], # noqa: B006 + integration: str | None = None, + prerelease: bool = False, +) -> None: """Builds a new local copy of manifest.yaml from integrations Github.""" - def write_manifests(integrations: dict) -> None: - manifest_file = gzip.open(MANIFEST_FILE_PATH, "w+") + def write_manifests(integrations: dict[str, Any]) -> None: manifest_file_bytes = json.dumps(integrations).encode("utf-8") - manifest_file.write(manifest_file_bytes) - manifest_file.close() + with gzip.open(MANIFEST_FILE_PATH, "wb") as f: + _ = f.write(manifest_file_bytes) - if overwrite: - if MANIFEST_FILE_PATH.exists(): - MANIFEST_FILE_PATH.unlink() + if overwrite and MANIFEST_FILE_PATH.exists(): + MANIFEST_FILE_PATH.unlink() - final_integration_manifests = {integration: {} for integration in rule_integrations} \ - or {integration: {}} + final_integration_manifests: dict[str, dict[str, Any]] = {} + if rule_integrations: + final_integration_manifests = {integration: {} for integration in rule_integrations} + elif integration: + final_integration_manifests = {integration: {}} + rule_integrations = [integration] - rule_integrations = rule_integrations or [integration] - for integration in rule_integrations: - integration_manifests = get_integration_manifests(integration, prerelease=prerelease) + for _integration in rule_integrations: + integration_manifests = get_integration_manifests(_integration, prerelease=prerelease) for manifest in integration_manifests: - validated_manifest = IntegrationManifestSchema(unknown=EXCLUDE).load(manifest) - package_version = validated_manifest.pop("version") - final_integration_manifests[integration][package_version] = validated_manifest + validated_manifest = IntegrationManifestSchema(unknown=EXCLUDE).load(manifest) # type: ignore[reportUnknownVariableType] + package_version = validated_manifest.pop("version") # type: ignore[reportOptionalMemberAccess] + final_integration_manifests[_integration][package_version] = validated_manifest if overwrite and rule_integrations: write_manifests(final_integration_manifests) elif integration and not overwrite: - manifest_file = gzip.open(MANIFEST_FILE_PATH, "rb") - manifest_file_bytes = manifest_file.read() + with gzip.open(MANIFEST_FILE_PATH, "rb") as manifest_file: + manifest_file_bytes = manifest_file.read() + manifest_file_contents = json.loads(manifest_file_bytes.decode("utf-8")) - manifest_file.close() manifest_file_contents[integration] = final_integration_manifests[integration] write_manifests(manifest_file_contents) print(f"final integrations manifests dumped: {MANIFEST_FILE_PATH}") -def build_integrations_schemas(overwrite: bool, integration: str = None) -> None: +def build_integrations_schemas(overwrite: bool, integration: str | None = None) -> None: """Builds a new local copy of integration-schemas.json.gz from EPR integrations.""" saved_integration_schemas = {} @@ -125,85 +137,96 @@ def build_integrations_schemas(overwrite: bool, integration: str = None) -> None # Loop through the packages and versions for package, versions in integration_manifests.items(): print(f"processing {package}") - final_integration_schemas.setdefault(package, {}) + final_integration_schemas.setdefault(package, {}) # type: ignore[reportUnknownMemberType] for version, manifest in versions.items(): if package in saved_integration_schemas and version in saved_integration_schemas[package]: continue # Download the zip file download_url = f"https://epr.elastic.co{manifest['download']}" - response = requests.get(download_url) + response = requests.get(download_url, timeout=30) response.raise_for_status() # Update the final integration schemas - final_integration_schemas[package].update({version: {}}) + final_integration_schemas[package].update({version: {}}) # type: ignore[reportUnknownMemberType] # Open the zip file with unzip(response.content) as zip_ref: for file in zip_ref.namelist(): file_data_bytes = zip_ref.read(file) # Check if the file is a match - if glob.fnmatch.fnmatch(file, '*/fields/*.yml'): + if fnmatch.fnmatch(file, "*/fields/*.yml"): integration_name = Path(file).parent.parent.name - final_integration_schemas[package][version].setdefault(integration_name, {}) + final_integration_schemas[package][version].setdefault(integration_name, {}) # type: ignore[reportUnknownMemberType] schema_fields = yaml.safe_load(file_data_bytes) # Parse the schema and add to the integration_manifests data = flatten_ecs_schema(schema_fields) - flat_data = {field['name']: field['type'] for field in data} + flat_data = {field["name"]: field["type"] for field in data} - final_integration_schemas[package][version][integration_name].update(flat_data) + final_integration_schemas[package][version][integration_name].update(flat_data) # type: ignore[reportUnknownMemberType] # add machine learning jobs to the schema - if package in list(map(str.lower, definitions.MACHINE_LEARNING_PACKAGES)): - if glob.fnmatch.fnmatch(file, '*/ml_module/*ml.json'): - ml_module = json.loads(file_data_bytes) - job_ids = [job['id'] for job in ml_module['attributes']['jobs']] - final_integration_schemas[package][version]['jobs'] = job_ids + if package in [str.lower(x) for x in definitions.MACHINE_LEARNING_PACKAGES] and fnmatch.fnmatch( + file, "*/ml_module/*ml.json" + ): + ml_module = json.loads(file_data_bytes) + job_ids = [job["id"] for job in ml_module["attributes"]["jobs"]] + final_integration_schemas[package][version]["jobs"] = job_ids del file_data_bytes # Write the final integration schemas to disk with gzip.open(SCHEMA_FILE_PATH, "w") as schema_file: schema_file_bytes = json.dumps(final_integration_schemas).encode("utf-8") - schema_file.write(schema_file_bytes) + _ = schema_file.write(schema_file_bytes) print(f"final integrations manifests dumped: {SCHEMA_FILE_PATH}") -def find_least_compatible_version(package: str, integration: str, - current_stack_version: str, packages_manifest: dict) -> str: +def find_least_compatible_version( + package: str, + integration: str, + current_stack_version: str, + packages_manifest: dict[str, Any], +) -> str: """Finds least compatible version for specified integration based on stack version supplied.""" - integration_manifests = {k: v for k, v in sorted(packages_manifest[package].items(), - key=lambda x: Version.parse(x[0]))} - current_stack_version = Version.parse(current_stack_version, optional_minor_and_patch=True) + integration_manifests = dict(sorted(packages_manifest[package].items(), key=lambda x: Version.parse(x[0]))) + stack_version = Version.parse(current_stack_version, optional_minor_and_patch=True) # filter integration_manifests to only the latest major entries - major_versions = sorted(list(set([Version.parse(manifest_version).major - for manifest_version in integration_manifests])), reverse=True) + major_versions = sorted( + {Version.parse(manifest_version).major for manifest_version in integration_manifests}, + reverse=True, + ) for max_major in major_versions: - major_integration_manifests = \ - {k: v for k, v in integration_manifests.items() if Version.parse(k).major == max_major} + major_integration_manifests = { + k: v for k, v in integration_manifests.items() if Version.parse(k).major == max_major + } # iterates through ascending integration manifests # returns latest major version that is least compatible - for version, manifest in OrderedDict(sorted(major_integration_manifests.items(), - key=lambda x: Version.parse(x[0]))).items(): - compatible_versions = re.sub(r"\>|\<|\=|\^|\~", "", - manifest["conditions"]["kibana"]["version"]).split(" || ") + for version, manifest in OrderedDict( + sorted(major_integration_manifests.items(), key=lambda x: Version.parse(x[0])) + ).items(): + compatible_versions = re.sub(r"\>|\<|\=|\^|\~", "", manifest["conditions"]["kibana"]["version"]).split( + " || " + ) for kibana_ver in compatible_versions: - kibana_ver = Version.parse(kibana_ver) + _kibana_ver = Version.parse(kibana_ver) # check versions have the same major - if kibana_ver.major == current_stack_version.major: - if kibana_ver <= current_stack_version: - return f"^{version}" + if _kibana_ver.major == stack_version.major and _kibana_ver <= stack_version: + return f"^{version}" raise ValueError(f"no compatible version for integration {package}:{integration}") -def find_latest_compatible_version(package: str, integration: str, - rule_stack_version: Version, - packages_manifest: dict) -> Union[None, Tuple[str, str]]: +def find_latest_compatible_version( + package: str, + integration: str, + rule_stack_version: Version, + packages_manifest: dict[str, Any], +) -> tuple[str, list[str]]: """Finds least compatible version for specified integration based on stack version supplied.""" if not package: @@ -215,7 +238,7 @@ def find_latest_compatible_version(package: str, integration: str, # Converts the dict keys (version numbers) to Version objects for proper sorting (descending) integration_manifests = sorted(package_manifest.items(), key=lambda x: Version.parse(x[0]), reverse=True) - notice = "" + notice = [""] for version, manifest in integration_manifests: kibana_conditions = manifest.get("conditions", {}).get("kibana", {}) @@ -228,41 +251,45 @@ def find_latest_compatible_version(package: str, integration: str, if not compatible_versions: raise ValueError(f"Manifest for {package}:{integration} version {version} is missing compatible versions") - highest_compatible_version = Version.parse(max(compatible_versions, - key=lambda x: Version.parse(x))) + highest_compatible_version = Version.parse(max(compatible_versions, key=lambda x: Version.parse(x))) if highest_compatible_version > rule_stack_version: # generate notice message that a later integration version is available integration = f" {integration.strip()}" if integration else "" - notice = (f"There is a new integration {package}{integration} version {version} available!", - f"Update the rule min_stack version from {rule_stack_version} to " - f"{highest_compatible_version} if using new features in this latest version.") + notice = [ + f"There is a new integration {package}{integration} version {version} available!", + f"Update the rule min_stack version from {rule_stack_version} to " + f"{highest_compatible_version} if using new features in this latest version.", + ] if highest_compatible_version.major == rule_stack_version.major: return version, notice - else: - # Check for rules that cross majors - for compatible_version in compatible_versions: - if Version.parse(compatible_version) <= rule_stack_version: - return version, notice + # Check for rules that cross majors + for compatible_version in compatible_versions: + if Version.parse(compatible_version) <= rule_stack_version: + return version, notice raise ValueError(f"no compatible version for integration {package}:{integration}") -def get_integration_manifests(integration: str, prerelease: Optional[bool] = False, - kibana_version: Optional[str] = "") -> list: +def get_integration_manifests( + integration: str, + prerelease: bool | None = False, + kibana_version: str | None = "", +) -> list[Any]: """Iterates over specified integrations from package-storage and combines manifests per version.""" epr_search_url = "https://epr.elastic.co/search" - if not prerelease: - prerelease = "false" - else: - prerelease = "true" + prerelease_str = "true" if prerelease else "false" # link for search parameters - https://github.com/elastic/package-registry - epr_search_parameters = {"package": f"{integration}", "prerelease": prerelease, - "all": "true", "include_policy_templates": "true"} + epr_search_parameters = { + "package": f"{integration}", + "prerelease": prerelease_str, + "all": "true", + "include_policy_templates": "true", + } if kibana_version: epr_search_parameters["kibana.version"] = kibana_version epr_search_response = requests.get(epr_search_url, params=epr_search_parameters, timeout=10) @@ -273,46 +300,45 @@ def get_integration_manifests(integration: str, prerelease: Optional[bool] = Fal raise ValueError(f"EPR search for {integration} integration package returned empty list") sorted_manifests = sorted(manifests, key=lambda p: Version.parse(p["version"]), reverse=True) - print(f"loaded {integration} manifests from the following package versions: " - f"{[manifest['version'] for manifest in sorted_manifests]}") + print( + f"loaded {integration} manifests from the following package versions: " + f"{[manifest['version'] for manifest in sorted_manifests]}" + ) return manifests def find_latest_integration_version(integration: str, maturity: str, stack_version: Version) -> Version: """Finds the latest integration version based on maturity and stack version""" - prerelease = False if maturity == "ga" else True + prerelease = maturity != "ga" existing_pkgs = get_integration_manifests(integration, prerelease, str(stack_version)) if maturity == "ga": - existing_pkgs = [pkg for pkg in existing_pkgs if not - Version.parse(pkg["version"]).prerelease] + existing_pkgs = [pkg for pkg in existing_pkgs if not Version.parse(pkg["version"]).prerelease] if maturity == "beta": - existing_pkgs = [pkg for pkg in existing_pkgs if - Version.parse(pkg["version"]).prerelease] + existing_pkgs = [pkg for pkg in existing_pkgs if Version.parse(pkg["version"]).prerelease] return max([Version.parse(pkg["version"]) for pkg in existing_pkgs]) -def get_integration_schema_data(data, meta, package_integrations: dict) -> Generator[dict, None, None]: +# Using `Any` here because `integrations` and `rule` modules are tightly coupled +def get_integration_schema_data( + data: Any, # type: ignore[reportRedeclaration] + meta: Any, # type: ignore[reportRedeclaration] + package_integrations: list[dict[str, Any]], +) -> Iterator[dict[str, Any]]: """Iterates over specified integrations from package-storage and combines schemas per version.""" - # lazy import to avoid circular import - from .rule import ( # pylint: disable=import-outside-toplevel - ESQLRuleData, QueryRuleData, RuleMeta) - - data: QueryRuleData = data - meta: RuleMeta = meta + data: QueryRuleData = data # type: ignore[reportAssignmentType] # noqa: PLW0127 + meta: RuleMeta = meta # noqa: PLW0127 packages_manifest = load_integrations_manifests() integrations_schemas = load_integrations_schemas() # validate the query against related integration fields - if (isinstance(data, QueryRuleData) or isinstance(data, ESQLRuleData)) \ - and data.language != 'lucene' and meta.maturity == "production": - + if data.language != "lucene" and meta.maturity == "production": for stack_version, mapping in meta.get_validation_stack_versions().items(): - ecs_version = mapping['ecs'] - endgame_version = mapping['endgame'] + ecs_version = mapping["ecs"] + endgame_version = mapping["endgame"] - ecs_schema = ecs.flatten_multi_fields(ecs.get_schema(ecs_version, name='ecs_flat')) + ecs_schema = ecs.flatten_multi_fields(ecs.get_schema(ecs_version, name="ecs_flat")) for pk_int in package_integrations: package = pk_int["package"] @@ -323,20 +349,37 @@ def get_integration_schema_data(data, meta, package_integrations: dict) -> Gener min_stack = Version.parse(min_stack, optional_minor_and_patch=True) # Extract the integration schema fields - integration_schema, package_version = get_integration_schema_fields(integrations_schemas, package, - integration, min_stack, - packages_manifest, ecs_schema, - data) - - data = {"schema": integration_schema, "package": package, "integration": integration, - "stack_version": stack_version, "ecs_version": ecs_version, - "package_version": package_version, "endgame_version": endgame_version} - yield data - - -def get_integration_schema_fields(integrations_schemas: dict, package: str, integration: str, - min_stack: Version, packages_manifest: dict, - ecs_schema: dict, data: dict) -> dict: + integration_schema, package_version = get_integration_schema_fields( + integrations_schemas, + package, + integration, + min_stack, + packages_manifest, + ecs_schema, + data, + ) + + yield { + "schema": integration_schema, + "package": package, + "integration": integration, + "stack_version": stack_version, + "ecs_version": ecs_version, + "package_version": package_version, + "endgame_version": endgame_version, + } + + +def get_integration_schema_fields( # noqa: PLR0913 + integrations_schemas: dict[str, Any], + package: str, + integration: str, + min_stack: Version, + packages_manifest: dict[str, Any], + ecs_schema: dict[str, Any], + data: Any, # type: ignore[reportRedeclaration] +) -> tuple[dict[str, Any], str]: + data: QueryRuleData = data # type: ignore[reportAssignmentType] # noqa: PLW0127 """Extracts the integration fields to schema based on package integrations.""" package_version, notice = find_latest_compatible_version(package, integration, min_stack, packages_manifest) notify_user_if_update_available(data, notice, integration) @@ -348,25 +391,37 @@ def get_integration_schema_fields(integrations_schemas: dict, package: str, inte return integration_schema, package_version -def notify_user_if_update_available(data: dict, notice: list, integration: str) -> None: +def notify_user_if_update_available( + data: Any, # type: ignore[reportRedeclaration] + notice: list[str], + integration: str, +) -> None: """Notifies the user if an update is available, only once per integration.""" - global _notified_integrations - if notice and data.get("notify", False) and integration not in _notified_integrations: + data: QueryRuleData = data # type: ignore[reportAssignmentType] # noqa: PLW0127 + if notice and data.get("notify", False) and integration not in _notified_integrations: # flag to only warn once per integration for available upgrades _notified_integrations.add(integration) print(f"\n{data.get('name')}") - print('\n'.join(notice)) + print("\n".join(notice)) -def collect_schema_fields(integrations_schemas: dict, package: str, package_version: str, - integration: Optional[str] = None) -> dict: +def collect_schema_fields( + integrations_schemas: dict[str, Any], + package: str, + package_version: str, + integration: str | None = None, +) -> dict[str, Any]: """Collects the schema fields for a given integration.""" if integration is None: - return {field: value for dataset in integrations_schemas[package][package_version] if dataset != "jobs" - for field, value in integrations_schemas[package][package_version][dataset].items()} + return { + field: value + for dataset in integrations_schemas[package][package_version] + if dataset != "jobs" + for field, value in integrations_schemas[package][package_version][dataset].items() + } if integration not in integrations_schemas[package][package_version]: raise ValueError(f"Integration {integration} not found in package {package} version {package_version}") @@ -374,21 +429,20 @@ def collect_schema_fields(integrations_schemas: dict, package: str, package_vers return integrations_schemas[package][package_version][integration] -def parse_datasets(datasets: list, package_manifest: dict) -> List[Optional[dict]]: +def parse_datasets(datasets: list[str], package_manifest: dict[str, Any]) -> list[dict[str, Any]]: """Parses datasets into packaged integrations from rule data.""" - packaged_integrations = [] - for value in sorted(datasets): - + packaged_integrations: list[dict[str, Any]] = [] + for _value in sorted(datasets): # cleanup extra quotes pulled from ast field - value = value.strip('"') + value = _value.strip('"') - integration = 'Unknown' - if '.' in value: - package, integration = value.split('.', 1) + integration = "Unknown" + if "." in value: + package, integration = value.split(".", 1) # Handle cases where endpoint event datasource needs to be parsed uniquely (e.g endpoint.events.network) # as endpoint.network if package == "endpoint" and "events" in integration: - integration = integration.split('.')[1] + integration = integration.split(".")[1] else: package = value @@ -400,29 +454,34 @@ def parse_datasets(datasets: list, package_manifest: dict) -> List[Optional[dict class SecurityDetectionEngine: """Dedicated to Security Detection Engine integration.""" - def __init__(self): + def __init__(self) -> None: self.epr_url = "https://epr.elastic.co/package/security_detection_engine/" - def load_integration_assets(self, package_version: Version) -> dict: + def load_integration_assets(self, package_version: Version) -> dict[str, Any]: """Loads integration assets into memory.""" - epr_package_url = f"{self.epr_url}{str(package_version)}/" + epr_package_url = f"{self.epr_url}{package_version!s}/" epr_response = requests.get(epr_package_url, timeout=10) epr_response.raise_for_status() package_obj = epr_response.json() zip_url = f"https://epr.elastic.co{package_obj['download']}" - zip_response = requests.get(zip_url) + zip_response = requests.get(zip_url, timeout=30) with unzip(zip_response.content) as zip_package: asset_file_names = [asset for asset in zip_package.namelist() if "json" in asset] - assets = {x.split("/")[-1].replace(".json", ""): json.loads(zip_package.read(x).decode('utf-8')) - for x in asset_file_names} - return assets - - def keep_latest_versions(self, assets: dict, num_versions: int = DEFAULT_MAX_RULE_VERSIONS) -> dict: + return { + x.split("/")[-1].replace(".json", ""): json.loads(zip_package.read(x).decode("utf-8")) + for x in asset_file_names + } + + def keep_latest_versions( + self, + assets: dict[str, Any], + num_versions: int = DEFAULT_MAX_RULE_VERSIONS, + ) -> dict[str, Any]: """Keeps only the latest N versions of each rule to limit historical rule versions in our release package.""" # Dictionary to hold the sorted list of versions for each base rule ID - rule_versions = defaultdict(list) + rule_versions: dict[str, list[tuple[int, str]]] = defaultdict(list) # Separate rule ID and version, and group by base rule ID for key in assets: @@ -431,12 +490,12 @@ def keep_latest_versions(self, assets: dict, num_versions: int = DEFAULT_MAX_RUL rule_versions[base_id].append((version, key)) # Dictionary to hold the final assets with only the specified number of latest versions - filtered_assets = {} + filtered_assets: dict[str, Any] = {} # Keep only the last/latest num_versions versions for each rule # Sort versions and take the last num_versions # Add the latest versions of the rule to the filtered assets - for base_id, versions in rule_versions.items(): + for versions in rule_versions.values(): latest_versions = sorted(versions, key=lambda x: x[0], reverse=True)[:num_versions] for _, key in latest_versions: filtered_assets[key] = assets[key] diff --git a/detection_rules/kbwrap.py b/detection_rules/kbwrap.py index e52dd507ce7..1c0d4e72201 100644 --- a/detection_rules/kbwrap.py +++ b/detection_rules/kbwrap.py @@ -4,56 +4,59 @@ # 2.0. """Kibana cli commands.""" + import re import sys from pathlib import Path -from typing import Iterable, List, Optional +from typing import Any import click - -import kql -from kibana import Signal, RuleResource - -from .config import parse_rules_config +import kql # type: ignore[reportMissingTypeStubs] +from kibana import RuleResource, Signal # type: ignore[reportMissingTypeStubs] + +from .action_connector import ( + TOMLActionConnector, + TOMLActionConnectorContents, + build_action_connector_objects, + parse_action_connector_results_from_api, +) from .cli_utils import multi_collection -from .action_connector import (TOMLActionConnectorContents, - parse_action_connector_results_from_api, build_action_connector_objects) -from .exception import (TOMLExceptionContents, - build_exception_objects, parse_exceptions_results_from_api) +from .config import parse_rules_config +from .exception import TOMLException, TOMLExceptionContents, build_exception_objects, parse_exceptions_results_from_api from .generic_loader import GenericCollection from .main import root -from .misc import add_params, client_error, kibana_options, get_kibana_client, nested_set -from .rule import downgrade_contents_from_rule, TOMLRuleContents, TOMLRule +from .misc import add_params, get_kibana_client, kibana_options, nested_set, raise_client_error +from .rule import TOMLRule, TOMLRuleContents, downgrade_contents_from_rule from .rule_loader import RuleCollection, update_metadata_from_file from .utils import format_command_options, rulename_to_filename RULES_CONFIG = parse_rules_config() -@root.group('kibana') +@root.group("kibana") @add_params(*kibana_options) @click.pass_context -def kibana_group(ctx: click.Context, **kibana_kwargs): +def kibana_group(ctx: click.Context, **kibana_kwargs: Any) -> None: """Commands for integrating with Kibana.""" - ctx.ensure_object(dict) + _ = ctx.ensure_object(dict) # type: ignore[reportUnknownVariableType] # only initialize an kibana client if the subcommand is invoked without help (hacky) if sys.argv[-1] in ctx.help_option_names: - click.echo('Kibana client:') + click.echo("Kibana client:") click.echo(format_command_options(ctx)) else: - ctx.obj['kibana'] = get_kibana_client(**kibana_kwargs) + ctx.obj["kibana"] = get_kibana_client(**kibana_kwargs) @kibana_group.command("upload-rule") @multi_collection -@click.option('--replace-id', '-r', is_flag=True, help='Replace rule IDs with new IDs before export') +@click.option("--replace-id", "-r", is_flag=True, help="Replace rule IDs with new IDs before export") @click.pass_context -def upload_rule(ctx, rules: RuleCollection, replace_id): +def upload_rule(ctx: click.Context, rules: RuleCollection, replace_id: bool) -> list[RuleResource]: """[Deprecated] Upload a list of rule .toml files to Kibana.""" - kibana = ctx.obj['kibana'] - api_payloads = [] + kibana = ctx.obj["kibana"] + api_payloads: list[RuleResource] = [] click.secho( "WARNING: This command is deprecated as of Elastic Stack version 9.0. Please use `kibana import-rules`.", @@ -64,62 +67,71 @@ def upload_rule(ctx, rules: RuleCollection, replace_id): try: payload = downgrade_contents_from_rule(rule, kibana.version, replace_id=replace_id) except ValueError as e: - client_error(f'{e} in version:{kibana.version}, for rule: {rule.name}', e, ctx=ctx) + raise_client_error(f"{e} in version:{kibana.version}, for rule: {rule.name}", e, ctx=ctx) - rule = RuleResource(payload) - api_payloads.append(rule) + api_payloads.append(RuleResource(payload)) with kibana: - results = RuleResource.bulk_create_legacy(api_payloads) + results: list[RuleResource] = RuleResource.bulk_create_legacy(api_payloads) # type: ignore[reportUnknownMemberType] - success = [] - errors = [] + success: list[str] = [] + errors: list[str] = [] for result in results: - if 'error' in result: - errors.append(f'{result["rule_id"]} - {result["error"]["message"]}') + if "error" in result: + errors.append(f"{result['rule_id']} - {result['error']['message']}") else: - success.append(result['rule_id']) + success.append(result["rule_id"]) # type: ignore[reportUnknownArgumentType] if success: - click.echo('Successful uploads:\n - ' + '\n - '.join(success)) + click.echo("Successful uploads:\n - " + "\n - ".join(success)) if errors: - click.echo('Failed uploads:\n - ' + '\n - '.join(errors)) + click.echo("Failed uploads:\n - " + "\n - ".join(errors)) return results -@kibana_group.command('import-rules') +@kibana_group.command("import-rules") @multi_collection -@click.option('--overwrite', '-o', is_flag=True, help='Overwrite existing rules') -@click.option('--overwrite-exceptions', '-e', is_flag=True, help='Overwrite exceptions in existing rules') -@click.option('--overwrite-action-connectors', '-ac', is_flag=True, - help='Overwrite action connectors in existing rules') +@click.option("--overwrite", "-o", is_flag=True, help="Overwrite existing rules") +@click.option("--overwrite-exceptions", "-e", is_flag=True, help="Overwrite exceptions in existing rules") +@click.option( + "--overwrite-action-connectors", + "-ac", + is_flag=True, + help="Overwrite action connectors in existing rules", +) @click.pass_context -def kibana_import_rules(ctx: click.Context, rules: RuleCollection, overwrite: Optional[bool] = False, - overwrite_exceptions: Optional[bool] = False, - overwrite_action_connectors: Optional[bool] = False) -> (dict, List[RuleResource]): +def kibana_import_rules( # noqa: PLR0915 + ctx: click.Context, + rules: RuleCollection, + overwrite: bool = False, + overwrite_exceptions: bool = False, + overwrite_action_connectors: bool = False, +) -> tuple[dict[str, Any], list[RuleResource]]: """Import custom rules into Kibana.""" - def _handle_response_errors(response: dict): + + def _handle_response_errors(response: dict[str, Any]) -> None: """Handle errors from the import response.""" - def _parse_list_id(s: str): + + def _parse_list_id(s: str) -> str | None: """Parse the list ID from the error message.""" match = re.search(r'list_id: "(.*?)"', s) return match.group(1) if match else None # Re-try to address known Kibana issue: https://github.com/elastic/kibana/issues/143864 - workaround_errors = [] - workaround_error_types = set() + workaround_errors: list[str] = [] + workaround_error_types: set[str] = set() flattened_exceptions = [e for sublist in exception_dicts for e in sublist] all_exception_list_ids = {exception["list_id"] for exception in flattened_exceptions} - click.echo(f'{len(response["errors"])} rule(s) failed to import!') + click.echo(f"{len(response['errors'])} rule(s) failed to import!") action_connector_validation_error = "Error validating create data" action_connector_type_error = "expected value of type [string] but got [undefined]" - for error in response['errors']: + for error in response["errors"]: error_message = error["error"]["message"] - click.echo(f' - {error["rule_id"]}: ({error["error"]["status_code"]}) {error_message}') + click.echo(f" - {error['rule_id']}: ({error['error']['status_code']}) {error_message}") if "references a non existent exception list" in error_message: list_id = _parse_list_id(error_message) @@ -147,15 +159,19 @@ def _parse_list_id(s: str): ) click.echo() - def _process_imported_items(imported_items_list, item_type_description, item_key): + def _process_imported_items( + imported_items_list: list[list[dict[str, Any]]], + item_type_description: str, + item_key: str, + ) -> None: """Displays appropriately formatted success message that all items imported successfully.""" all_ids = {item[item_key] for sublist in imported_items_list for item in sublist} if all_ids: - click.echo(f'{len(all_ids)} {item_type_description} successfully imported') - ids_str = '\n - '.join(all_ids) - click.echo(f' - {ids_str}') + click.echo(f"{len(all_ids)} {item_type_description} successfully imported") + ids_str = "\n - ".join(all_ids) + click.echo(f" - {ids_str}") - kibana = ctx.obj['kibana'] + kibana = ctx.obj["kibana"] rule_dicts = [r.contents.to_api_format() for r in rules] with kibana: cl = GenericCollection.default() @@ -165,26 +181,26 @@ def _process_imported_items(imported_items_list, item_type_description, item_key action_connectors_dicts = [ d.contents.to_api_format() for d in cl.items if isinstance(d.contents, TOMLActionConnectorContents) ] - response, successful_rule_ids, results = RuleResource.import_rules( + response, successful_rule_ids, results = RuleResource.import_rules( # type: ignore[reportUnknownMemberType] rule_dicts, exception_dicts, action_connectors_dicts, overwrite=overwrite, overwrite_exceptions=overwrite_exceptions, - overwrite_action_connectors=overwrite_action_connectors + overwrite_action_connectors=overwrite_action_connectors, ) if successful_rule_ids: - click.echo(f'{len(successful_rule_ids)} rule(s) successfully imported') - rule_str = '\n - '.join(successful_rule_ids) - click.echo(f' - {rule_str}') - if response['errors']: - _handle_response_errors(response) + click.echo(f"{len(successful_rule_ids)} rule(s) successfully imported") # type: ignore[reportUnknownArgumentType] + rule_str = "\n - ".join(successful_rule_ids) # type: ignore[reportUnknownArgumentType] + click.echo(f" - {rule_str}") + if response["errors"]: + _handle_response_errors(response) # type: ignore[reportUnknownArgumentType] else: - _process_imported_items(exception_dicts, 'exception list(s)', 'list_id') - _process_imported_items(action_connectors_dicts, 'action connector(s)', 'id') + _process_imported_items(exception_dicts, "exception list(s)", "list_id") + _process_imported_items(action_connectors_dicts, "action connector(s)", "id") - return response, results + return response, results # type: ignore[reportUnknownVariableType] @kibana_group.command("export-rules") @@ -195,15 +211,23 @@ def _process_imported_items(imported_items_list, item_type_description, item_key @click.option("--exceptions-directory", "-ed", required=False, type=Path, help="Directory to export exceptions to") @click.option("--default-author", "-da", type=str, required=False, help="Default author for rules missing one") @click.option("--rule-id", "-r", multiple=True, help="Optional Rule IDs to restrict export to") -@click.option("--rule-name", "-rn", required=False, help="Optional Rule name to restrict export to " - "(KQL, case-insensitive, supports wildcards)") +@click.option( + "--rule-name", + "-rn", + required=False, + help="Optional Rule name to restrict export to (KQL, case-insensitive, supports wildcards)", +) @click.option("--export-action-connectors", "-ac", is_flag=True, help="Include action connectors in export") @click.option("--export-exceptions", "-e", is_flag=True, help="Include exceptions in export") @click.option("--skip-errors", "-s", is_flag=True, help="Skip errors when exporting rules") @click.option("--strip-version", "-sv", is_flag=True, help="Strip the version fields from all rules") -@click.option("--no-tactic-filename", "-nt", is_flag=True, - help="Exclude tactic prefix in exported filenames for rules. " - "Use same flag for import-rules to prevent warnings and disable its unit test.") +@click.option( + "--no-tactic-filename", + "-nt", + is_flag=True, + help="Exclude tactic prefix in exported filenames for rules. " + "Use same flag for import-rules to prevent warnings and disable its unit test.", +) @click.option("--local-creation-date", "-lc", is_flag=True, help="Preserve the local creation date of the rule") @click.option("--local-updated-date", "-lu", is_flag=True, help="Preserve the local updated date of the rule") @click.option("--custom-rules-only", "-cro", is_flag=True, help="Only export custom rules") @@ -214,18 +238,28 @@ def _process_imported_items(imported_items_list, item_type_description, item_key required=False, help=( "Apply a query filter to exporting rules e.g. " - "\"alert.attributes.tags: \\\"test\\\"\" to filter for rules that have the tag \"test\"" - ) + '"alert.attributes.tags: \\"test\\"" to filter for rules that have the tag "test"' + ), ) @click.pass_context -def kibana_export_rules(ctx: click.Context, directory: Path, action_connectors_directory: Optional[Path], - exceptions_directory: Optional[Path], default_author: str, - rule_id: Optional[Iterable[str]] = None, rule_name: Optional[str] = None, - export_action_connectors: bool = False, - export_exceptions: bool = False, skip_errors: bool = False, strip_version: bool = False, - no_tactic_filename: bool = False, local_creation_date: bool = False, - local_updated_date: bool = False, custom_rules_only: bool = False, - export_query: Optional[str] = None) -> List[TOMLRule]: +def kibana_export_rules( # noqa: PLR0912, PLR0913, PLR0915 + ctx: click.Context, + directory: Path, + action_connectors_directory: Path | None, + exceptions_directory: Path | None, + default_author: str, + rule_id: list[str] | None = None, + rule_name: str | None = None, + export_action_connectors: bool = False, + export_exceptions: bool = False, + skip_errors: bool = False, + strip_version: bool = False, + no_tactic_filename: bool = False, + local_creation_date: bool = False, + local_updated_date: bool = False, + custom_rules_only: bool = False, + export_query: str | None = None, +) -> list[TOMLRule]: """Export custom rules from Kibana.""" kibana = ctx.obj["kibana"] kibana_include_details = export_exceptions or export_action_connectors @@ -237,22 +271,20 @@ def kibana_export_rules(ctx: click.Context, directory: Path, action_connectors_d with kibana: # Look up rule IDs by name if --rule-name was provided if rule_name: - found = RuleResource.find(filter=f"alert.attributes.name:{rule_name}") - rule_id = [r["rule_id"] for r in found] + found = RuleResource.find(filter=f"alert.attributes.name:{rule_name}") # type: ignore[reportUnknownMemberType] + rule_id = [r["rule_id"] for r in found] # type: ignore[reportUnknownVariableType] query = ( - export_query if not custom_rules_only + export_query + if not custom_rules_only else ( - f"alert.attributes.params.ruleSource.type: \"internal\"" - f"{f' and ({export_query})' if export_query else ''}" + f'alert.attributes.params.ruleSource.type: "internal"{f" and ({export_query})" if export_query else ""}' ) ) - results = ( - RuleResource.bulk_export(rule_ids=list(rule_id), query=query) + results = ( # type: ignore[reportUnknownVariableType] + RuleResource.bulk_export(rule_ids=list(rule_id), query=query) # type: ignore[reportArgumentType] if query - else RuleResource.export_rules( - list(rule_id), exclude_export_details=not kibana_include_details - ) + else RuleResource.export_rules(list(rule_id), exclude_export_details=not kibana_include_details) # type: ignore[reportArgumentType] ) # Handle Exceptions Directory Location if results and exceptions_directory: @@ -274,48 +306,48 @@ def kibana_export_rules(ctx: click.Context, directory: Path, action_connectors_d click.echo("No rules found to export") return [] - rules_results = results + rules_results = results # type: ignore[reportUnknownVariableType] action_connector_results = [] exception_results = [] if kibana_include_details: # Assign counts to variables - rules_count = results[-1]["exported_rules_count"] - exception_list_count = results[-1]["exported_exception_list_count"] - exception_list_item_count = results[-1]["exported_exception_list_item_count"] - action_connector_count = results[-1]["exported_action_connector_count"] + rules_count = results[-1]["exported_rules_count"] # type: ignore[reportUnknownVariableType] + exception_list_count = results[-1]["exported_exception_list_count"] # type: ignore[reportUnknownVariableType] + exception_list_item_count = results[-1]["exported_exception_list_item_count"] # type: ignore[reportUnknownVariableType] + action_connector_count = results[-1]["exported_action_connector_count"] # type: ignore[reportUnknownVariableType] # Parse rules results and exception results from API return - rules_results = results[:rules_count] - exception_results = results[rules_count:rules_count + exception_list_count + exception_list_item_count] - rules_and_exceptions_count = rules_count + exception_list_count + exception_list_item_count - action_connector_results = results[ - rules_and_exceptions_count: rules_and_exceptions_count + action_connector_count + rules_results = results[:rules_count] # type: ignore[reportUnknownVariableType] + exception_results = results[rules_count : rules_count + exception_list_count + exception_list_item_count] # type: ignore[reportUnknownVariableType] + rules_and_exceptions_count = rules_count + exception_list_count + exception_list_item_count # type: ignore[reportUnknownVariableType] + action_connector_results = results[ # type: ignore[reportUnknownVariableType] + rules_and_exceptions_count : rules_and_exceptions_count + action_connector_count ] - errors = [] - exported = [] - exception_list_rule_table = {} - action_connector_rule_table = {} - for rule_resource in rules_results: + errors: list[str] = [] + exported: list[TOMLRule] = [] + exception_list_rule_table: dict[str, list[dict[str, Any]]] = {} + action_connector_rule_table: dict[str, list[dict[str, Any]]] = {} + for rule_resource in rules_results: # type: ignore[reportUnknownVariableType] try: if strip_version: - rule_resource.pop("revision", None) - rule_resource.pop("version", None) - rule_resource["author"] = rule_resource.get("author") or default_author or [rule_resource.get("created_by")] + rule_resource.pop("revision", None) # type: ignore[reportUnknownMemberType] + rule_resource.pop("version", None) # type: ignore[reportUnknownMemberType] + rule_resource["author"] = rule_resource.get("author") or default_author or [rule_resource.get("created_by")] # type: ignore[reportUnknownMemberType] if isinstance(rule_resource["author"], str): rule_resource["author"] = [rule_resource["author"]] # Inherit maturity and optionally local dates from the rule if it already exists - params = { + params: dict[str, Any] = { "rule": rule_resource, "maturity": "development", } - threat = rule_resource.get("threat") - first_tactic = threat[0].get("tactic").get("name") if threat else "" + threat = rule_resource.get("threat") # type: ignore[reportUnknownMemberType] + first_tactic = threat[0].get("tactic").get("name") if threat else "" # type: ignore[reportUnknownMemberType] # Check if flag or config is set to not include tactic in the filename no_tactic_filename = no_tactic_filename or RULES_CONFIG.no_tactic_filename # Check if the flag is set to not include tactic in the filename - tactic_name = first_tactic if not no_tactic_filename else None - rule_name = rulename_to_filename(rule_resource.get("name"), tactic_name=tactic_name) + tactic_name = first_tactic if not no_tactic_filename else None # type: ignore[reportUnknownMemberType] + rule_name = rulename_to_filename(rule_resource.get("name"), tactic_name=tactic_name) # type: ignore[reportUnknownMemberType] save_path = directory / f"{rule_name}" params.update( @@ -323,12 +355,12 @@ def kibana_export_rules(ctx: click.Context, directory: Path, action_connectors_d save_path, {"creation_date": local_creation_date, "updated_date": local_updated_date} ) ) - contents = TOMLRuleContents.from_rule_resource(**params) + contents = TOMLRuleContents.from_rule_resource(**params) # type: ignore[reportArgumentType] rule = TOMLRule(contents=contents, path=save_path) except Exception as e: if skip_errors: - print(f'- skipping {rule_resource.get("name")} - {type(e).__name__}') - errors.append(f'- {rule_resource.get("name")} - {e}') + print(f"- skipping {rule_resource.get('name')} - {type(e).__name__}") # type: ignore[reportUnknownMemberType] + errors.append(f"- {rule_resource.get('name')} - {e}") # type: ignore[reportUnknownMemberType] continue raise if rule.contents.data.exceptions_list: @@ -354,7 +386,7 @@ def kibana_export_rules(ctx: click.Context, directory: Path, action_connectors_d exceptions_containers = {} exceptions_items = {} - exceptions_containers, exceptions_items, parse_errors, _ = parse_exceptions_results_from_api(exception_results) + exceptions_containers, exceptions_items, parse_errors, _ = parse_exceptions_results_from_api(exception_results) # type: ignore[reportArgumentType] errors.extend(parse_errors) # Build TOMLException Objects @@ -374,7 +406,7 @@ def kibana_export_rules(ctx: click.Context, directory: Path, action_connectors_d # Parse action connector results from API return action_connectors = [] if export_action_connectors: - action_connector_results, _ = parse_action_connector_results_from_api(action_connector_results) + action_connector_results, _ = parse_action_connector_results_from_api(action_connector_results) # type: ignore[reportArgumentType] # Build TOMLActionConnector Objects action_connectors, ac_output, ac_errors = build_action_connector_objects( @@ -389,7 +421,7 @@ def kibana_export_rules(ctx: click.Context, directory: Path, action_connectors_d click.echo(line) errors.extend(ac_errors) - saved = [] + saved: list[TOMLRule] = [] for rule in exported: try: rule.save_toml() @@ -402,20 +434,20 @@ def kibana_export_rules(ctx: click.Context, directory: Path, action_connectors_d saved.append(rule) - saved_exceptions = [] + saved_exceptions: list[TOMLException] = [] for exception in exceptions: try: exception.save_toml() except Exception as e: if skip_errors: - print(f"- skipping {exception.rule_name} - {type(e).__name__}") - errors.append(f"- {exception.rule_name} - {e}") + print(f"- skipping {exception.rule_name} - {type(e).__name__}") # type: ignore[reportUnknownMemberType] + errors.append(f"- {exception.rule_name} - {e}") # type: ignore[reportUnknownMemberType] continue raise saved_exceptions.append(exception) - saved_action_connectors = [] + saved_action_connectors: list[TOMLActionConnector] = [] for action in action_connectors: try: action.save_toml() @@ -428,7 +460,7 @@ def kibana_export_rules(ctx: click.Context, directory: Path, action_connectors_d saved_action_connectors.append(action) - click.echo(f"{len(results)} results exported") + click.echo(f"{len(results)} results exported") # type: ignore[reportUnknownArgumentType] click.echo(f"{len(exported)} rules converted") click.echo(f"{len(exceptions)} exceptions exported") click.echo(f"{len(action_connectors)} action connectors exported") @@ -437,54 +469,61 @@ def kibana_export_rules(ctx: click.Context, directory: Path, action_connectors_d click.echo(f"{len(saved_action_connectors)} action connectors saved to {action_connectors_directory}") if errors: err_file = directory / "_errors.txt" - err_file.write_text("\n".join(errors)) + _ = err_file.write_text("\n".join(errors)) click.echo(f"{len(errors)} errors saved to {err_file}") return exported -@kibana_group.command('search-alerts') -@click.argument('query', required=False) -@click.option('--date-range', '-d', type=(str, str), default=('now-7d', 'now'), help='Date range to scope search') -@click.option('--columns', '-c', multiple=True, help='Columns to display in table') -@click.option('--extend', '-e', is_flag=True, help='If columns are specified, extend the original columns') -@click.option('--max-count', '-m', default=100, help='The max number of alerts to return') +@kibana_group.command("search-alerts") +@click.argument("query", required=False) +@click.option("--date-range", "-d", type=(str, str), default=("now-7d", "now"), help="Date range to scope search") +@click.option("--columns", "-c", multiple=True, help="Columns to display in table") +@click.option("--extend", "-e", is_flag=True, help="If columns are specified, extend the original columns") +@click.option("--max-count", "-m", default=100, help="The max number of alerts to return") @click.pass_context -def search_alerts(ctx, query, date_range, columns, extend, max_count): +def search_alerts( # noqa: PLR0913 + ctx: click.Context, + query: str, + date_range: tuple[str, str], + columns: list[str], + extend: bool, + max_count: int, +) -> None: """Search detection engine alerts with KQL.""" - from eql.table import Table + from eql.table import Table # type: ignore[reportMissingTypeStubs] + from .eswrap import MATCH_ALL, add_range_to_dsl - kibana = ctx.obj['kibana'] + kibana = ctx.obj["kibana"] start_time, end_time = date_range - kql_query = kql.to_dsl(query) if query else MATCH_ALL - add_range_to_dsl(kql_query['bool'].setdefault('filter', []), start_time, end_time) + kql_query = kql.to_dsl(query) if query else MATCH_ALL # type: ignore[reportUnknownMemberType] + add_range_to_dsl(kql_query["bool"].setdefault("filter", []), start_time, end_time) # type: ignore[reportUnknownArgumentType] with kibana: - alerts = [a['_source'] for a in Signal.search({'query': kql_query}, size=max_count)['hits']['hits']] + alerts = [a["_source"] for a in Signal.search({"query": kql_query}, size=max_count)["hits"]["hits"]] # type: ignore[reportUnknownMemberType] # check for events with nested signal fields if alerts: - table_columns = ['host.hostname'] + table_columns = ["host.hostname"] - if 'signal' in alerts[0]: - table_columns += ['signal.rule.name', 'signal.status', 'signal.original_time'] - elif 'kibana.alert.rule.name' in alerts[0]: - table_columns += ['kibana.alert.rule.name', 'kibana.alert.status', 'kibana.alert.original_time'] + if "signal" in alerts[0]: + table_columns += ["signal.rule.name", "signal.status", "signal.original_time"] + elif "kibana.alert.rule.name" in alerts[0]: + table_columns += ["kibana.alert.rule.name", "kibana.alert.status", "kibana.alert.original_time"] else: - table_columns += ['rule.name', '@timestamp'] + table_columns += ["rule.name", "@timestamp"] if columns: columns = list(columns) table_columns = table_columns + columns if extend else columns # Table requires the data to be nested, but depending on the version, some data uses dotted keys, so # they must be nested explicitly - for alert in alerts: + for alert in alerts: # type: ignore[reportUnknownVariableType] for key in table_columns: if key in alert: - nested_set(alert, key, alert[key]) + nested_set(alert, key, alert[key]) # type: ignore[reportUnknownArgumentType] - click.echo(Table.from_list(table_columns, alerts)) + click.echo(Table.from_list(table_columns, alerts)) # type: ignore[reportUnknownMemberType] else: - click.echo('No alerts detected') - return alerts + click.echo("No alerts detected") diff --git a/detection_rules/main.py b/detection_rules/main.py index 1e1dcfcac06..5751284b9b8 100644 --- a/detection_rules/main.py +++ b/detection_rules/main.py @@ -4,83 +4,107 @@ # 2.0. """CLI commands for detection_rules.""" + import dataclasses -import glob import json import os import time -from datetime import datetime - -import pytoml -from marshmallow_dataclass import class_schema +from collections.abc import Iterable +from datetime import UTC, datetime from pathlib import Path -from semver import Version -from typing import Dict, Iterable, List, Optional, get_args +from typing import TYPE_CHECKING, Any, Literal, get_args from uuid import uuid4 + import click +import pytoml # type: ignore[reportMissingTypeStubs] +from marshmallow_dataclass import class_schema +from semver import Version -from .action_connector import (TOMLActionConnectorContents, - build_action_connector_objects, parse_action_connector_results_from_api) +from .action_connector import ( + TOMLActionConnectorContents, + build_action_connector_objects, + parse_action_connector_results_from_api, +) from .attack import build_threat_map_entry -from .cli_utils import rule_prompt, multi_collection +from .cli_utils import multi_collection, rule_prompt from .config import load_current_package_version, parse_rules_config +from .exception import TOMLExceptionContents, build_exception_objects, parse_exceptions_results_from_api from .generic_loader import GenericCollection -from .exception import (TOMLExceptionContents, - build_exception_objects, parse_exceptions_results_from_api) -from .misc import ( - add_client, client_error, nested_set, parse_user_config -) -from .rule import TOMLRule, TOMLRuleContents, QueryRuleData +from .misc import add_client, nested_set, parse_user_config, raise_client_error +from .rule import DeprecatedRule, QueryRuleData, TOMLRule, TOMLRuleContents from .rule_formatter import toml_write from .rule_loader import RuleCollection, update_metadata_from_file from .schemas import all_versions, definitions, get_incompatible_fields, get_schema_file -from .utils import Ndjson, get_path, get_etc_path, clear_caches, load_dump, load_rule_contents, rulename_to_filename +from .utils import ( + Ndjson, + clear_caches, + get_etc_path, + get_path, + load_dump, # type: ignore[reportUnknownVariableType] + load_rule_contents, + rulename_to_filename, +) + +if TYPE_CHECKING: + from elasticsearch import Elasticsearch RULES_CONFIG = parse_rules_config() RULES_DIRS = RULES_CONFIG.rule_dirs @click.group( - 'detection-rules', + "detection-rules", context_settings={ - 'help_option_names': ['-h', '--help'], - 'max_content_width': int(os.getenv('DR_CLI_MAX_WIDTH', 240)), + "help_option_names": ["-h", "--help"], + "max_content_width": int(os.getenv("DR_CLI_MAX_WIDTH", 240)), # noqa: PLW1508 }, ) -@click.option('--debug/--no-debug', '-D/-N', is_flag=True, default=None, - help='Print full exception stacktrace on errors') +@click.option( + "--debug/--no-debug", + "-D/-N", + is_flag=True, + default=None, + help="Print full exception stacktrace on errors", +) @click.pass_context -def root(ctx, debug): +def root(ctx: click.Context, debug: bool) -> None: """Commands for detection-rules repository.""" - debug = debug if debug is not None else parse_user_config().get('debug') - ctx.obj = {'debug': debug, 'rules_config': RULES_CONFIG} + debug = debug if debug else parse_user_config().get("debug") + ctx.obj = {"debug": debug, "rules_config": RULES_CONFIG} if debug: - click.secho('DEBUG MODE ENABLED', fg='yellow') + click.secho("DEBUG MODE ENABLED", fg="yellow") -@root.command('create-rule') -@click.argument('path', type=Path) -@click.option('--config', '-c', type=click.Path(exists=True, dir_okay=False, path_type=Path), - help='Rule or config file') -@click.option('--required-only', is_flag=True, help='Only prompt for required fields') -@click.option('--rule-type', '-t', type=click.Choice(sorted(TOMLRuleContents.all_rule_types())), - help='Type of rule to create') -def create_rule(path, config, required_only, rule_type): +@root.command("create-rule") +@click.argument("path", type=Path) +@click.option( + "--config", "-c", type=click.Path(exists=True, dir_okay=False, path_type=Path), help="Rule or config file" +) +@click.option("--required-only", is_flag=True, help="Only prompt for required fields") +@click.option( + "--rule-type", "-t", type=click.Choice(sorted(TOMLRuleContents.all_rule_types())), help="Type of rule to create" +) +def create_rule(path: Path, config: Path, required_only: bool, rule_type: str): # noqa: ANN201 """Create a detection rule.""" - contents = load_rule_contents(config, single_only=True)[0] if config else {} + contents: dict[str, Any] = load_rule_contents(config, single_only=True)[0] if config else {} return rule_prompt(path, rule_type=rule_type, required_only=required_only, save=True, **contents) -@root.command('generate-rules-index') -@click.option('--query', '-q', help='Optional KQL query to limit to specific rules') -@click.option('--overwrite', is_flag=True, help='Overwrite files in an existing folder') +@root.command("generate-rules-index") +@click.option("--query", "-q", help="Optional KQL query to limit to specific rules") +@click.option("--overwrite", is_flag=True, help="Overwrite files in an existing folder") @click.pass_context -def generate_rules_index(ctx: click.Context, query, overwrite, save_files=True): +def generate_rules_index( + ctx: click.Context, + query: str, + overwrite: bool, + save_files: bool = True, +) -> tuple[Ndjson, Ndjson]: """Generate enriched indexes of rules, based on a KQL search, for indexing/importing into elasticsearch/kibana.""" from .packaging import Package if query: - rule_paths = [r['file'] for r in ctx.invoke(search_rules, query=query, verbose=False)] + rule_paths = [r["file"] for r in ctx.invoke(search_rules, query=query, verbose=False)] rules = RuleCollection() rules.load_files(Path(p) for p in rule_paths) else: @@ -92,37 +116,40 @@ def generate_rules_index(ctx: click.Context, query, overwrite, save_files=True): bulk_upload_docs, importable_rules_docs = package.create_bulk_index_body() if save_files: - path = get_path('enriched-rule-indexes', package_hash) + path = get_path(["enriched-rule-indexes", package_hash]) path.mkdir(parents=True, exist_ok=overwrite) - bulk_upload_docs.dump(path.joinpath('enriched-rules-index-uploadable.ndjson'), sort_keys=True) - importable_rules_docs.dump(path.joinpath('enriched-rules-index-importable.ndjson'), sort_keys=True) + bulk_upload_docs.dump(path.joinpath("enriched-rules-index-uploadable.ndjson"), sort_keys=True) + importable_rules_docs.dump(path.joinpath("enriched-rules-index-importable.ndjson"), sort_keys=True) - click.echo(f'files saved to: {path}') + click.echo(f"files saved to: {path}") - click.echo(f'{rule_count} rules included') + click.echo(f"{rule_count} rules included") return bulk_upload_docs, importable_rules_docs @root.command("import-rules-to-repo") -@click.argument("input-file", type=click.Path(dir_okay=False, exists=True), nargs=-1, required=False) +@click.argument("input-file", type=click.Path(dir_okay=False, exists=True, path_type=Path), nargs=-1, required=False) @click.option("--action-connector-import", "-ac", is_flag=True, help="Include action connectors in export") @click.option("--exceptions-import", "-e", is_flag=True, help="Include exceptions in export") @click.option("--required-only", is_flag=True, help="Only prompt for required fields") @click.option("--directory", "-d", type=click.Path(file_okay=False, exists=True), help="Load files from a directory") @click.option( - "--save-directory", "-s", type=click.Path(file_okay=False, exists=True), help="Save imported rules to a directory" + "--save-directory", + "-s", + type=click.Path(file_okay=False, exists=True, path_type=Path), + help="Save imported rules to a directory", ) @click.option( "--exceptions-directory", "-se", - type=click.Path(file_okay=False, exists=True), + type=click.Path(file_okay=False, exists=True, path_type=Path), help="Save imported exceptions to a directory", ) @click.option( "--action-connectors-directory", "-sa", - type=click.Path(file_okay=False, exists=True), + type=click.Path(file_okay=False, exists=True, path_type=Path), help="Save imported actions to a directory", ) @click.option("--skip-errors", "-ske", is_flag=True, help="Skip rule import errors") @@ -130,17 +157,32 @@ def generate_rules_index(ctx: click.Context, query, overwrite, save_files=True): @click.option("--strip-none-values", "-snv", is_flag=True, help="Strip None values from the rule") @click.option("--local-creation-date", "-lc", is_flag=True, help="Preserve the local creation date of the rule") @click.option("--local-updated-date", "-lu", is_flag=True, help="Preserve the local updated date of the rule") -def import_rules_into_repo(input_file: click.Path, required_only: bool, action_connector_import: bool, - exceptions_import: bool, directory: click.Path, save_directory: click.Path, - action_connectors_directory: click.Path, exceptions_directory: click.Path, - skip_errors: bool, default_author: str, strip_none_values: bool, local_creation_date: bool, - local_updated_date: bool): +def import_rules_into_repo( # noqa: PLR0912, PLR0913, PLR0915 + input_file: tuple[Path, ...] | None, + required_only: bool, + action_connector_import: bool, + exceptions_import: bool, + directory: Path | None, + save_directory: Path, + action_connectors_directory: Path | None, + exceptions_directory: Path | None, + skip_errors: bool, + default_author: str, + strip_none_values: bool, + local_creation_date: bool, + local_updated_date: bool, +) -> None: """Import rules from json, toml, or yaml files containing Kibana exported rule(s).""" - errors = [] - rule_files = glob.glob(os.path.join(directory, "**", "*.*"), recursive=True) if directory else [] - rule_files = sorted(set(rule_files + list(input_file))) + errors: list[str] = [] + + rule_files: list[Path] = [] + if directory: + rule_files = list(directory.glob("**/*.*")) + + if input_file: + rule_files = sorted({*rule_files, *input_file}) - file_contents = [] + file_contents: list[Any] = [] for rule_file in rule_files: file_contents.extend(load_rule_contents(Path(rule_file))) @@ -156,19 +198,19 @@ def import_rules_into_repo(input_file: click.Path, required_only: bool, action_c file_contents = unparsed_results - exception_list_rule_table = {} - action_connector_rule_table = {} + exception_list_rule_table: dict[str, Any] = {} + action_connector_rule_table: dict[str, Any] = {} rule_count = 0 for contents in file_contents: # Don't load exceptions as rules if contents.get("type") not in get_args(definitions.RuleType): - click.echo(f"Skipping - {contents.get("type")} is not a supported rule type") + click.echo(f"Skipping - {contents.get('type')} is not a supported rule type") continue base_path = contents.get("name") or contents.get("rule", {}).get("name") base_path = rulename_to_filename(base_path) if base_path else base_path if base_path is None: raise ValueError(f"Invalid rule file, please ensure the rule has a name field: {contents}") - rule_path = os.path.join(save_directory if save_directory is not None else RULES_DIRS[0], base_path) + rule_path = Path(os.path.join(str(save_directory) if save_directory else RULES_DIRS[0], base_path)) # noqa: PTH118 # handle both rule json formats loaded from kibana and toml data_view_id = contents.get("data_view_id") or contents.get("rule", {}).get("data_view_id") @@ -255,16 +297,21 @@ def import_rules_into_repo(input_file: click.Path, required_only: bool, action_c click.echo(f"{exceptions_count} exceptions exported") click.echo(f"{len(action_connectors)} actions connectors exported") if errors: - err_file = save_directory if save_directory is not None else RULES_DIRS[0] / "_errors.txt" - err_file.write_text("\n".join(errors)) + _dir = save_directory if save_directory else RULES_DIRS[0] + err_file = _dir / "_errors.txt" + _ = err_file.write_text("\n".join(errors)) click.echo(f"{len(errors)} errors saved to {err_file}") -@root.command('build-limited-rules') -@click.option('--stack-version', type=click.Choice(all_versions()), required=True, - help='Version to downgrade to be compatible with the older instance of Kibana') -@click.option('--output-file', '-o', type=click.Path(dir_okay=False, exists=False), required=True) -def build_limited_rules(stack_version: str, output_file: str): +@root.command("build-limited-rules") +@click.option( + "--stack-version", + type=click.Choice(all_versions()), + required=True, + help="Version to downgrade to be compatible with the older instance of Kibana", +) +@click.option("--output-file", "-o", type=click.Path(dir_okay=False, exists=False), required=True) +def build_limited_rules(stack_version: str, output_file: str) -> None: """ Import rules from json, toml, or Kibana exported rule file(s), filter out unsupported ones, and write to output NDJSON file. @@ -272,9 +319,10 @@ def build_limited_rules(stack_version: str, output_file: str): # Schema generation and incompatible fields detection query_rule_data = class_schema(QueryRuleData)() - fields = getattr(query_rule_data, 'fields', {}) - incompatible_fields = get_incompatible_fields(list(fields.values()), - Version.parse(stack_version, optional_minor_and_patch=True)) + fields = getattr(query_rule_data, "fields", {}) + incompatible_fields = get_incompatible_fields( + list(fields.values()), Version.parse(stack_version, optional_minor_and_patch=True) + ) # Load all rules rules = RuleCollection.default() @@ -289,10 +337,11 @@ def build_limited_rules(stack_version: str, output_file: str): api_schema = get_schema_file(stack_version, "base")["properties"]["type"]["enum"] # Function to process each rule - def process_rule(rule, incompatible_fields: List[str]): + def process_rule(rule: TOMLRule, incompatible_fields: list[str]) -> dict[str, Any] | None: if rule.contents.type not in api_schema: - click.secho(f'{rule.contents.name} - Skipping unsupported rule type: {rule.contents.get("type")}', - fg='yellow') + click.secho( + f"{rule.contents.name} - Skipping unsupported rule type: {rule.contents.get('type')}", fg="yellow" + ) return None # Remove unsupported fields from rule @@ -311,13 +360,18 @@ def process_rule(rule, incompatible_fields: List[str]): # Write ndjson_output to file ndjson_output.dump(output_path) - click.echo(f'Success: Rules written to {output_file}') + click.echo(f"Success: Rules written to {output_file}") -@root.command('toml-lint') -@click.option('--rule-file', '-f', multiple=True, type=click.Path(exists=True), - help='Specify one or more rule files.') -def toml_lint(rule_file): +@root.command("toml-lint") +@click.option( + "--rule-file", + "-f", + multiple=True, + type=click.Path(exists=True, path_type=Path), + help="Specify one or more rule files.", +) +def toml_lint(rule_file: list[Path]) -> None: """Cleanup files with some simple toml formatting.""" if rule_file: rules = RuleCollection() @@ -329,88 +383,101 @@ def toml_lint(rule_file): for rule in rules: rule.save_toml() - click.echo('TOML file linting complete') + click.echo("TOML file linting complete") -@root.command('mass-update') -@click.argument('query') -@click.option('--metadata', '-m', is_flag=True, help='Make an update to the rule metadata rather than contents.') -@click.option('--language', type=click.Choice(["eql", "kql"]), default="kql") -@click.option('--field', type=(str, str), multiple=True, - help='Use rule-search to retrieve a subset of rules and modify values ' - '(ex: --field management.ecs_version 1.1.1).\n' - 'Note this is limited to string fields only. Nested fields should use dot notation.') +@root.command("mass-update") +@click.argument("query") +@click.option("--metadata", "-m", is_flag=True, help="Make an update to the rule metadata rather than contents.") +@click.option("--language", type=click.Choice(["eql", "kql"]), default="kql") +@click.option( + "--field", + type=(str, str), + multiple=True, + help="Use rule-search to retrieve a subset of rules and modify values " + "(ex: --field management.ecs_version 1.1.1).\n" + "Note this is limited to string fields only. Nested fields should use dot notation.", +) @click.pass_context -def mass_update(ctx, query, metadata, language, field): +def mass_update( + ctx: click.Context, + query: str, + metadata: bool, + language: Literal["eql", "kql"], + field: tuple[str, str], +) -> Any: """Update multiple rules based on eql results.""" rules = RuleCollection().default() results = ctx.invoke(search_rules, query=query, language=language, verbose=False) - matching_ids = set(r["rule_id"] for r in results) + matching_ids = {r["rule_id"] for r in results} rules = rules.filter(lambda r: r.id in matching_ids) for rule in rules: for key, value in field: - nested_set(rule.metadata if metadata else rule.contents, key, value) + nested_set(rule.metadata if metadata else rule.contents, key, value) # type: ignore[reportAttributeAccessIssue] - rule.validate(as_rule=True) - rule.save(as_rule=True) + rule.validate(as_rule=True) # type: ignore[reportAttributeAccessIssue] + rule.save(as_rule=True) # type: ignore[reportAttributeAccessIssue] - return ctx.invoke(search_rules, query=query, language=language, - columns=['rule_id', 'name'] + [k[0].split('.')[-1] for k in field]) + return ctx.invoke( + search_rules, + query=query, + language=language, + columns=["rule_id", "name"] + [k[0].split(".")[-1] for k in field], + ) -@root.command('view-rule') -@click.argument('rule-file', type=Path) -@click.option('--api-format/--rule-format', default=True, help='Print the rule in final api or rule format') +@root.command("view-rule") +@click.argument("rule-file", type=Path) +@click.option("--api-format/--rule-format", default=True, help="Print the rule in final api or rule format") @click.pass_context -def view_rule(ctx, rule_file, api_format): +def view_rule(_: click.Context, rule_file: Path, api_format: str) -> TOMLRule | DeprecatedRule: """View an internal rule or specified rule file.""" rule = RuleCollection().load_file(rule_file) if api_format: click.echo(json.dumps(rule.contents.to_api_format(), indent=2, sort_keys=True)) else: - click.echo(toml_write(rule.contents.to_dict())) + click.echo(toml_write(rule.contents.to_dict())) # type: ignore[reportAttributeAccessIssue] return rule -def _export_rules( +def _export_rules( # noqa: PLR0913 rules: RuleCollection, outfile: Path, - downgrade_version: Optional[definitions.SemVer] = None, - verbose=True, - skip_unsupported=False, + downgrade_version: definitions.SemVer | None = None, + verbose: bool = True, + skip_unsupported: bool = False, include_metadata: bool = False, include_action_connectors: bool = False, include_exceptions: bool = False, -): +) -> None: """Export rules and exceptions into a consolidated ndjson file.""" from .rule import downgrade_contents_from_rule - outfile = outfile.with_suffix('.ndjson') - unsupported = [] + outfile = outfile.with_suffix(".ndjson") + unsupported: list[str] = [] + output_lines: list[str] = [] if downgrade_version: - if skip_unsupported: - output_lines = [] - - for rule in rules: - try: - output_lines.append(json.dumps(downgrade_contents_from_rule(rule, downgrade_version, - include_metadata=include_metadata), - sort_keys=True)) - except ValueError as e: - unsupported.append(f'{e}: {rule.id} - {rule.name}') - continue - - else: - output_lines = [json.dumps(downgrade_contents_from_rule(r, downgrade_version, - include_metadata=include_metadata), sort_keys=True) - for r in rules] + for rule in rules: + try: + output_lines.append( + json.dumps( + downgrade_contents_from_rule(rule, downgrade_version, include_metadata=include_metadata), + sort_keys=True, + ) + ) + except ValueError as e: + if skip_unsupported: + unsupported.append(f"{e}: {rule.id} - {rule.name}") + else: + raise else: - output_lines = [json.dumps(r.contents.to_api_format(include_metadata=include_metadata), - sort_keys=True) for r in rules] + output_lines = [ + json.dumps(r.contents.to_api_format(include_metadata=include_metadata), sort_keys=True) for r in rules + ] # Add exceptions to api format here and add to output_lines if include_exceptions or include_action_connectors: @@ -427,14 +494,14 @@ def _export_rules( actions = [a for sublist in action_connectors for a in sublist] output_lines.extend(json.dumps(a, sort_keys=True) for a in actions) - outfile.write_text('\n'.join(output_lines) + '\n') + _ = outfile.write_text("\n".join(output_lines) + "\n") if verbose: - click.echo(f'Exported {len(rules) - len(unsupported)} rules into {outfile}') + click.echo(f"Exported {len(rules) - len(unsupported)} rules into {outfile}") if skip_unsupported and unsupported: - unsupported_str = '\n- '.join(unsupported) - click.echo(f'Skipped {len(unsupported)} unsupported rules: \n- {unsupported_str}') + unsupported_str = "\n- ".join(unsupported) + click.echo(f"Skipped {len(unsupported)} unsupported rules: \n- {unsupported_str}") @root.command("export-rules-from-repo") @@ -442,7 +509,7 @@ def _export_rules( @click.option( "--outfile", "-o", - default=Path(get_path("exports", f'{time.strftime("%Y%m%dT%H%M%SL")}.ndjson')), + default=Path(get_path(["exports", f"{time.strftime('%Y%m%dT%H%M%SL')}.ndjson"])), type=Path, help="Name of file for exported rules", ) @@ -456,7 +523,7 @@ def _export_rules( "--skip-unsupported", "-s", is_flag=True, - help="If `--stack-version` is passed, skip rule types which are unsupported " "(an error will be raised otherwise)", + help="If `--stack-version` is passed, skip rule types which are unsupported (an error will be raised otherwise)", ) @click.option("--include-metadata", type=bool, is_flag=True, default=False, help="Add metadata to the exported rules") @click.option( @@ -470,10 +537,19 @@ def _export_rules( @click.option( "--include-exceptions", "-e", type=bool, is_flag=True, default=False, help="Include Exceptions Lists in export" ) -def export_rules_from_repo(rules, outfile: Path, replace_id, stack_version, skip_unsupported, include_metadata: bool, - include_action_connectors: bool, include_exceptions: bool) -> RuleCollection: +def export_rules_from_repo( # noqa: PLR0913 + rules: RuleCollection, + outfile: Path, + replace_id: bool, + stack_version: str, + skip_unsupported: bool, + include_metadata: bool, + include_action_connectors: bool, + include_exceptions: bool, +) -> RuleCollection: """Export rule(s) and exception(s) into an importable ndjson file.""" - assert len(rules) > 0, "No rules found" + if len(rules) == 0: + raise ValueError("No rules found") if replace_id: # if we need to replace the id, take each rule object and create a copy @@ -500,87 +576,97 @@ def export_rules_from_repo(rules, outfile: Path, replace_id, stack_version, skip return rules -@root.command('validate-rule') -@click.argument('path') +@root.command("validate-rule") +@click.argument("path") @click.pass_context -def validate_rule(ctx, path): +def validate_rule(_: click.Context, path: str) -> TOMLRule | DeprecatedRule: """Check if a rule staged in rules dir validates against a schema.""" rule = RuleCollection().load_file(Path(path)) - click.echo('Rule validation successful') + click.echo("Rule validation successful") return rule -@root.command('validate-all') -def validate_all(): +@root.command("validate-all") +def validate_all() -> None: """Check if all rules validates against a schema.""" - RuleCollection.default() - click.echo('Rule validation successful') - - -@root.command('rule-search') -@click.argument('query', required=False) -@click.option('--columns', '-c', multiple=True, help='Specify columns to add the table') -@click.option('--language', type=click.Choice(["eql", "kql"]), default="kql") -@click.option('--count', is_flag=True, help='Return a count rather than table') -def search_rules(query, columns, language, count, verbose=True, rules: Dict[str, TOMLRule] = None, pager=False): + _ = RuleCollection.default() + click.echo("Rule validation successful") + + +@root.command("rule-search") +@click.argument("query", required=False) +@click.option("--columns", "-c", multiple=True, help="Specify columns to add the table") +@click.option("--language", type=click.Choice(["eql", "kql"]), default="kql") +@click.option("--count", is_flag=True, help="Return a count rather than table") +def search_rules( # noqa: PLR0913 + query: str | None, + columns: list[str], + language: Literal["eql", "kql"], + count: bool, + verbose: bool = True, + rules: dict[str, TOMLRule] | None = None, + pager: bool = False, +) -> list[dict[str, Any]]: """Use KQL or EQL to find matching rules.""" - from kql import get_evaluator - from eql.table import Table - from eql.build import get_engine - from eql import parse_query - from eql.pipes import CountPipe + from eql import parse_query # type: ignore[reportMissingTypeStubs] + from eql.build import get_engine # type: ignore[reportMissingTypeStubs] + from eql.pipes import CountPipe # type: ignore[reportMissingTypeStubs] + from eql.table import Table # type: ignore[reportMissingTypeStubs] + from kql import get_evaluator # type: ignore[reportMissingTypeStubs] + from .rule import get_unique_query_fields - flattened_rules = [] + flattened_rules: list[dict[str, Any]] = [] rules = rules or {str(rule.path): rule for rule in RuleCollection.default()} for file_name, rule in rules.items(): - flat: dict = {"file": os.path.relpath(file_name)} + flat: dict[str, Any] = {"file": os.path.relpath(file_name)} flat.update(rule.contents.to_dict()) flat.update(flat["metadata"]) flat.update(flat["rule"]) - tactic_names = [] - technique_ids = [] - subtechnique_ids = [] + tactic_names: list[str] = [] + technique_ids: list[str] = [] + subtechnique_ids: list[str] = [] - for entry in flat['rule'].get('threat', []): + for entry in flat["rule"].get("threat", []): if entry["framework"] != "MITRE ATT&CK": continue - techniques = entry.get('technique', []) - tactic_names.append(entry['tactic']['name']) - technique_ids.extend([t['id'] for t in techniques]) - subtechnique_ids.extend([st['id'] for t in techniques for st in t.get('subtechnique', [])]) + techniques = entry.get("technique", []) + tactic_names.append(entry["tactic"]["name"]) + technique_ids.extend([t["id"] for t in techniques]) + subtechnique_ids.extend([st["id"] for t in techniques for st in t.get("subtechnique", [])]) - flat.update(techniques=technique_ids, tactics=tactic_names, subtechniques=subtechnique_ids, - unique_fields=get_unique_query_fields(rule)) + flat.update( + techniques=technique_ids, + tactics=tactic_names, + subtechniques=subtechnique_ids, + unique_fields=get_unique_query_fields(rule), + ) flattened_rules.append(flat) flattened_rules.sort(key=lambda dct: dct["name"]) - filtered = [] + filtered: list[dict[str, Any]] = [] if language == "kql": - evaluator = get_evaluator(query) if query else lambda x: True - filtered = list(filter(evaluator, flattened_rules)) + evaluator = get_evaluator(query) if query else lambda _: True # type: ignore[reportUnknownLambdaType] + filtered = list(filter(evaluator, flattened_rules)) # type: ignore[reportCallIssue] elif language == "eql": - parsed = parse_query(query, implied_any=True, implied_base=True) - evaluator = get_engine(parsed) - filtered = [result.events[0].data for result in evaluator(flattened_rules)] + parsed = parse_query(query, implied_any=True, implied_base=True) # type: ignore[reportUnknownVariableType] + evaluator = get_engine(parsed) # type: ignore[reportUnknownVariableType] + filtered = [result.events[0].data for result in evaluator(flattened_rules)] # type: ignore[reportUnknownVariableType] - if not columns and any(isinstance(pipe, CountPipe) for pipe in parsed.pipes): + if not columns and any(isinstance(pipe, CountPipe) for pipe in parsed.pipes): # type: ignore[reportAttributeAccessIssue] columns = ["key", "count", "percent"] if count: - click.echo(f'{len(filtered)} rules') + click.echo(f"{len(filtered)} rules") return filtered - if columns: - columns = ",".join(columns).split(",") - else: - columns = ["rule_id", "file", "name"] + columns = ",".join(columns).split(",") if columns else ["rule_id", "file", "name"] - table = Table.from_list(columns, filtered) + table: Table = Table.from_list(columns, filtered) # type: ignore[reportUnknownMemberType] if verbose: click.echo_via_pager(table) if pager else click.echo(table) @@ -588,70 +674,68 @@ def search_rules(query, columns, language, count, verbose=True, rules: Dict[str, return filtered -@root.command('build-threat-map-entry') -@click.argument('tactic') -@click.argument('technique-ids', nargs=-1) -def build_threat_map(tactic: str, technique_ids: Iterable[str]): +@root.command("build-threat-map-entry") +@click.argument("tactic") +@click.argument("technique-ids", nargs=-1) +def build_threat_map(tactic: str, technique_ids: Iterable[str]) -> dict[str, Any]: """Build a threat map entry.""" entry = build_threat_map_entry(tactic, *technique_ids) - rendered = pytoml.dumps({'rule': {'threat': [entry]}}) + rendered = pytoml.dumps({"rule": {"threat": [entry]}}) # type: ignore[reportUnknownMemberType] # strip out [rule] - cleaned = '\n'.join(rendered.splitlines()[2:]) + cleaned = "\n".join(rendered.splitlines()[2:]) print(cleaned) return entry @root.command("test") @click.pass_context -def test_rules(ctx): +def test_rules(ctx: click.Context) -> None: """Run unit tests over all of the rules.""" import pytest - rules_config = ctx.obj['rules_config'] + rules_config = ctx.obj["rules_config"] test_config = rules_config.test_config tests, skipped = test_config.get_test_names(formatted=True) if skipped: - click.echo(f'Tests skipped per config ({len(skipped)}):') - click.echo('\n'.join(skipped)) + click.echo(f"Tests skipped per config ({len(skipped)}):") + click.echo("\n".join(skipped)) clear_caches() if tests: - ctx.exit(pytest.main(['-v'] + tests)) + ctx.exit(pytest.main(["-v", *tests])) else: - click.echo('No tests found to execute!') + click.echo("No tests found to execute!") -@root.group('typosquat') -def typosquat_group(): +@root.group("typosquat") +def typosquat_group() -> None: """Commands for generating typosquat detections.""" -@typosquat_group.command('create-dnstwist-index') -@click.argument('input-file', type=click.Path(exists=True, dir_okay=False), required=True) +@typosquat_group.command("create-dnstwist-index") +@click.argument("input-file", type=click.Path(exists=True, dir_okay=False), required=True) @click.pass_context -@add_client('elasticsearch', add_func_arg=False) -def create_dnstwist_index(ctx: click.Context, input_file: click.Path): +@add_client(["elasticsearch"], add_func_arg=False) +def create_dnstwist_index(ctx: click.Context, input_file: click.Path) -> None: """Create a dnstwist index in Elasticsearch to work with a threat match rule.""" - from elasticsearch import Elasticsearch - - es_client: Elasticsearch = ctx.obj['es'] + es_client: Elasticsearch = ctx.obj["es"] - click.echo(f'Attempting to load dnstwist data from {input_file}') - dnstwist_data: dict = load_dump(str(input_file)) - click.echo(f'{len(dnstwist_data)} records loaded') + click.echo(f"Attempting to load dnstwist data from {input_file}") + dnstwist_data: list[dict[str, Any]] = load_dump(str(input_file)) # type: ignore[reportAssignmentType] + click.echo(f"{len(dnstwist_data)} records loaded") - original_domain = next(r['domain-name'] for r in dnstwist_data if r.get('fuzzer', '') == 'original*') - click.echo(f'Original domain name identified: {original_domain}') + original_domain = next(r["domain-name"] for r in dnstwist_data if r.get("fuzzer", "") == "original*") # type: ignore[reportAttributeAccessIssue] + click.echo(f"Original domain name identified: {original_domain}") - domain = original_domain.split('.')[0] - domain_index = f'dnstwist-{domain}' + domain = original_domain.split(".")[0] + domain_index = f"dnstwist-{domain}" # If index already exists, prompt user to confirm if they want to overwrite - if es_client.indices.exists(index=domain_index): - if click.confirm( - f"dnstwist index: {domain_index} already exists for {original_domain}. Do you want to overwrite?", - abort=True): - es_client.indices.delete(index=domain_index) + if es_client.indices.exists(index=domain_index) and click.confirm( + f"dnstwist index: {domain_index} already exists for {original_domain}. Do you want to overwrite?", + abort=True, + ): + _ = es_client.indices.delete(index=domain_index) fields = [ "dns-a", @@ -661,52 +745,52 @@ def create_dnstwist_index(ctx: click.Context, input_file: click.Path): "banner-http", "fuzzer", "original-domain", - "dns.question.registered_domain" + "dns.question.registered_domain", ] timestamp_field = "@timestamp" mappings = {"mappings": {"properties": {f: {"type": "keyword"} for f in fields}}} mappings["mappings"]["properties"][timestamp_field] = {"type": "date"} - es_client.indices.create(index=domain_index, body=mappings) + _ = es_client.indices.create(index=domain_index, body=mappings) # handle dns.question.registered_domain separately - fields.pop() - es_updates = [] - now = datetime.utcnow() + _ = fields.pop() + es_updates: list[dict[str, Any]] = [] + now = datetime.now(UTC) for item in dnstwist_data: - if item['fuzzer'] == 'original*': + if item["fuzzer"] == "original*": continue record = item.copy() - record.setdefault('dns', {}).setdefault('question', {}).setdefault('registered_domain', item.get('domain-name')) + record.setdefault("dns", {}).setdefault("question", {}).setdefault("registered_domain", item.get("domain-name")) for field in fields: - record.setdefault(field, None) + _ = record.setdefault(field, None) - record['@timestamp'] = now + record["@timestamp"] = now - es_updates.extend([{'create': {'_index': domain_index}}, record]) + es_updates.extend([{"create": {"_index": domain_index}}, record]) - click.echo(f'Indexing data for domain {original_domain}') + click.echo(f"Indexing data for domain {original_domain}") results = es_client.bulk(body=es_updates) - if results['errors']: - error = {r['create']['result'] for r in results['items'] if r['create']['status'] != 201} - client_error(f'Errors occurred during indexing:\n{error}') + if results["errors"]: + error = {r["create"]["result"] for r in results["items"] if r["create"]["status"] != 201} # noqa: PLR2004 + raise_client_error(f"Errors occurred during indexing:\n{error}") - click.echo(f'{len(results["items"])} watchlist domains added to index') - click.echo('Run `prep-rule` and import to Kibana to create alerts on this index') + click.echo(f"{len(results['items'])} watchlist domains added to index") + click.echo("Run `prep-rule` and import to Kibana to create alerts on this index") -@typosquat_group.command('prep-rule') -@click.argument('author') -def prep_rule(author: str): +@typosquat_group.command("prep-rule") +@click.argument("author") +def prep_rule(author: str) -> None: """Prep the detection threat match rule for dnstwist data with a rule_id and author.""" - rule_template_file = get_etc_path('rule_template_typosquatting_domain.json') + rule_template_file = get_etc_path(["rule_template_typosquatting_domain.json"]) template_rule = json.loads(rule_template_file.read_text()) template_rule.update(author=[author], rule_id=str(uuid4())) - updated_rule = get_path('rule_typosquatting_domain.ndjson') - updated_rule.write_text(json.dumps(template_rule, sort_keys=True)) - click.echo(f'Rule saved to: {updated_rule}. Import this to Kibana to create alerts on all dnstwist-* indexes') - click.echo('Note: you only need to import and enable this rule one time for all dnstwist-* indexes') + updated_rule = get_path(["rule_typosquatting_domain.ndjson"]) + _ = updated_rule.write_text(json.dumps(template_rule, sort_keys=True)) + click.echo(f"Rule saved to: {updated_rule}. Import this to Kibana to create alerts on all dnstwist-* indexes") + click.echo("Note: you only need to import and enable this rule one time for all dnstwist-* indexes") diff --git a/detection_rules/misc.py b/detection_rules/misc.py index dcf1fc51ad0..9b8f8c82e8e 100644 --- a/detection_rules/misc.py +++ b/detection_rules/misc.py @@ -4,23 +4,23 @@ # 2.0. """Misc support.""" + import os import re import time import unittest import uuid -from pathlib import Path +from collections.abc import Callable from functools import wraps -from typing import NoReturn, Optional +from pathlib import Path +from typing import IO, Any, NoReturn import click import requests +from elasticsearch import AuthenticationException, Elasticsearch +from kibana import Kibana # type: ignore[reportMissingTypeStubs] -from kibana import Kibana - -from .utils import add_params, cached, get_path, load_etc_dump - -_CONFIG = {} +from .utils import add_params, cached, load_etc_dump LICENSE_HEADER = """ Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one @@ -35,7 +35,7 @@ /* {} */ -""".strip().format("\n".join(' * ' + line for line in LICENSE_LINES)) +""".strip().format("\n".join(" * " + line for line in LICENSE_LINES)) ROOT_DIR = Path(__file__).parent.parent @@ -44,57 +44,57 @@ class ClientError(click.ClickException): """Custom CLI error to format output or full debug stacktrace.""" - def __init__(self, message, original_error=None): - super(ClientError, self).__init__(message) + def __init__(self, message: str, original_error: Exception | None = None) -> None: + super().__init__(message) self.original_error = original_error - self.original_error_type = type(original_error).__name__ if original_error else '' + self.original_error_type = type(original_error).__name__ if original_error else "" - def show(self, file=None, err=True): + def show(self, file: IO[Any] | None = None, err: bool = True) -> None: """Print the error to the console.""" - # err_msg = f' {self.original_error_type}' if self.original_error else '' - msg = f'{click.style(f"CLI Error ({self.original_error_type})", fg="red", bold=True)}: {self.format_message()}' + msg = f"{click.style(f'CLI Error ({self.original_error_type})', fg='red', bold=True)}: {self.format_message()}" click.echo(msg, err=err, file=file) -def client_error(message, exc: Exception = None, debug=None, ctx: click.Context = None, file=None, - err=None) -> NoReturn: - config_debug = True if ctx and ctx.ensure_object(dict) and ctx.obj.get('debug') is True else False +def raise_client_error( # noqa: PLR0913 + message: str, + exc: Exception | None = None, + debug: bool | None = False, + ctx: click.Context | None = None, + file: IO[Any] | None = None, + err: bool = False, +) -> NoReturn: + config_debug = bool(ctx and ctx.ensure_object(dict) and ctx.obj.get("debug")) # type: ignore[reportUnknownArgumentType] debug = debug if debug is not None else config_debug if debug: - click.echo(click.style('DEBUG: ', fg='yellow') + message, err=err, file=file) - raise - else: + click.echo(click.style("DEBUG: ", fg="yellow") + message, err=err, file=file) raise ClientError(message, original_error=exc) + raise ClientError(message, original_error=exc) -def nested_get(_dict, dot_key, default=None): +def nested_get(_dict: dict[str, Any] | None, dot_key: str | None, default: Any | None = None) -> Any: """Get a nested field from a nested dict with dot notation.""" if _dict is None or dot_key is None: return default - elif '.' in dot_key and isinstance(_dict, dict): - dot_key = dot_key.split('.') - this_key = dot_key.pop(0) - return nested_get(_dict.get(this_key, default), '.'.join(dot_key), default) - else: - return _dict.get(dot_key, default) + if "." in dot_key: + dot_key_parts = dot_key.split(".") + this_key = dot_key_parts.pop(0) + return nested_get(_dict.get(this_key, default), ".".join(dot_key_parts), default) + return _dict.get(dot_key, default) -def nested_set(_dict, dot_key, value): +def nested_set(_dict: dict[str, Any], dot_key: str, value: Any) -> None: """Set a nested field from a key in dot notation.""" - keys = dot_key.split('.') + keys = dot_key.split(".") for key in keys[:-1]: _dict = _dict.setdefault(key, {}) - if isinstance(_dict, dict): - _dict[keys[-1]] = value - else: - raise ValueError('dict cannot set a value to a non-dict for {}'.format(dot_key)) + _dict[keys[-1]] = value -def nest_from_dot(dots, value): +def nest_from_dot(dots: str, value: Any) -> Any: """Nest a dotted field and set the innermost value.""" - fields = dots.split('.') + fields = dots.split(".") if not fields: return {} @@ -107,149 +107,187 @@ def nest_from_dot(dots, value): return nested -def schema_prompt(name, value=None, is_required=False, **options): +def schema_prompt(name: str, value: Any | None = None, is_required: bool = False, **options: Any) -> Any: # noqa: PLR0911, PLR0912, PLR0915 """Interactively prompt based on schema requirements.""" - name = str(name) - field_type = options.get('type') - pattern = options.get('pattern') - enum = options.get('enum', []) - minimum = options.get('minimum') - maximum = options.get('maximum') - min_item = options.get('min_items', 0) - max_items = options.get('max_items', 9999) - - default = options.get('default') - if default is not None and str(default).lower() in ('true', 'false'): + field_type = options.get("type") + pattern: str | None = options.get("pattern") + enum = options.get("enum", []) + minimum = int(options["minimum"]) if "minimum" in options else None + maximum = int(options["maximum"]) if "maximum" in options else None + min_item = int(options.get("min_items", 0)) + max_items = int(options.get("max_items", 9999)) + + default = options.get("default") + if default is not None and str(default).lower() in ("true", "false"): default = str(default).lower() - if 'date' in name: - default = time.strftime('%Y/%m/%d') + if "date" in name: + default = time.strftime("%Y/%m/%d") - if name == 'rule_id': + if name == "rule_id": default = str(uuid.uuid4()) if len(enum) == 1 and is_required and field_type != "array": return enum[0] - def _check_type(_val): - if field_type in ('number', 'integer') and not str(_val).isdigit(): - print('Number expected but got: {}'.format(_val)) - return False - if pattern and (not re.match(pattern, _val) or len(re.match(pattern, _val).group(0)) != len(_val)): - print('{} did not match pattern: {}!'.format(_val, pattern)) + def _check_type(_val: Any) -> bool: # noqa: PLR0911 + if field_type in ("number", "integer") and not str(_val).isdigit(): + print(f"Number expected but got: {_val}") return False + if pattern: + match = re.match(pattern, _val) + if not match or len(match.group(0)) != len(_val): + print(f"{_val} did not match pattern: {pattern}!") + return False if enum and _val not in enum: - print('{} not in valid options: {}'.format(_val, ', '.join(enum))) + print("{} not in valid options: {}".format(_val, ", ".join(enum))) return False if minimum and (type(_val) is int and int(_val) < minimum): - print('{} is less than the minimum: {}'.format(str(_val), str(minimum))) + print(f"{_val!s} is less than the minimum: {minimum!s}") return False if maximum and (type(_val) is int and int(_val) > maximum): - print('{} is greater than the maximum: {}'.format(str(_val), str(maximum))) + print(f"{_val!s} is greater than the maximum: {maximum!s}") return False - if field_type == 'boolean' and _val.lower() not in ('true', 'false'): - print('Boolean expected but got: {}'.format(str(_val))) + if type(_val) is str and field_type == "boolean" and _val.lower() not in ("true", "false"): + print(f"Boolean expected but got: {_val!s}") return False return True - def _convert_type(_val): - if field_type == 'boolean' and not type(_val) is bool: - _val = True if _val.lower() == 'true' else False - return int(_val) if field_type in ('number', 'integer') else _val - - prompt = '{name}{default}{required}{multi}'.format( - name=name, - default=' [{}] ("n/a" to leave blank) '.format(default) if default else '', - required=' (required) ' if is_required else '', - multi=' (multi, comma separated) ' if field_type == 'array' else '').strip() + ': ' + def _convert_type(_val: Any) -> Any: + if field_type == "boolean" and type(_val) is not bool: + _val = _val.lower() == "true" + return int(_val) if field_type in ("number", "integer") else _val + + prompt = ( + "{name}{default}{required}{multi}".format( + name=name, + default=f' [{default}] ("n/a" to leave blank) ' if default else "", + required=" (required) " if is_required else "", + multi=" (multi, comma separated) " if field_type == "array" else "", + ).strip() + + ": " + ) while True: result = value or input(prompt) or default - if result == 'n/a': + if result == "n/a": result = None if not result: if is_required: value = None continue - else: - return + return None - if field_type == 'array': - result_list = result.split(',') + if field_type == "array": + result_list = result.split(",") if not (min_item < len(result_list) < max_items): if is_required: value = None break - else: - return [] + return [] for value in result_list: if not _check_type(value): if is_required: - value = None + value = None # noqa: PLW2901 break - else: - return [] + return [] if is_required and value is None: continue - else: - return [_convert_type(r) for r in result_list] - else: - if _check_type(result): - return _convert_type(result) - elif is_required: - value = None - continue - return + return [_convert_type(r) for r in result_list] + if _check_type(result): + return _convert_type(result) + if is_required: + value = None + continue + return None -def get_kibana_rules_map(repo='elastic/kibana', branch='master'): +def get_kibana_rules_map(repo: str = "elastic/kibana", branch: str = "master") -> dict[str, Any]: """Get list of available rules from the Kibana repo and return a list of URLs.""" + + timeout = 30 # secs + # ensure branch exists - r = requests.get(f'https://api.github.com/repos/{repo}/branches/{branch}') + r = requests.get(f"https://api.github.com/repos/{repo}/branches/{branch}", timeout=timeout) r.raise_for_status() - url = ('https://api.github.com/repos/{repo}/contents/x-pack/{legacy}plugins/{app}/server/lib/' - 'detection_engine/rules/prepackaged_rules?ref={branch}') + url = ( + "https://api.github.com/repos/{repo}/contents/x-pack/{legacy}plugins/{app}/server/lib/" + "detection_engine/rules/prepackaged_rules?ref={branch}" + ) + + r = requests.get(url.format(legacy="", app="security_solution", branch=branch, repo=repo), timeout=timeout) + r.raise_for_status() - gh_rules = requests.get(url.format(legacy='', app='security_solution', branch=branch, repo=repo)).json() + gh_rules = r.json() # pre-7.9 app was siem - if isinstance(gh_rules, dict) and gh_rules.get('message', '') == 'Not Found': - gh_rules = requests.get(url.format(legacy='', app='siem', branch=branch, repo=repo)).json() + if isinstance(gh_rules, dict) and gh_rules.get("message", "") == "Not Found": # type: ignore[reportUnknownMemberType] + gh_rules = requests.get(url.format(legacy="", app="siem", branch=branch, repo=repo), timeout=timeout).json() # pre-7.8 the siem was under the legacy directory - if isinstance(gh_rules, dict) and gh_rules.get('message', '') == 'Not Found': - gh_rules = requests.get(url.format(legacy='legacy/', app='siem', branch=branch, repo=repo)).json() + if isinstance(gh_rules, dict) and gh_rules.get("message", "") == "Not Found": # type: ignore[reportUnknownMemberType] + gh_rules = requests.get( + url.format(legacy="legacy/", app="siem", branch=branch, repo=repo), timeout=timeout + ).json() + + if isinstance(gh_rules, dict) and gh_rules.get("message", "") == "Not Found": # type: ignore[reportUnknownMemberType] + raise ValueError(f"rules directory does not exist for {repo} branch: {branch}") + + if not isinstance(gh_rules, list): + raise TypeError("Expected to receive a list") - if isinstance(gh_rules, dict) and gh_rules.get('message', '') == 'Not Found': - raise ValueError(f'rules directory does not exist for {repo} branch: {branch}') + results: dict[str, Any] = {} - return {os.path.splitext(r['name'])[0]: r['download_url'] for r in gh_rules if r['name'].endswith('.json')} + for r in gh_rules: # type: ignore[reportUnknownMemberType] + if "name" not in r: + raise ValueError("Name value is expected") + name = r["name"] # type: ignore[reportUnknownMemberType] -def get_kibana_rules(*rule_paths, repo='elastic/kibana', branch='master', verbose=True, threads=50): + if not isinstance(name, str): + raise TypeError("String value is expected for name") + + if name.endswith(".json"): + key = Path(name).name + val = r["download_url"] # type: ignore[reportUnknownMemberType] + results[key] = val + + return results + + +def get_kibana_rules( + repo: str = "elastic/kibana", + branch: str = "master", + verbose: bool = True, + threads: int = 50, + rule_paths: list[str] | None = None, +) -> dict[str, Any]: """Retrieve prepackaged rules from kibana repo.""" from multiprocessing.pool import ThreadPool - kibana_rules = {} + kibana_rules: dict[str, Any] = {} if verbose: - thread_use = f' using {threads} threads' if threads > 1 else '' - click.echo(f'Downloading rules from {repo} {branch} branch in kibana repo{thread_use} ...') + thread_use = f" using {threads} threads" if threads > 1 else "" + click.echo(f"Downloading rules from {repo} {branch} branch in kibana repo{thread_use} ...") - rule_paths = [os.path.splitext(os.path.basename(p))[0] for p in rule_paths] - rules_mapping = [(n, u) for n, u in get_kibana_rules_map(repo=repo, branch=branch).items() if n in rule_paths] \ - if rule_paths else get_kibana_rules_map(repo=repo, branch=branch).items() + rule_paths = [os.path.splitext(os.path.basename(p))[0] for p in (rule_paths or [])] # noqa: PTH119, PTH122 + rules_mapping = ( + [(n, u) for n, u in get_kibana_rules_map(repo=repo, branch=branch).items() if n in rule_paths] + if rule_paths + else get_kibana_rules_map(repo=repo, branch=branch).items() + ) - def download_worker(rule_info): + def download_worker(rule_info: tuple[str, str]) -> None: n, u = rule_info - kibana_rules[n] = requests.get(u).json() + kibana_rules[n] = requests.get(u, timeout=30).json() pool = ThreadPool(processes=threads) - pool.map(download_worker, rules_mapping) + _ = pool.map(download_worker, rules_mapping) pool.close() pool.join() @@ -259,81 +297,92 @@ def download_worker(rule_info): @cached def load_current_package_version() -> str: """Load the current package version from config file.""" - return load_etc_dump('packages.yaml')['package']['name'] + data = load_etc_dump(["packages.yaml"]) + return data["package"]["name"] -def get_default_config() -> Optional[Path]: - return next(get_path().glob('.detection-rules-cfg.*'), None) +def get_default_config() -> Path | None: + return next(ROOT_DIR.glob(".detection-rules-cfg.*"), None) @cached -def parse_user_config(): +def parse_user_config() -> dict[str, Any]: """Parse a default config file.""" - import eql + import eql # type: ignore[reportMissingTypeStubs] config_file = get_default_config() config = {} if config_file and config_file.exists(): - config = eql.utils.load_dump(str(config_file)) - - click.secho(f'Loaded config file: {config_file}', fg='yellow') + config = eql.utils.load_dump(str(config_file)) # type: ignore[reportUnknownMemberType] + click.secho(f"Loaded config file: {config_file}", fg="yellow") return config -def discover_tests(start_dir: str = 'tests', pattern: str = 'test*.py', top_level_dir: Optional[str] = None): +def discover_tests(start_dir: str = "tests", pattern: str = "test*.py", top_level_dir: str | None = None) -> list[str]: """Discover all unit tests in a directory.""" - def list_tests(s, tests=None): - if tests is None: - tests = [] + + tests: list[str] = [] + + def list_tests(s: unittest.TestSuite) -> None: for test in s: if isinstance(test, unittest.TestSuite): - list_tests(test, tests) + list_tests(test) else: tests.append(test.id()) - return tests loader = unittest.defaultTestLoader suite = loader.discover(start_dir, pattern=pattern, top_level_dir=top_level_dir or str(ROOT_DIR)) - return list_tests(suite) + list_tests(suite) + return tests -def getdefault(name): +def getdefault(name: str) -> Callable[[], Any]: """Callback function for `default` to get an environment variable.""" envvar = f"DR_{name.upper()}" config = parse_user_config() return lambda: os.environ.get(envvar, config.get(name)) -def get_elasticsearch_client(cloud_id: str = None, elasticsearch_url: str = None, es_user: str = None, - es_password: str = None, ctx: click.Context = None, api_key: str = None, **kwargs): +def get_elasticsearch_client( # noqa: PLR0913 + cloud_id: str | None = None, + elasticsearch_url: str | None = None, + es_user: str | None = None, + es_password: str | None = None, + ctx: click.Context | None = None, + api_key: str | None = None, + **kwargs: Any, +) -> Elasticsearch: """Get an authenticated elasticsearch client.""" - from elasticsearch import AuthenticationException, Elasticsearch if not (cloud_id or elasticsearch_url): - client_error("Missing required --cloud-id or --elasticsearch-url") + raise_client_error("Missing required --cloud-id or --elasticsearch-url") # don't prompt for these until there's a cloud id or elasticsearch URL - basic_auth: (str, str) | None = None + basic_auth: tuple[str, str] | None = None if not api_key: es_user = es_user or click.prompt("es_user") es_password = es_password or click.prompt("es_password", hide_input=True) + if not es_user or not es_password: + raise ValueError("Both username and password must be provided") basic_auth = (es_user, es_password) hosts = [elasticsearch_url] if elasticsearch_url else None - timeout = kwargs.pop('timeout', 60) - kwargs['verify_certs'] = not kwargs.pop('ignore_ssl_errors', False) + timeout = kwargs.pop("timeout", 60) + kwargs["verify_certs"] = not kwargs.pop("ignore_ssl_errors", False) try: - client = Elasticsearch(hosts=hosts, cloud_id=cloud_id, http_auth=basic_auth, timeout=timeout, api_key=api_key, - **kwargs) + client = Elasticsearch( + hosts=hosts, cloud_id=cloud_id, http_auth=basic_auth, timeout=timeout, api_key=api_key, **kwargs + ) # force login to test auth - client.info() - return client + _ = client.info() except AuthenticationException as e: - error_msg = f'Failed authentication for {elasticsearch_url or cloud_id}' - client_error(error_msg, e, ctx=ctx, err=True) + error_msg = f"Failed authentication for {elasticsearch_url or cloud_id}" + raise_client_error(error_msg, e, ctx=ctx, err=True) + else: + return client def get_kibana_client( @@ -343,71 +392,68 @@ def get_kibana_client( kibana_url: str | None = None, space: str | None = None, ignore_ssl_errors: bool = False, - **kwargs -): + **kwargs: Any, +) -> Kibana: """Get an authenticated Kibana client.""" if not (cloud_id or kibana_url): - client_error("Missing required --cloud-id or --kibana-url") + raise_client_error("Missing required --cloud-id or --kibana-url") verify = not ignore_ssl_errors return Kibana(cloud_id=cloud_id, kibana_url=kibana_url, space=space, verify=verify, api_key=api_key, **kwargs) client_options = { - 'kibana': { - 'kibana_url': click.Option(['--kibana-url'], default=getdefault('kibana_url')), - 'cloud_id': click.Option(['--cloud-id'], default=getdefault('cloud_id'), help="ID of the cloud instance."), - 'api_key': click.Option(['--api-key'], default=getdefault('api_key')), - 'space': click.Option(['--space'], default=None, help='Kibana space'), - 'ignore_ssl_errors': click.Option(['--ignore-ssl-errors'], default=getdefault('ignore_ssl_errors')) + "kibana": { + "kibana_url": click.Option(["--kibana-url"], default=getdefault("kibana_url")), + "cloud_id": click.Option(["--cloud-id"], default=getdefault("cloud_id"), help="ID of the cloud instance."), + "api_key": click.Option(["--api-key"], default=getdefault("api_key")), + "space": click.Option(["--space"], default=None, help="Kibana space"), + "ignore_ssl_errors": click.Option(["--ignore-ssl-errors"], default=getdefault("ignore_ssl_errors")), + }, + "elasticsearch": { + "cloud_id": click.Option(["--cloud-id"], default=getdefault("cloud_id")), + "api_key": click.Option(["--api-key"], default=getdefault("api_key")), + "elasticsearch_url": click.Option(["--elasticsearch-url"], default=getdefault("elasticsearch_url")), + "es_user": click.Option(["--es-user", "-eu"], default=getdefault("es_user")), + "es_password": click.Option(["--es-password", "-ep"], default=getdefault("es_password")), + "timeout": click.Option(["--timeout", "-et"], default=60, help="Timeout for elasticsearch client"), + "ignore_ssl_errors": click.Option(["--ignore-ssl-errors"], default=getdefault("ignore_ssl_errors")), }, - 'elasticsearch': { - 'cloud_id': click.Option(['--cloud-id'], default=getdefault("cloud_id")), - 'api_key': click.Option(['--api-key'], default=getdefault('api_key')), - 'elasticsearch_url': click.Option(['--elasticsearch-url'], default=getdefault("elasticsearch_url")), - 'es_user': click.Option(['--es-user', '-eu'], default=getdefault("es_user")), - 'es_password': click.Option(['--es-password', '-ep'], default=getdefault("es_password")), - 'timeout': click.Option(['--timeout', '-et'], default=60, help='Timeout for elasticsearch client'), - 'ignore_ssl_errors': click.Option(['--ignore-ssl-errors'], default=getdefault('ignore_ssl_errors')) - } } -kibana_options = list(client_options['kibana'].values()) -elasticsearch_options = list(client_options['elasticsearch'].values()) +kibana_options = list(client_options["kibana"].values()) +elasticsearch_options = list(client_options["elasticsearch"].values()) -def add_client(*client_type, add_to_ctx=True, add_func_arg=True): +def add_client(client_types: list[str], add_to_ctx: bool = True, add_func_arg: bool = True) -> Callable[..., Any]: """Wrapper to add authed client.""" - from elasticsearch import Elasticsearch - from elasticsearch.exceptions import AuthenticationException - from kibana import Kibana - - def _wrapper(func): - client_ops_dict = {} - client_ops_keys = {} - for c_type in client_type: - ops = client_options.get(c_type) + + def _wrapper(func: Callable[..., Any]) -> Callable[..., Any]: + client_ops_dict: dict[str, click.Option] = {} + client_ops_keys: dict[str, list[str]] = {} + for c_type in client_types: + ops = client_options[c_type] client_ops_dict.update(ops) client_ops_keys[c_type] = list(ops) if not client_ops_dict: - raise ValueError(f'Unknown client: {client_type} in {func.__name__}') + client_types_str = ", ".join(client_types) + raise ValueError(f"Unknown client: {client_types_str} in {func.__name__}") client_ops = list(client_ops_dict.values()) @wraps(func) @add_params(*client_ops) - def _wrapped(*args, **kwargs): - ctx: click.Context = next((a for a in args if isinstance(a, click.Context)), None) - es_client_args = {k: kwargs.pop(k, None) for k in client_ops_keys.get('elasticsearch', [])} + def _wrapped(*args: Any, **kwargs: Any) -> Any: # noqa: PLR0912 + ctx: click.Context | None = next((a for a in args if isinstance(a, click.Context)), None) + es_client_args = {k: kwargs.pop(k, None) for k in client_ops_keys.get("elasticsearch", [])} # shared args like cloud_id - kibana_client_args = {k: kwargs.pop(k, es_client_args.get(k)) for k in client_ops_keys.get('kibana', [])} + kibana_client_args = {k: kwargs.pop(k, es_client_args.get(k)) for k in client_ops_keys.get("kibana", [])} - if 'elasticsearch' in client_type: + if "elasticsearch" in client_types: # for nested ctx invocation, no need to re-auth if an existing client is already passed - elasticsearch_client: Elasticsearch = kwargs.get('elasticsearch_client') + elasticsearch_client: Elasticsearch | None = kwargs.get("elasticsearch_client") try: - if elasticsearch_client and isinstance(elasticsearch_client, Elasticsearch) and \ - elasticsearch_client.info(): + if elasticsearch_client and elasticsearch_client.info(): pass else: elasticsearch_client = get_elasticsearch_client(**es_client_args) @@ -415,15 +461,14 @@ def _wrapped(*args, **kwargs): elasticsearch_client = get_elasticsearch_client(**es_client_args) if add_func_arg: - kwargs['elasticsearch_client'] = elasticsearch_client + kwargs["elasticsearch_client"] = elasticsearch_client if ctx and add_to_ctx: - ctx.obj['es'] = elasticsearch_client + ctx.obj["es"] = elasticsearch_client - if 'kibana' in client_type: + if "kibana" in client_types: # for nested ctx invocation, no need to re-auth if an existing client is already passed - kibana_client: Kibana = kwargs.get('kibana_client') - if kibana_client and isinstance(kibana_client, Kibana): - + kibana_client: Kibana | None = kwargs.get("kibana_client") + if kibana_client: try: with kibana_client: if kibana_client.version: @@ -435,9 +480,9 @@ def _wrapped(*args, **kwargs): kibana_client = get_kibana_client(**kibana_client_args) if add_func_arg: - kwargs['kibana_client'] = kibana_client + kwargs["kibana_client"] = kibana_client if ctx and add_to_ctx: - ctx.obj['kibana'] = kibana_client + ctx.obj["kibana"] = kibana_client return func(*args, **kwargs) diff --git a/detection_rules/mixins.py b/detection_rules/mixins.py index b22677d2920..e02a2ce3751 100644 --- a/detection_rules/mixins.py +++ b/detection_rules/mixins.py @@ -6,43 +6,42 @@ """Generic mixin classes.""" import dataclasses +import json from pathlib import Path -from typing import Any, Optional, TypeVar, Type, Literal +from typing import Any, Literal -import json +import marshmallow import marshmallow_dataclass import marshmallow_dataclass.union_field -import marshmallow_jsonschema -import marshmallow_union -import marshmallow -from marshmallow import Schema, ValidationError, validates_schema, fields as marshmallow_fields +import marshmallow_jsonschema # type: ignore[reportMissingTypeStubs] +import marshmallow_union # type: ignore[reportMissingTypeStubs] +from marshmallow import Schema, ValidationError, validates_schema +from marshmallow import fields as marshmallow_fields +from semver import Version from .config import load_current_package_version from .schemas import definitions from .schemas.stack_compat import get_incompatible_fields -from semver import Version from .utils import cached, dict_hash -T = TypeVar('T') -ClassT = TypeVar('ClassT') # bound=dataclass? -UNKNOWN_VALUES = Literal['raise', 'exclude', 'include'] +UNKNOWN_VALUES = Literal["raise", "exclude", "include"] -def _strip_none_from_dict(obj: T) -> T: +def _strip_none_from_dict(obj: Any) -> Any: """Strip none values from a dict recursively.""" if isinstance(obj, dict): - return {key: _strip_none_from_dict(value) for key, value in obj.items() if value is not None} + return {key: _strip_none_from_dict(value) for key, value in obj.items() if value is not None} # type: ignore[reportUnknownVariableType] if isinstance(obj, list): - return [_strip_none_from_dict(o) for o in obj] + return [_strip_none_from_dict(o) for o in obj] # type: ignore[reportUnknownVariableType] if isinstance(obj, tuple): - return tuple(_strip_none_from_dict(list(obj))) + return tuple(_strip_none_from_dict(list(obj))) # type: ignore[reportUnknownVariableType] return obj -def patch_jsonschema(obj: dict) -> dict: +def patch_jsonschema(obj: Any) -> dict[str, Any]: """Patch marshmallow-jsonschema output to look more like JSL.""" - def dive(child: dict) -> dict: + def dive(child: dict[str, Any]) -> dict[str, Any]: if "$ref" in child: name = child["$ref"].split("/")[-1] definition = obj["definitions"][name] @@ -58,10 +57,12 @@ def dive(child: dict) -> dict: child["anyOf"] = [dive(c) for c in child["anyOf"]] elif isinstance(child["type"], list): - if 'null' in child["type"]: - child["type"] = [t for t in child["type"] if t != 'null'] + type_vals: list[str] = child["type"] # type: ignore[reportUnknownVariableType] + + if "null" in type_vals: + child["type"] = [t for t in type_vals if t != "null"] - if len(child["type"]) == 1: + if len(type_vals) == 1: child["type"] = child["type"][0] if "items" in child: @@ -86,29 +87,9 @@ def dive(child: dict) -> dict: class BaseSchema(Schema): """Base schema for marshmallow dataclasses with unknown.""" - class Meta: - """Meta class for marshmallow schema.""" - - -def exclude_class_schema( - clazz, base_schema: type[Schema] = BaseSchema, unknown: UNKNOWN_VALUES = marshmallow.EXCLUDE, **kwargs -) -> type[Schema]: - """Get a marshmallow schema for a dataclass with unknown=EXCLUDE.""" - base_schema.Meta.unknown = unknown - return marshmallow_dataclass.class_schema(clazz, base_schema=base_schema, **kwargs) - -def recursive_class_schema( - clazz, base_schema: type[Schema] = BaseSchema, unknown: UNKNOWN_VALUES = marshmallow.EXCLUDE, **kwargs -) -> type[Schema]: - """Recursively apply the unknown parameter for nested schemas.""" - schema = exclude_class_schema(clazz, base_schema=base_schema, unknown=unknown, **kwargs) - for field in dataclasses.fields(clazz): - if dataclasses.is_dataclass(field.type): - nested_cls = field.type - nested_schema = recursive_class_schema(nested_cls, base_schema=base_schema, **kwargs) - setattr(schema, field.name, nested_schema) - return schema + class Meta: # type: ignore[reportIncompatibleVariableOverride] + """Meta class for marshmallow schema.""" class MarshmallowDataclassMixin: @@ -116,35 +97,33 @@ class MarshmallowDataclassMixin: @classmethod @cached - def __schema(cls: ClassT, unknown: Optional[UNKNOWN_VALUES] = None) -> Schema: + def __schema(cls, unknown: UNKNOWN_VALUES | None = None) -> Schema: """Get the marshmallow schema for the data class""" if unknown: return recursive_class_schema(cls, unknown=unknown)() - else: - return marshmallow_dataclass.class_schema(cls)() + return marshmallow_dataclass.class_schema(cls)() - def get(self, key: str, default: Optional[Any] = None): + def get(self, key: str, default: Any = None) -> Any: """Get a key from the query data without raising attribute errors.""" return getattr(self, key, default) @classmethod @cached - def jsonschema(cls): + def jsonschema(cls) -> dict[str, Any]: """Get the jsonschema representation for this class.""" - jsonschema = PatchedJSONSchema().dump(cls.__schema()) - jsonschema = patch_jsonschema(jsonschema) - return jsonschema + jsonschema = PatchedJSONSchema().dump(cls.__schema()) # type: ignore[reportUnknownMemberType] + return patch_jsonschema(jsonschema) @classmethod - def from_dict(cls: Type[ClassT], obj: dict, unknown: Optional[UNKNOWN_VALUES] = None) -> ClassT: + def from_dict(cls, obj: dict[str, Any], unknown: UNKNOWN_VALUES | None = None) -> Any: """Deserialize and validate a dataclass from a dict using marshmallow.""" schema = cls.__schema(unknown=unknown) return schema.load(obj) - def to_dict(self, strip_none_values=True) -> dict: + def to_dict(self, strip_none_values: bool = True) -> dict[str, Any]: """Serialize a dataclass to a dictionary using marshmallow.""" schema = self.__schema() - serialized: dict = schema.dump(self) + serialized = schema.dump(self) if strip_none_values: serialized = _strip_none_from_dict(serialized) @@ -152,69 +131,103 @@ def to_dict(self, strip_none_values=True) -> dict: return serialized +def exclude_class_schema( + cls: type, + base_schema: type[Schema] = BaseSchema, + unknown: UNKNOWN_VALUES = marshmallow.EXCLUDE, + **kwargs: dict[str, Any], +) -> type[Schema]: + """Get a marshmallow schema for a dataclass with unknown=EXCLUDE.""" + base_schema.Meta.unknown = unknown # type: ignore[reportAttributeAccessIssue] + return marshmallow_dataclass.class_schema(cls, base_schema=base_schema, **kwargs) + + +def recursive_class_schema( + cls: type, + base_schema: type[Schema] = BaseSchema, + unknown: UNKNOWN_VALUES = marshmallow.EXCLUDE, + **kwargs: dict[str, Any], +) -> type[Schema]: + """Recursively apply the unknown parameter for nested schemas.""" + schema = exclude_class_schema(cls, base_schema=base_schema, unknown=unknown, **kwargs) + for field in dataclasses.fields(cls): + if dataclasses.is_dataclass(field.type): + nested_cls = field.type + nested_schema = recursive_class_schema( + nested_cls, # type: ignore[reportArgumentType] + base_schema=base_schema, + unknown=unknown, + **kwargs, + ) + setattr(schema, field.name, nested_schema) + return schema + + class LockDataclassMixin: """Mixin class for version and deprecated rules lock files.""" @classmethod @cached - def __schema(cls: ClassT) -> Schema: + def __schema(cls) -> Schema: """Get the marshmallow schema for the data class""" return marshmallow_dataclass.class_schema(cls)() - def get(self, key: str, default: Optional[Any] = None): + def get(self, key: str, default: Any = None) -> Any: """Get a key from the query data without raising attribute errors.""" return getattr(self, key, default) @classmethod - def from_dict(cls: Type[ClassT], obj: dict) -> ClassT: + def from_dict(cls, obj: dict[str, Any]) -> Any: """Deserialize and validate a dataclass from a dict using marshmallow.""" schema = cls.__schema() try: loaded = schema.load(obj) except ValidationError as e: - err_msg = json.dumps(e.messages, indent=2) - raise ValidationError(f'Validation error loading: {cls.__name__}\n{err_msg}') from None + err_msg = json.dumps(e.normalized_messages(), indent=2) + raise ValidationError(f"Validation error loading: {cls.__name__}\n{err_msg}") from e return loaded - def to_dict(self, strip_none_values=True) -> dict: + def to_dict(self, strip_none_values: bool = True) -> dict[str, Any]: """Serialize a dataclass to a dictionary using marshmallow.""" schema = self.__schema() - serialized: dict = schema.dump(self) + serialized: dict[str, Any] = schema.dump(self) if strip_none_values: serialized = _strip_none_from_dict(serialized) - return serialized['data'] + return serialized["data"] @classmethod - def load_from_file(cls: Type[ClassT], lock_file: Optional[Path] = None) -> ClassT: + def load_from_file(cls, lock_file: Path | None = None) -> Any: """Load and validate a version lock file.""" - path: Path = getattr(cls, 'file_path', lock_file) + path = getattr(cls, "file_path", lock_file) + if not path: + raise ValueError("No file path found") contents = json.loads(path.read_text()) - loaded = cls.from_dict(dict(data=contents)) - return loaded + return cls.from_dict({"data": contents}) def sha256(self) -> definitions.Sha256: """Get the sha256 hash of the version lock contents.""" contents = self.to_dict() return dict_hash(contents) - def save_to_file(self, lock_file: Optional[Path] = None): + def save_to_file(self, lock_file: Path | None = None) -> None: """Save and validate a version lock file.""" - path: Path = lock_file or getattr(self, 'file_path', None) - assert path, 'No path passed or set' + path = lock_file or getattr(self, "file_path", None) + if not path: + raise ValueError("No file path found") contents = self.to_dict() - path.write_text(json.dumps(contents, indent=2, sort_keys=True)) + _ = path.write_text(json.dumps(contents, indent=2, sort_keys=True)) class StackCompatMixin: """Mixin to restrict schema compatibility to defined stack versions.""" @validates_schema - def validate_field_compatibility(self, data: dict, **kwargs): + def validate_field_compatibility(self, data: dict[str, Any], **_: dict[str, Any]) -> None: """Verify stack-specific fields are properly applied to schema.""" package_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True) - schema_fields = getattr(self, 'fields', {}) + schema_fields = getattr(self, "fields", {}) incompatible = get_incompatible_fields(list(schema_fields.values()), package_version) if not incompatible: return @@ -223,29 +236,33 @@ def validate_field_compatibility(self, data: dict, **kwargs): for field, bounds in incompatible.items(): min_compat, max_compat = bounds if data.get(field) is not None: - raise ValidationError(f'Invalid field: "{field}" for stack version: {package_version}, ' - f'min compatibility: {min_compat}, max compatibility: {max_compat}') + raise ValidationError( + f'Invalid field: "{field}" for stack version: {package_version}, ' + f"min compatibility: {min_compat}, max compatibility: {max_compat}" + ) class PatchedJSONSchema(marshmallow_jsonschema.JSONSchema): - # Patch marshmallow-jsonschema to support marshmallow-dataclass[union] - def _get_schema_for_field(self, obj, field): + def _get_schema_for_field(self, obj: Any, field: Any) -> Any: """Patch marshmallow_jsonschema.base.JSONSchema to support marshmallow-dataclass[union].""" if isinstance(field, marshmallow_fields.Raw) and field.allow_none and not field.validate: # raw fields shouldn't be type string but type any. bug in marshmallow_dataclass:__init__.py: - # if typ is Any: - # metadata.setdefault("allow_none", True) - # return marshmallow.fields.Raw(**metadata) return {"type": ["string", "number", "object", "array", "boolean", "null"]} if isinstance(field, marshmallow_dataclass.union_field.Union): # convert to marshmallow_union.Union - field = marshmallow_union.Union([subfield for _, subfield in field.union_fields], - metadata=field.metadata, - required=field.required, name=field.name, - parent=field.parent, root=field.root, error_messages=field.error_messages, - default_error_messages=field.default_error_messages, default=field.default, - allow_none=field.allow_none) - - return super()._get_schema_for_field(obj, field) + field = marshmallow_union.Union( + [subfield for _, subfield in field.union_fields], + metadata=field.metadata, # type: ignore[reportUnknownMemberType] + required=field.required, + name=field.name, # type: ignore[reportUnknownMemberType] + parent=field.parent, # type: ignore[reportUnknownMemberType] + root=field.root, # type: ignore[reportUnknownMemberType] + error_messages=field.error_messages, + default_error_messages=field.default_error_messages, + default=field.default, # type: ignore[reportUnknownMemberType] + allow_none=field.allow_none, + ) + + return super()._get_schema_for_field(obj, field) # type: ignore[reportUnknownMemberType] diff --git a/detection_rules/ml.py b/detection_rules/ml.py index 13573ae51d4..d8930c0467e 100644 --- a/detection_rules/ml.py +++ b/detection_rules/ml.py @@ -6,15 +6,15 @@ """Schemas and dataclasses for experimental ML features.""" import io +import json import zipfile from dataclasses import dataclass from functools import cached_property, lru_cache from pathlib import Path -from typing import Dict, List, Literal, Optional +from typing import Any, Literal import click import elasticsearch -import json import requests from elasticsearch import Elasticsearch from elasticsearch.client import IngestClient, LicenseClient, MlClient @@ -23,15 +23,20 @@ from .schemas import definitions from .utils import get_path, unzip_to_dict +ML_PATH = get_path(["machine-learning"]) -ML_PATH = get_path('machine-learning') - -def info_from_tag(tag: str) -> (Literal['ml'], definitions.MachineLearningType, str, int): +def info_from_tag(tag: str) -> tuple[Literal["ml"], str, str, int]: try: - ml, release_type, release_date, release_number = tag.split('-') + ml, release_type, release_date, release_number = tag.split("-") except ValueError as exc: - raise ValueError(f'{tag} is not of valid release format: ml-type-date-number. {exc}') + raise ValueError(f"{tag} is not of valid release format: ml-type-date-number. {exc}") from exc + + if ml != "ml": + raise ValueError(f"Invalid type from the tag: {ml}") + + if release_type not in definitions.MACHINE_LEARNING_PACKAGES: + raise ValueError(f"Unexpected release type encountered: {release_type}") return ml, release_type, release_date, int(release_number) @@ -45,15 +50,15 @@ class MachineLearningClient: """Class for experimental machine learning release clients.""" es_client: Elasticsearch - bundle: dict + bundle: dict[str, Any] @cached_property def model_id(self) -> str: - return next(data['model_id'] for name, data in self.bundle.items() if Path(name).stem.lower().endswith('model')) + return next(data["model_id"] for name, data in self.bundle.items() if Path(name).stem.lower().endswith("model")) @cached_property def bundle_type(self) -> str: - return self.model_id.split('_')[0].lower() + return self.model_id.split("_")[0].lower() @cached_property def ml_client(self) -> MlClient: @@ -66,204 +71,214 @@ def ingest_client(self) -> IngestClient: @cached_property def license(self) -> str: license_client = LicenseClient(self.es_client) - return license_client.get()['license']['type'].lower() + return license_client.get()["license"]["type"].lower() @staticmethod @lru_cache - def ml_manifests() -> Dict[str, ReleaseManifest]: + def ml_manifests() -> dict[str, ReleaseManifest]: return get_ml_model_manifests_by_model_id() - def verify_license(self): - valid_license = self.license in ('platinum', 'enterprise') + def verify_license(self) -> None: + valid_license = self.license in ("platinum", "enterprise") if not valid_license: - err_msg = 'Your subscription level does not support Machine Learning. See ' \ - 'https://www.elastic.co/subscriptions for more information.' - raise InvalidLicenseError(err_msg) + raise InvalidLicenseError( + "Your subscription level does not support Machine Learning. See " + "https://www.elastic.co/subscriptions for more information." + ) @classmethod - def from_release(cls, es_client: Elasticsearch, release_tag: str, - repo: str = 'elastic/detection-rules') -> 'MachineLearningClient': + def from_release( + cls, es_client: Elasticsearch, release_tag: str, repo: str = "elastic/detection-rules" + ) -> "MachineLearningClient": """Load from a GitHub release.""" - full_type = '-'.join(info_from_tag(release_tag)[:2]) - release_url = f'https://api.github.com/repos/{repo}/releases/tags/{release_tag}' - release = requests.get(release_url) + + ml, release_type, _, _ = info_from_tag(release_tag) + + full_type = f"{ml}-{release_type}" + release_url = f"https://api.github.com/repos/{repo}/releases/tags/{release_tag}" + release = requests.get(release_url, timeout=30) release.raise_for_status() # check that the release only has a single zip file - assets = [a for a in release.json()['assets'] if - a['name'].startswith(full_type) and a['name'].endswith('.zip')] - assert len(assets) == 1, f'Malformed release: expected 1 {full_type} zip file, found: {len(assets)}!' + assets = [a for a in release.json()["assets"] if a["name"].startswith(full_type) and a["name"].endswith(".zip")] + if len(assets) != 1: + raise ValueError(f"Malformed release: expected 1 {full_type} zip file, found: {len(assets)}!") - zipped_url = assets[0]['browser_download_url'] - zipped_raw = requests.get(zipped_url) + zipped_url = assets[0]["browser_download_url"] + zipped_raw = requests.get(zipped_url, timeout=30) zipped_bundle = zipfile.ZipFile(io.BytesIO(zipped_raw.content)) bundle = unzip_to_dict(zipped_bundle) return cls(es_client=es_client, bundle=bundle) @classmethod - def from_directory(cls, es_client: Elasticsearch, directory: Path) -> 'MachineLearningClient': + def from_directory(cls, es_client: Elasticsearch, directory: Path) -> "MachineLearningClient": """Load from an unzipped local directory.""" bundle = json.loads(directory.read_text()) return cls(es_client=es_client, bundle=bundle) - def remove(self) -> dict: + def remove(self) -> dict[str, dict[str, Any]]: """Remove machine learning files from a stack.""" - results = dict(script={}, pipeline={}, model={}) + results = {"script": {}, "pipeline": {}, "model": {}} # type: ignore[reportUnknownVariableType] for pipeline in list(self.get_related_pipelines()): - results['pipeline'][pipeline] = self.ingest_client.delete_pipeline(pipeline) + results["pipeline"][pipeline] = self.ingest_client.delete_pipeline(id=pipeline) for script in list(self.get_related_scripts()): - results['script'][script] = self.es_client.delete_script(script) + results["script"][script] = self.es_client.delete_script(id=script) - results['model'][self.model_id] = self.ml_client.delete_trained_model(self.model_id) - return results + results["model"][self.model_id] = self.ml_client.delete_trained_model(model_id=self.model_id) + return results # type: ignore[reportUnknownVariableType] - def setup(self) -> dict: + def setup(self) -> dict[str, Any]: """Setup machine learning bundle on a stack.""" self.verify_license() - results = dict(script={}, pipeline={}, model={}) + results = {"script": {}, "pipeline": {}, "model": {}} # type: ignore[reportUnknownVariableType] # upload in order: model, scripts, then pipelines - parsed_bundle = dict(model={}, script={}, pipeline={}) + parsed_bundle = {"model": {}, "script": {}, "pipeline": {}} # type: ignore[reportUnknownVariableType] for filename, data in self.bundle.items(): fp = Path(filename) - file_type = fp.stem.split('_')[-1] + file_type = fp.stem.split("_")[-1] parsed_bundle[file_type][fp.stem] = data - model = list(parsed_bundle['model'].values())[0] - results['model'][model['model_id']] = self.upload_model(model['model_id'], model) + model = next(parsed_bundle["model"].values()) # type: ignore[reportArgumentType] + results["model"][model["model_id"]] = self.upload_model(model["model_id"], model) # type: ignore[reportUnknownArgumentType] - for script_name, script in parsed_bundle['script'].items(): - results['script'][script_name] = self.upload_script(script_name, script) + for script_name, script in parsed_bundle["script"].items(): # type: ignore[reportArgumentType] + results["script"][script_name] = self.upload_script(script_name, script) # type: ignore[reportUnknownArgumentType] - for pipeline_name, pipeline in parsed_bundle['pipeline'].items(): - results['pipeline'][pipeline_name] = self.upload_ingest_pipeline(pipeline_name, pipeline) + for pipeline_name, pipeline in parsed_bundle["pipeline"].items(): # type: ignore[reportArgumentType] + results["pipeline"][pipeline_name] = self.upload_ingest_pipeline(pipeline_name, pipeline) # type: ignore[reportUnknownArgumentType] - return results + return results # type: ignore[reportUnknownVariableType] - def get_all_scripts(self) -> Dict[str, dict]: + def get_all_scripts(self) -> dict[str, dict[str, Any]]: """Get all scripts from an elasticsearch instance.""" - return self.es_client.cluster.state()['metadata']['stored_scripts'] + return self.es_client.cluster.state()["metadata"]["stored_scripts"] - def get_related_scripts(self) -> Dict[str, dict]: + def get_related_scripts(self) -> dict[str, dict[str, Any]]: """Get all scripts which start with ml_*.""" scripts = self.get_all_scripts() - return {n: s for n, s in scripts.items() if n.lower().startswith(f'ml_{self.bundle_type}')} + return {n: s for n, s in scripts.items() if n.lower().startswith(f"ml_{self.bundle_type}")} - def get_related_pipelines(self) -> Dict[str, dict]: + def get_related_pipelines(self) -> dict[str, dict[str, Any]]: """Get all pipelines which start with ml_*.""" pipelines = self.ingest_client.get_pipeline() - return {n: s for n, s in pipelines.items() if n.lower().startswith(f'ml_{self.bundle_type}')} + return {n: s for n, s in pipelines.items() if n.lower().startswith(f"ml_{self.bundle_type}")} - def get_related_model(self) -> Optional[dict]: + def get_related_model(self) -> dict[str, Any] | None: """Get a model from an elasticsearch instance matching the model_id.""" for model in self.get_all_existing_model_files(): - if model['model_id'] == self.model_id: + if model["model_id"] == self.model_id: return model + return None - def get_all_existing_model_files(self) -> dict: + def get_all_existing_model_files(self) -> list[dict[str, Any]]: """Get available models from a stack.""" - return self.ml_client.get_trained_models()['trained_model_configs'] + return self.ml_client.get_trained_models()["trained_model_configs"] @classmethod - def get_existing_model_ids(cls, es_client: Elasticsearch) -> List[str]: + def get_existing_model_ids(cls, es_client: Elasticsearch) -> list[str]: """Get model IDs for existing ML models.""" ml_client = MlClient(es_client) - return [m['model_id'] for m in ml_client.get_trained_models()['trained_model_configs'] - if m['model_id'] in cls.ml_manifests()] + return [ + m["model_id"] + for m in ml_client.get_trained_models()["trained_model_configs"] + if m["model_id"] in cls.ml_manifests() + ] @classmethod def check_model_exists(cls, es_client: Elasticsearch, model_id: str) -> bool: """Check if a model exists on a stack by model id.""" ml_client = MlClient(es_client) - return model_id in [m['model_id'] for m in ml_client.get_trained_models()['trained_model_configs']] + return model_id in [m["model_id"] for m in ml_client.get_trained_models()["trained_model_configs"]] - def get_related_files(self) -> dict: + def get_related_files(self) -> dict[str, Any]: """Check for the presence and status of ML bundle files on a stack.""" - files = { - 'pipeline': self.get_related_pipelines(), - 'script': self.get_related_scripts(), - 'model': self.get_related_model(), - 'release': self.get_related_release() + return { + "pipeline": self.get_related_pipelines(), + "script": self.get_related_scripts(), + "model": self.get_related_model(), + "release": self.get_related_release(), } - return files def get_related_release(self) -> ReleaseManifest: """Get the GitHub release related to a model.""" - return self.ml_manifests.get(self.model_id) + return self.ml_manifests.get(self.model_id) # type: ignore[reportAttributeAccessIssue] @classmethod - def get_all_ml_files(cls, es_client: Elasticsearch) -> dict: + def get_all_ml_files(cls, es_client: Elasticsearch) -> dict[str, Any]: """Get all scripts, pipelines, and models which start with ml_*.""" pipelines = IngestClient(es_client).get_pipeline() - scripts = es_client.cluster.state()['metadata']['stored_scripts'] - models = MlClient(es_client).get_trained_models()['trained_model_configs'] + scripts = es_client.cluster.state()["metadata"]["stored_scripts"] + models = MlClient(es_client).get_trained_models()["trained_model_configs"] manifests = get_ml_model_manifests_by_model_id() - files = { - 'pipeline': {n: s for n, s in pipelines.items() if n.lower().startswith('ml_')}, - 'script': {n: s for n, s in scripts.items() if n.lower().startswith('ml_')}, - 'model': {m['model_id']: {'model': m, 'release': manifests[m['model_id']]} - for m in models if m['model_id'] in manifests}, + return { + "pipeline": {n: s for n, s in pipelines.items() if n.lower().startswith("ml_")}, + "script": {n: s for n, s in scripts.items() if n.lower().startswith("ml_")}, + "model": { + m["model_id"]: {"model": m, "release": manifests[m["model_id"]]} + for m in models + if m["model_id"] in manifests + }, } - return files @classmethod - def remove_ml_scripts_pipelines(cls, es_client: Elasticsearch, ml_type: List[str]) -> dict: + def remove_ml_scripts_pipelines(cls, es_client: Elasticsearch, ml_type: list[str]) -> dict[str, Any]: """Remove all ML script and pipeline files.""" - results = dict(script={}, pipeline={}) + results = {"script": {}, "pipeline": {}} # type: ignore[reportUnknownVariableType] ingest_client = IngestClient(es_client) files = cls.get_all_ml_files(es_client=es_client) for file_type, data in files.items(): for name in list(data): - this_type = name.split('_')[1].lower() + this_type = name.split("_")[1].lower() if this_type not in ml_type: continue - if file_type == 'script': - results[file_type][name] = es_client.delete_script(name) - elif file_type == 'pipeline': - results[file_type][name] = ingest_client.delete_pipeline(name) + if file_type == "script": + results[file_type][name] = es_client.delete_script(id=name) + elif file_type == "pipeline": + results[file_type][name] = ingest_client.delete_pipeline(id=name) - return results + return results # type: ignore[reportUnknownVariableType] - def upload_model(self, model_id: str, body: dict) -> dict: + def upload_model(self, model_id: str, body: dict[str, Any]) -> Any: """Upload an ML model file.""" return self.ml_client.put_trained_model(model_id=model_id, body=body) - def upload_script(self, script_id: str, body: dict) -> dict: + def upload_script(self, script_id: str, body: dict[str, Any]) -> Any: """Install a script file.""" return self.es_client.put_script(id=script_id, body=body) - def upload_ingest_pipeline(self, pipeline_id: str, body: dict) -> dict: + def upload_ingest_pipeline(self, pipeline_id: str, body: dict[str, Any]) -> Any: """Install a pipeline file.""" return self.ingest_client.put_pipeline(id=pipeline_id, body=body) @staticmethod - def _build_script_error(exc: elasticsearch.RequestError, pipeline_file: str): + def _build_script_error(exc: elasticsearch.RequestError, pipeline_file: str) -> str: """Build an error for a failed script upload.""" - error = exc.info['error'] - cause = error['caused_by'] + error = exc.info["error"] + cause = error["caused_by"] error_msg = [ - f'Script error while uploading {pipeline_file}: {cause["type"]} - {cause["reason"]}', - ' '.join(f'{k}: {v}' for k, v in error['position'].items()), - '\n'.join(error['script_stack']) + f"Script error while uploading {pipeline_file}: {cause['type']} - {cause['reason']}", + " ".join(f"{k}: {v}" for k, v in error["position"].items()), + "\n".join(error["script_stack"]), ] - return click.style('\n'.join(error_msg), fg='red') + return click.style("\n".join(error_msg), fg="red") -def get_ml_model_manifests_by_model_id(repo: str = 'elastic/detection-rules') -> Dict[str, ReleaseManifest]: +def get_ml_model_manifests_by_model_id(repo_name: str = "elastic/detection-rules") -> dict[str, ReleaseManifest]: """Load all ML DGA model release manifests by model id.""" - manifests, _ = ManifestManager.load_all(repo=repo) - model_manifests = {} - - for manifest_name, manifest in manifests.items(): - for asset_name, asset in manifest['assets'].items(): - for entry_name, entry_data in asset['entries'].items(): - if entry_name.startswith('dga') and entry_name.endswith('model.json'): - model_id, _ = entry_name.rsplit('_', 1) + manifests, _ = ManifestManager.load_all(repo_name=repo_name) + model_manifests: dict[str, ReleaseManifest] = {} + + for manifest in manifests.values(): + for asset in manifest["assets"].values(): + for entry_name in asset["entries"]: + if entry_name.startswith("dga") and entry_name.endswith("model.json"): + model_id, _ = entry_name.rsplit("_", 1) model_manifests[model_id] = ReleaseManifest(**manifest) break diff --git a/detection_rules/navigator.py b/detection_rules/navigator.py index 125ab1bbee0..f467cfceac8 100644 --- a/detection_rules/navigator.py +++ b/detection_rules/navigator.py @@ -5,21 +5,20 @@ """Create summary documents for a rule package.""" -from functools import reduce +import json from collections import defaultdict -from dataclasses import dataclass, field, fields +from dataclasses import dataclass, field +from functools import reduce from pathlib import Path -from typing import Dict, List, Optional -from marshmallow import pre_load +from typing import Any -import json +from marshmallow import pre_load from .attack import CURRENT_ATTACK_VERSION from .mixins import MarshmallowDataclassMixin from .rule import TOMLRule from .schemas import definitions - _DEFAULT_PLATFORMS = [ "Azure AD", "Containers", @@ -31,17 +30,15 @@ "Office 365", "PRE", "SaaS", - "Windows" + "Windows", ] -_DEFAULT_NAVIGATOR_LINKS = { - "label": "repo", - "url": "https://github.com/elastic/detection-rules" -} +_DEFAULT_NAVIGATOR_LINKS = {"label": "repo", "url": "https://github.com/elastic/detection-rules"} @dataclass class NavigatorMetadata(MarshmallowDataclassMixin): """Metadata for ATT&CK navigator objects.""" + name: str value: str @@ -49,6 +46,7 @@ class NavigatorMetadata(MarshmallowDataclassMixin): @dataclass class NavigatorLinks(MarshmallowDataclassMixin): """Metadata for ATT&CK navigator objects.""" + label: str url: str @@ -56,40 +54,42 @@ class NavigatorLinks(MarshmallowDataclassMixin): @dataclass class Techniques(MarshmallowDataclassMixin): """ATT&CK navigator techniques array class.""" + techniqueID: str tactic: str score: int - metadata: List[NavigatorMetadata] - links: List[NavigatorLinks] + metadata: list[NavigatorMetadata] + links: list[NavigatorLinks] - color: str = '' - comment: str = '' + color: str = "" + comment: str = "" enabled: bool = True showSubtechniques: bool = False @pre_load - def set_score(self, data: dict, **kwargs): - data['score'] = len(data['metadata']) + def set_score(self, data: dict[str, Any], **_: Any) -> dict[str, Any]: + data["score"] = len(data["metadata"]) return data @dataclass class Navigator(MarshmallowDataclassMixin): """ATT&CK navigator class.""" + @dataclass class Versions: attack: str - layer: str = '4.4' - navigator: str = '4.5.5' + layer: str = "4.4" + navigator: str = "4.5.5" @dataclass class Filters: - platforms: list = field(default_factory=_DEFAULT_PLATFORMS.copy) + platforms: list[str] = field(default_factory=_DEFAULT_PLATFORMS.copy) @dataclass class Layout: - layout: str = 'side' - aggregateFunction: str = 'average' + layout: str = "side" + aggregateFunction: str = "average" showID: bool = True showName: bool = True showAggregateScores: bool = False @@ -97,124 +97,121 @@ class Layout: @dataclass class Gradient: - colors: list = field(default_factory=['#d3e0fa', '#0861fb'].copy) + colors: list[str] = field(default_factory=["#d3e0fa", "#0861fb"].copy) minValue: int = 0 maxValue: int = 10 # not all defaults set name: str versions: Versions - techniques: List[Techniques] + techniques: list[Techniques] # all defaults set - filters: Filters = fields(Filters) - layout: Layout = fields(Layout) - gradient: Gradient = fields(Gradient) + filters: Filters = field(default_factory=Filters) + layout: Layout = field(default_factory=Layout) + gradient: Gradient = field(default_factory=Gradient) - domain: str = 'enterprise-attack' - description: str = 'Elastic detection-rules coverage' + domain: str = "enterprise-attack" + description: str = "Elastic detection-rules coverage" hideDisabled: bool = False - legendItems: list = field(default_factory=list) - links: List[NavigatorLinks] = field(default_factory=[_DEFAULT_NAVIGATOR_LINKS].copy) - metadata: Optional[List[NavigatorLinks]] = field(default_factory=list) + legendItems: list[Any] = field(default_factory=list) # type: ignore[reportUnknownVariableType] + + links: list[NavigatorLinks] = field(default_factory=[_DEFAULT_NAVIGATOR_LINKS].copy) # type: ignore[reportAssignmentType] + metadata: list[NavigatorLinks] | None = field(default_factory=list) # type: ignore[reportAssignmentType] showTacticRowBackground: bool = False selectTechniquesAcrossTactics: bool = False selectSubtechniquesWithParent: bool = False sorting: int = 0 - tacticRowBackground: str = '#dddddd' + tacticRowBackground: str = "#dddddd" -def technique_dict() -> dict: - return {'metadata': [], 'links': []} +def technique_dict() -> dict[str, Any]: + return {"metadata": [], "links": []} class NavigatorBuilder: """Rule navigator mappings and management.""" - def __init__(self, detection_rules: List[TOMLRule]): + def __init__(self, detection_rules: list[TOMLRule]) -> None: self.detection_rules = detection_rules - self.layers = { - 'all': defaultdict(lambda: defaultdict(technique_dict)), - 'platforms': defaultdict(lambda: defaultdict(technique_dict)), - + self.layers: dict[str, Any] = { + "all": defaultdict(lambda: defaultdict(technique_dict)), # type: ignore[reportUnknownLambdaType] + "platforms": defaultdict(lambda: defaultdict(technique_dict)), # type: ignore[reportUnknownLambdaType] # these will build multiple layers - 'indexes': defaultdict(lambda: defaultdict(lambda: defaultdict(technique_dict))), - 'tags': defaultdict(lambda: defaultdict(lambda: defaultdict(technique_dict))) + "indexes": defaultdict(lambda: defaultdict(lambda: defaultdict(technique_dict))), # type: ignore[reportUnknownLambdaType] + "tags": defaultdict(lambda: defaultdict(lambda: defaultdict(technique_dict))), # type: ignore[reportUnknownLambdaType] } self.process_rules() @staticmethod - def meta_dict(name: str, value: any) -> dict: - meta = { - 'name': name, - 'value': value - } - return meta + def meta_dict(name: str, value: Any) -> dict[str, Any]: + return {"name": name, "value": value} @staticmethod - def links_dict(label: str, url: any) -> dict: - links = { - 'label': label, - 'url': url - } - return links + def links_dict(label: str, url: Any) -> dict[str, Any]: + return {"label": label, "url": url} - def rule_links_dict(self, rule: TOMLRule) -> dict: + def rule_links_dict(self, rule: TOMLRule) -> dict[str, Any]: """Create a links dictionary for a rule.""" - base_url = 'https://github.com/elastic/detection-rules/blob/main/rules/' - base_path = str(rule.get_base_rule_dir()) + base_url = "https://github.com/elastic/detection-rules/blob/main/rules/" + base_path = rule.get_base_rule_dir() - if base_path is None: + if not base_path: raise ValueError("Could not find a valid base path for the rule") - url = f'{base_url}{base_path}' + base_path_str = str(base_path) + url = f"{base_url}{base_path_str}" return self.links_dict(rule.name, url) - def get_layer(self, layer_name: str, layer_key: Optional[str] = None) -> dict: + def get_layer(self, layer_name: str, layer_key: str | None = None) -> dict[str, Any]: """Safely retrieve a layer with optional sub-keys.""" return self.layers[layer_name][layer_key] if layer_key else self.layers[layer_name] - def _update_all(self, rule: TOMLRule, tactic: str, technique_id: str): - value = f'{rule.contents.data.type}/{rule.contents.data.get("language")}' - self.add_rule_to_technique(rule, 'all', tactic, technique_id, value) + def _update_all(self, rule: TOMLRule, tactic: str, technique_id: str) -> None: + value = f"{rule.contents.data.type}/{rule.contents.data.get('language')}" + self.add_rule_to_technique(rule, "all", tactic, technique_id, value) - def _update_platforms(self, rule: TOMLRule, tactic: str, technique_id: str): + def _update_platforms(self, rule: TOMLRule, tactic: str, technique_id: str) -> None: + if not rule.path: + raise ValueError("No rule path found") value = rule.path.parent.name - self.add_rule_to_technique(rule, 'platforms', tactic, technique_id, value) + self.add_rule_to_technique(rule, "platforms", tactic, technique_id, value) - def _update_indexes(self, rule: TOMLRule, tactic: str, technique_id: str): - for index in rule.contents.data.get('index') or []: + def _update_indexes(self, rule: TOMLRule, tactic: str, technique_id: str) -> None: + for index in rule.contents.data.get("index") or []: # type: ignore[reportUnknownVariableType] value = rule.id - self.add_rule_to_technique(rule, 'indexes', tactic, technique_id, value, layer_key=index.lower()) + self.add_rule_to_technique(rule, "indexes", tactic, technique_id, value, layer_key=index.lower()) # type: ignore[reportUnknownVariableType] - def _update_tags(self, rule: TOMLRule, tactic: str, technique_id: str): - for tag in rule.contents.data.get('tags', []): + def _update_tags(self, rule: TOMLRule, tactic: str, technique_id: str) -> None: + for _tag in rule.contents.data.get("tags") or []: # type: ignore[reportUnknownVariableType] value = rule.id - expected_prefixes = set([tag.split(":")[0] + ":" for tag in definitions.EXPECTED_RULE_TAGS]) - tag = reduce(lambda s, substr: s.replace(substr, ''), expected_prefixes, tag).lstrip() - layer_key = tag.replace(' ', '-').lower() - self.add_rule_to_technique(rule, 'tags', tactic, technique_id, value, layer_key=layer_key) - - def add_rule_to_technique(self, - rule: TOMLRule, - layer_name: str, - tactic: str, - technique_id: str, - value: str, - layer_key: Optional[str] = None): + expected_prefixes = {tag.split(":")[0] + ":" for tag in definitions.EXPECTED_RULE_TAGS} + tag = reduce(lambda s, substr: s.replace(substr, ""), expected_prefixes, _tag).lstrip() # type: ignore[reportUnknownMemberType] + layer_key = tag.replace(" ", "-").lower() # type: ignore[reportUnknownVariableType] + self.add_rule_to_technique(rule, "tags", tactic, technique_id, value, layer_key=layer_key) # type: ignore[reportUnknownArgumentType] + + def add_rule_to_technique( # noqa: PLR0913 + self, + rule: TOMLRule, + layer_name: str, + tactic: str, + technique_id: str, + value: str, + layer_key: str | None = None, + ) -> None: """Add a rule to a technique metadata and links.""" layer = self.get_layer(layer_name, layer_key) - layer[tactic][technique_id]['metadata'].append(self.meta_dict(rule.name, value)) - layer[tactic][technique_id]['links'].append(self.rule_links_dict(rule)) + layer[tactic][technique_id]["metadata"].append(self.meta_dict(rule.name, value)) + layer[tactic][technique_id]["links"].append(self.rule_links_dict(rule)) - def process_rule(self, rule: TOMLRule, tactic: str, technique_id: str): + def process_rule(self, rule: TOMLRule, tactic: str, technique_id: str) -> None: self._update_all(rule, tactic, technique_id) self._update_platforms(rule, tactic, technique_id) self._update_indexes(rule, tactic, technique_id) self._update_tags(rule, tactic, technique_id) - def process_rules(self): + def process_rules(self) -> None: """Adds rule to each applicable layer, including multi-layers.""" for rule in self.detection_rules: threat = rule.contents.data.threat @@ -230,63 +227,62 @@ def process_rules(self): for sub in technique_entry.subtechnique: self.process_rule(rule, tactic, sub.id) - def build_navigator(self, layer_name: str, layer_key: Optional[str] = None) -> Navigator: - populated_techniques = [] + def build_navigator(self, layer_name: str, layer_key: str | None = None) -> Navigator: + populated_techniques: list[dict[str, Any]] = [] layer = self.get_layer(layer_name, layer_key) - base_name = f'{layer_name}-{layer_key}' if layer_key else layer_name - base_name = base_name.replace('*', 'WILDCARD') - name = f'Elastic-detection-rules-{base_name}' + base_name = f"{layer_name}-{layer_key}" if layer_key else layer_name + base_name = base_name.replace("*", "WILDCARD") + name = f"Elastic-detection-rules-{base_name}" for tactic, techniques in layer.items(): - tactic_normalized = '-'.join(tactic.lower().split()) + tactic_normalized = "-".join(tactic.lower().split()) for technique_id, rules_data in techniques.items(): rules_data.update(tactic=tactic_normalized, techniqueID=technique_id) - techniques = Techniques.from_dict(rules_data) + _techniques = Techniques.from_dict(rules_data) - populated_techniques.append(techniques.to_dict()) + populated_techniques.append(_techniques.to_dict()) base_nav_obj = { - 'name': name, - 'techniques': populated_techniques, - 'versions': {'attack': CURRENT_ATTACK_VERSION} + "name": name, + "techniques": populated_techniques, + "versions": {"attack": CURRENT_ATTACK_VERSION}, } - navigator = Navigator.from_dict(base_nav_obj) - return navigator + return Navigator.from_dict(base_nav_obj) - def build_all(self) -> List[Navigator]: - built = [] + def build_all(self) -> list[Navigator]: + built: list[Navigator] = [] for layer_name, data in self.layers.items(): # this is a single layer - if 'defense evasion' in data: + if "defense evasion" in data: built.append(self.build_navigator(layer_name)) else: # multi layers - for layer_key, sub_data in data.items(): - built.append(self.build_navigator(layer_name, layer_key)) + built.extend([self.build_navigator(layer_name, layer_key) for layer_key in data]) return built @staticmethod - def _save(built: Navigator, directory: Path, verbose=True) -> Path: - path = directory.joinpath(built.name).with_suffix('.json') - path.write_text(json.dumps(built.to_dict(), indent=2)) + def _save(built: Navigator, directory: Path, verbose: bool = True) -> Path: + path = directory.joinpath(built.name).with_suffix(".json") + _ = path.write_text(json.dumps(built.to_dict(), indent=2)) if verbose: - print(f'saved: {path}') + print(f"saved: {path}") return path - def save_layer(self, - layer_name: str, - directory: Path, - layer_key: Optional[str] = None, - verbose=True - ) -> (Path, dict): + def save_layer( + self, + layer_name: str, + directory: Path, + layer_key: str | None = None, + verbose: bool = True, + ) -> tuple[Path, Navigator]: built = self.build_navigator(layer_name, layer_key) return self._save(built, directory, verbose), built - def save_all(self, directory: Path, verbose=True) -> Dict[Path, Navigator]: - paths = {} + def save_all(self, directory: Path, verbose: bool = True) -> dict[Path, Navigator]: + paths: dict[Path, Navigator] = {} for built in self.build_all(): path = self._save(built, directory, verbose) diff --git a/detection_rules/packaging.py b/detection_rules/packaging.py index a3bef733f7c..270947b84a8 100644 --- a/detection_rules/packaging.py +++ b/detection_rules/packaging.py @@ -4,42 +4,38 @@ # 2.0. """Packaging and preparation for releases.""" + import base64 -import datetime import hashlib import json -import os import shutil import textwrap from collections import defaultdict +from datetime import UTC, date, datetime from pathlib import Path -from typing import Dict, Optional, Tuple -from semver import Version +from typing import Any import click import yaml +from semver import Version from .config import load_current_package_version, parse_rules_config from .misc import JS_LICENSE, cached -from .navigator import NavigatorBuilder, Navigator -from .rule import TOMLRule, QueryRuleData, ThreatMapping +from .navigator import Navigator, NavigatorBuilder +from .rule import QueryRuleData, ThreatMapping, TOMLRule from .rule_loader import DeprecatedCollection, RuleCollection from .schemas import definitions -from .utils import Ndjson, get_path, get_etc_path +from .utils import Ndjson, get_etc_path, get_path from .version_lock import loaded_version_lock - RULES_CONFIG = parse_rules_config() -RELEASE_DIR = get_path("releases") +RELEASE_DIR = get_path(["releases"]) PACKAGE_FILE = str(RULES_CONFIG.packages_file) -NOTICE_FILE = get_path('NOTICE.txt') -FLEET_PKG_LOGO = get_etc_path("security-logo-color-64px.svg") - +NOTICE_FILE = get_path(["NOTICE.txt"]) +FLEET_PKG_LOGO = get_etc_path(["security-logo-color-64px.svg"]) -# CHANGELOG_FILE = Path(get_etc_path('rules-changelog.json')) - -def filter_rule(rule: TOMLRule, config_filter: dict, exclude_fields: Optional[dict] = None) -> bool: +def filter_rule(rule: TOMLRule, config_filter: dict[str, Any], exclude_fields: dict[str, Any] | None = None) -> bool: """Filter a rule based off metadata and a package configuration.""" flat_rule = rule.contents.flattened_dict() @@ -47,15 +43,15 @@ def filter_rule(rule: TOMLRule, config_filter: dict, exclude_fields: Optional[di if key not in flat_rule: return False - values = set([v.lower() if isinstance(v, str) else v for v in values]) + values_set = {v.lower() if isinstance(v, str) else v for v in values} rule_value = flat_rule[key] if isinstance(rule_value, list): - rule_values = {v.lower() if isinstance(v, str) else v for v in rule_value} + rule_values: set[Any] = {v.lower() if isinstance(v, str) else v for v in rule_value} # type: ignore[reportUnknownVariableType] else: rule_values = {rule_value.lower() if isinstance(rule_value, str) else rule_value} - if len(rule_values & values) == 0: + if len(rule_values & values_set) == 0: return False exclude_fields = exclude_fields or {} @@ -65,9 +61,12 @@ def filter_rule(rule: TOMLRule, config_filter: dict, exclude_fields: Optional[di unique_fields = get_unique_query_fields(rule) for index, fields in exclude_fields.items(): - if unique_fields and (rule.contents.data.index_or_dataview == index or index == 'any'): - if set(unique_fields) & set(fields): - return False + if ( + unique_fields + and (rule.contents.data.index_or_dataview == index or index == "any") # type: ignore[reportAttributeAccessIssue] # noqa: PLR1714 + and (set(unique_fields) & set(fields)) + ): + return False return True @@ -75,13 +74,21 @@ def filter_rule(rule: TOMLRule, config_filter: dict, exclude_fields: Optional[di CURRENT_RELEASE_PATH = RELEASE_DIR / load_current_package_version() -class Package(object): +class Package: """Packaging object for siem rules and releases.""" - def __init__(self, rules: RuleCollection, name: str, release: Optional[bool] = False, - min_version: Optional[int] = None, max_version: Optional[int] = None, - registry_data: Optional[dict] = None, verbose: Optional[bool] = True, - generate_navigator: bool = False, historical: bool = False): + def __init__( # noqa: PLR0913 + self, + rules: RuleCollection, + name: str, + release: bool | None = False, + min_version: int | None = None, + max_version: int | None = None, + registry_data: dict[str, Any] | None = None, + generate_navigator: bool = False, + verbose: bool = True, + historical: bool = False, + ) -> None: """Initialize a package.""" self.name = name self.rules = rules @@ -92,44 +99,50 @@ def __init__(self, rules: RuleCollection, name: str, release: Optional[bool] = F self.historical = historical if min_version is not None: - self.rules = self.rules.filter(lambda r: min_version <= r.contents.saved_version) + self.rules = self.rules.filter(lambda r: min_version <= r.contents.saved_version) # type: ignore[reportOperatorIssue] if max_version is not None: - self.rules = self.rules.filter(lambda r: max_version >= r.contents.saved_version) + self.rules = self.rules.filter(lambda r: max_version >= r.contents.saved_version) # type: ignore[reportOperatorIssue] - assert not RULES_CONFIG.bypass_version_lock, "Packaging can not be used when version locking is bypassed." - self.changed_ids, self.new_ids, self.removed_ids = \ - loaded_version_lock.manage_versions(self.rules, verbose=verbose, save_changes=False) + if RULES_CONFIG.bypass_version_lock: + raise ValueError("Packaging can not be used when version locking is bypassed.") + self.changed_ids, self.new_ids, self.removed_ids = loaded_version_lock.manage_versions( + self.rules, + verbose=verbose, + save_changes=False, + ) @classmethod - def load_configs(cls): + def load_configs(cls) -> Any: """Load configs from packages.yaml.""" - return RULES_CONFIG.packages['package'] + return RULES_CONFIG.packages["package"] @staticmethod - def _package_kibana_notice_file(save_dir): + def _package_kibana_notice_file(save_dir: Path) -> None: """Convert and save notice file with package.""" - with open(NOTICE_FILE, 'rt') as f: + with NOTICE_FILE.open() as f: notice_txt = f.read() - with open(os.path.join(save_dir, 'notice.ts'), 'wt') as f: - commented_notice = [f' * {line}'.rstrip() for line in notice_txt.splitlines()] - lines = ['/* eslint-disable @kbn/eslint/require-license-header */', '', '/* @notice'] - lines = lines + commented_notice + [' */', ''] - f.write('\n'.join(lines)) + with (save_dir / "notice.ts").open("w") as f: + commented_notice = [f" * {line}".rstrip() for line in notice_txt.splitlines()] + lines = ["/* eslint-disable @kbn/eslint/require-license-header */", "", "/* @notice"] + lines = lines + commented_notice + [" */", ""] + _ = f.write("\n".join(lines)) - def _package_kibana_index_file(self, save_dir): + def _package_kibana_index_file(self, save_dir: Path) -> None: """Convert and save index file with package.""" - sorted_rules = sorted(self.rules, key=lambda k: (k.contents.metadata.creation_date, os.path.basename(k.path))) + sorted_rules = sorted(self.rules, key=lambda k: (k.contents.metadata.creation_date, k.path.name)) # type: ignore[reportOptionalMemberAccess] comments = [ - '// Auto generated file from either:', - '// - scripts/regen_prepackage_rules_index.sh', - '// - detection-rules repo using CLI command build-release', - '// Do not hand edit. Run script/command to regenerate package information instead', + "// Auto generated file from either:", + "// - scripts/regen_prepackage_rules_index.sh", + "// - detection-rules repo using CLI command build-release", + "// Do not hand edit. Run script/command to regenerate package information instead", + ] + rule_imports = [ + f"import rule{i} from './{r.path.name + '.json'}';" # type: ignore[reportOptionalMemberAccess] + for i, r in enumerate(sorted_rules, 1) ] - rule_imports = [f"import rule{i} from './{os.path.splitext(os.path.basename(r.path))[0] + '.json'}';" - for i, r in enumerate(sorted_rules, 1)] - const_exports = ['export const rawRules = ['] + const_exports = ["export const rawRules = ["] const_exports.extend(f" rule{i}," for i, _ in enumerate(sorted_rules, 1)) const_exports.append("];") const_exports.append("") @@ -141,47 +154,54 @@ def _package_kibana_index_file(self, save_dir): index_ts.append("") index_ts.extend(const_exports) - with open(os.path.join(save_dir, 'index.ts'), 'wt') as f: - f.write('\n'.join(index_ts)) + with (save_dir / "index.ts").open("w") as f: + _ = f.write("\n".join(index_ts)) - def save_release_files(self, directory: str, changed_rules: list, new_rules: list, removed_rules: list): + def save_release_files( + self, + directory: Path, + changed_rules: list[definitions.UUIDString], + new_rules: list[str], + removed_rules: list[str], + ) -> None: """Release a package.""" summary, changelog = self.generate_summary_and_changelog(changed_rules, new_rules, removed_rules) - with open(os.path.join(directory, f'{self.name}-summary.txt'), 'w') as f: - f.write(summary) - with open(os.path.join(directory, f'{self.name}-changelog-entry.md'), 'w') as f: - f.write(changelog) + with (directory / f"{self.name}-summary.txt").open("w") as f: + _ = f.write(summary) + with (directory / f"{self.name}-changelog-entry.md").open("w") as f: + _ = f.write(changelog) if self.generate_navigator: - self.generate_attack_navigator(Path(directory)) + _ = self.generate_attack_navigator(Path(directory)) consolidated = json.loads(self.get_consolidated()) - with open(os.path.join(directory, f'{self.name}-consolidated-rules.json'), 'w') as f: + with (directory / f"{self.name}-consolidated-rules.json").open("w") as f: json.dump(consolidated, f, sort_keys=True, indent=2) consolidated_rules = Ndjson(consolidated) - consolidated_rules.dump(Path(directory).joinpath(f'{self.name}-consolidated-rules.ndjson'), sort_keys=True) + consolidated_rules.dump(Path(directory).joinpath(f"{self.name}-consolidated-rules.ndjson"), sort_keys=True) - self.generate_xslx(os.path.join(directory, f'{self.name}-summary.xlsx')) + self.generate_xslx(str(directory / f"{self.name}-summary.xlsx")) bulk_upload, rules_ndjson = self.create_bulk_index_body() - bulk_upload.dump(Path(directory).joinpath(f'{self.name}-enriched-rules-index-uploadable.ndjson'), - sort_keys=True) - rules_ndjson.dump(Path(directory).joinpath(f'{self.name}-enriched-rules-index-importable.ndjson'), - sort_keys=True) - - def get_consolidated(self, as_api=True): + bulk_upload.dump( + directory / f"{self.name}-enriched-rules-index-uploadable.ndjson", + sort_keys=True, + ) + rules_ndjson.dump( + directory / f"{self.name}-enriched-rules-index-importable.ndjson", + sort_keys=True, + ) + + def get_consolidated(self, as_api: bool = True) -> str: """Get a consolidated package of the rules in a single file.""" - full_package = [] - for rule in self.rules: - full_package.append(rule.contents.to_api_format() if as_api else rule.contents.to_dict()) - + full_package = [rule.contents.to_api_format() if as_api else rule.contents.to_dict() for rule in self.rules] return json.dumps(full_package, sort_keys=True) - def save(self, verbose=True): + def save(self, verbose: bool = True) -> None: """Save a package and all artifacts.""" save_dir = RELEASE_DIR / self.name - rules_dir = save_dir / 'rules' - extras_dir = save_dir / 'extras' + rules_dir = save_dir / "rules" + extras_dir = save_dir / "extras" # remove anything that existed before shutil.rmtree(save_dir, ignore_errors=True) @@ -189,7 +209,9 @@ def save(self, verbose=True): extras_dir.mkdir(parents=True, exist_ok=True) for rule in self.rules: - rule.save_json(rules_dir / Path(rule.path.name).with_suffix('.json')) + if not rule.path: + raise ValueError("Rule path is not found") + rule.save_json(rules_dir / Path(rule.path.name).with_suffix(".json")) self._package_kibana_notice_file(rules_dir) self._package_kibana_index_file(rules_dir) @@ -199,43 +221,67 @@ def save(self, verbose=True): self.save_release_files(extras_dir, self.changed_ids, self.new_ids, self.removed_ids) # zip all rules only and place in extras - shutil.make_archive(extras_dir / self.name, 'zip', root_dir=rules_dir.parent, base_dir=rules_dir.name) + _ = shutil.make_archive( + str(extras_dir / self.name), + "zip", + root_dir=rules_dir.parent, + base_dir=rules_dir.name, + ) # zip everything and place in release root - shutil.make_archive( - save_dir / f"{self.name}-all", "zip", root_dir=extras_dir.parent, base_dir=extras_dir.name + _ = shutil.make_archive( + str(save_dir / f"{self.name}-all"), + "zip", + root_dir=extras_dir.parent, + base_dir=extras_dir.name, ) if verbose: - click.echo(f'Package saved to: {save_dir}') - - def export(self, outfile, downgrade_version=None, verbose=True, skip_unsupported=False): + click.echo(f"Package saved to: {save_dir}") + + def export( + self, + outfile: Path, + downgrade_version: definitions.SemVer | None = None, + verbose: bool = True, + skip_unsupported: bool = False, + ) -> None: """Export rules into a consolidated ndjson file.""" - from .main import _export_rules + from .main import _export_rules # type: ignore[reportPrivateUsage] - _export_rules(self.rules, outfile=outfile, downgrade_version=downgrade_version, verbose=verbose, - skip_unsupported=skip_unsupported) + _export_rules( + self.rules, + outfile=outfile, + downgrade_version=downgrade_version, + verbose=verbose, + skip_unsupported=skip_unsupported, + ) - def get_package_hash(self, as_api=True, verbose=True): + def get_package_hash(self, as_api: bool = True, verbose: bool = True) -> str: """Get hash of package contents.""" - contents = base64.b64encode(self.get_consolidated(as_api=as_api).encode('utf-8')) + contents = base64.b64encode(self.get_consolidated(as_api=as_api).encode("utf-8")) sha256 = hashlib.sha256(contents).hexdigest() if verbose: - click.echo('- sha256: {}'.format(sha256)) + click.echo(f"- sha256: {sha256}") return sha256 @classmethod - def from_config(cls, rule_collection: Optional[RuleCollection] = None, config: Optional[dict] = None, - verbose: Optional[bool] = False, historical: Optional[bool] = True) -> 'Package': + def from_config( + cls, + rule_collection: RuleCollection | None = None, + config: dict[str, Any] | None = None, + verbose: bool = False, + historical: bool = True, + ) -> "Package": """Load a rules package given a config.""" all_rules = rule_collection or RuleCollection.default() config = config or {} - exclude_fields = config.pop('exclude_fields', {}) + exclude_fields = config.pop("exclude_fields", {}) # deprecated rules are now embedded in the RuleCollection.deprecated - this is left here for backwards compat - config.pop('log_deprecated', False) - rule_filter = config.pop('filter', {}) + config.pop("log_deprecated", False) + rule_filter = config.pop("filter", {}) rules = all_rules.filter(lambda r: filter_rule(r, rule_filter, exclude_fields)) @@ -243,31 +289,34 @@ def from_config(cls, rule_collection: Optional[RuleCollection] = None, config: O rules.deprecated = all_rules.deprecated if verbose: - click.echo(f' - {len(all_rules) - len(rules)} rules excluded from package') - - package = cls(rules, verbose=verbose, historical=historical, **config) + click.echo(f" - {len(all_rules) - len(rules)} rules excluded from package") - return package + return cls(rules, verbose=verbose, historical=historical, **config) - def generate_summary_and_changelog(self, changed_rule_ids, new_rule_ids, removed_rules): + def generate_summary_and_changelog( # noqa: PLR0915 + self, + changed_rule_ids: list[definitions.UUIDString], + new_rule_ids: list[str], + removed_rules: list[str], + ) -> tuple[str, str]: """Generate stats on package.""" - summary = { - 'changed': defaultdict(list), - 'added': defaultdict(list), - 'removed': defaultdict(list), - 'unchanged': defaultdict(list) + summary: dict[str, dict[str, list[str]]] = { + "changed": defaultdict(list), + "added": defaultdict(list), + "removed": defaultdict(list), + "unchanged": defaultdict(list), } - changelog = { - 'changed': defaultdict(list), - 'added': defaultdict(list), - 'removed': defaultdict(list), - 'unchanged': defaultdict(list) + changelog: dict[str, dict[str, list[str]]] = { + "changed": defaultdict(list), + "added": defaultdict(list), + "removed": defaultdict(list), + "unchanged": defaultdict(list), } # Build an index map first longest_name = 0 - indexes = set() + indexes: set[str] = set() for rule in self.rules: longest_name = max(longest_name, len(rule.name)) index_list = getattr(rule.contents.data, "index", []) @@ -276,103 +325,115 @@ def generate_summary_and_changelog(self, changed_rule_ids, new_rule_ids, removed index_map = {index: str(i) for i, index in enumerate(sorted(indexes))} - def get_summary_rule_info(r: TOMLRule): - r = r.contents - rule_str = f'{r.name:<{longest_name}} (v:{r.autobumped_version} t:{r.data.type}' + def get_summary_rule_info(r: TOMLRule) -> str: + contents = r.contents + rule_str = f"{r.name:<{longest_name}} (v:{contents.autobumped_version} t:{contents.data.type}" if isinstance(rule.contents.data, QueryRuleData): - index = rule.contents.data.get("index") or [] - rule_str += f'-{r.data.language}' - rule_str += f'(indexes:{"".join(index_map[idx] for idx in index) or "none"}' + index: list[str] = rule.contents.data.get("index") or [] + rule_str += f"-{contents.data.language}" # type: ignore[reportAttributeAccessIssue] + rule_str += f"(indexes:{''.join(index_map[idx] for idx in index) or 'none'}" return rule_str - def get_markdown_rule_info(r: TOMLRule, sd): + def get_markdown_rule_info(r: TOMLRule, sd: str) -> str: # lookup the rule in the GitHub tag v{major.minor.patch} + if not r.path: + raise ValueError("Unknown rule path") data = r.contents.data - rules_dir_link = f'https://github.com/elastic/detection-rules/tree/v{self.name}/rules/{sd}/' + rules_dir_link = f"https://github.com/elastic/detection-rules/tree/v{self.name}/rules/{sd}/" rule_type = data.language if isinstance(data, QueryRuleData) else data.type - return f'`{r.id}` **[{r.name}]({rules_dir_link + os.path.basename(str(r.path))})** (_{rule_type}_)' + return f"`{r.id}` **[{r.name}]({rules_dir_link + r.path.name})** (_{rule_type}_)" for rule in self.rules: - sub_dir = os.path.basename(os.path.dirname(rule.path)) + if not rule.path: + raise ValueError("Unknown rule path") + sub_dir = rule.path.parent.name if rule.id in changed_rule_ids: - summary['changed'][sub_dir].append(get_summary_rule_info(rule)) - changelog['changed'][sub_dir].append(get_markdown_rule_info(rule, sub_dir)) + summary["changed"][sub_dir].append(get_summary_rule_info(rule)) + changelog["changed"][sub_dir].append(get_markdown_rule_info(rule, sub_dir)) elif rule.id in new_rule_ids: - summary['added'][sub_dir].append(get_summary_rule_info(rule)) - changelog['added'][sub_dir].append(get_markdown_rule_info(rule, sub_dir)) + summary["added"][sub_dir].append(get_summary_rule_info(rule)) + changelog["added"][sub_dir].append(get_markdown_rule_info(rule, sub_dir)) else: - summary['unchanged'][sub_dir].append(get_summary_rule_info(rule)) - changelog['unchanged'][sub_dir].append(get_markdown_rule_info(rule, sub_dir)) + summary["unchanged"][sub_dir].append(get_summary_rule_info(rule)) + changelog["unchanged"][sub_dir].append(get_markdown_rule_info(rule, sub_dir)) for rule in self.deprecated_rules: - sub_dir = os.path.basename(os.path.dirname(rule.path)) + if not rule.path: + raise ValueError("Unknown rule path") + + sub_dir = rule.path.parent.name + + if not rule.name: + raise ValueError("Rule name is not found") if rule.id in removed_rules: - summary['removed'][sub_dir].append(rule.name) - changelog['removed'][sub_dir].append(rule.name) + summary["removed"][sub_dir].append(rule.name) + changelog["removed"][sub_dir].append(rule.name) - def format_summary_rule_str(rule_dict): - str_fmt = '' + def format_summary_rule_str(rule_dict: dict[str, Any]) -> str: + str_fmt = "" for sd, rules in sorted(rule_dict.items(), key=lambda x: x[0]): - str_fmt += f'\n{sd} ({len(rules)})\n' - str_fmt += '\n'.join(' - ' + s for s in sorted(rules)) - return str_fmt or '\nNone' + str_fmt += f"\n{sd} ({len(rules)})\n" + str_fmt += "\n".join(" - " + s for s in sorted(rules)) + return str_fmt or "\nNone" - def format_changelog_rule_str(rule_dict): - str_fmt = '' + def format_changelog_rule_str(rule_dict: dict[str, Any]) -> str: + str_fmt = "" for sd, rules in sorted(rule_dict.items(), key=lambda x: x[0]): - str_fmt += f'\n- **{sd}** ({len(rules)})\n' - str_fmt += '\n'.join(' - ' + s for s in sorted(rules)) - return str_fmt or '\nNone' + str_fmt += f"\n- **{sd}** ({len(rules)})\n" + str_fmt += "\n".join(" - " + s for s in sorted(rules)) + return str_fmt or "\nNone" - def rule_count(rule_dict): + def rule_count(rule_dict: dict[str, Any]) -> int: count = 0 - for _, rules in rule_dict.items(): + for rules in rule_dict.values(): count += len(rules) return count - today = str(datetime.date.today()) - summary_fmt = [f'{sf.capitalize()} ({rule_count(summary[sf])}): \n{format_summary_rule_str(summary[sf])}\n' - for sf in ('added', 'changed', 'removed', 'unchanged') if summary[sf]] - - change_fmt = [f'{sf.capitalize()} ({rule_count(changelog[sf])}): \n{format_changelog_rule_str(changelog[sf])}\n' - for sf in ('added', 'changed', 'removed') if changelog[sf]] - - summary_str = '\n'.join([ - f'Version {self.name}', - f'Generated: {today}', - f'Total Rules: {len(self.rules)}', - f'Package Hash: {self.get_package_hash(verbose=False)}', - '---', - '(v: version, t: rule_type-language)', - 'Index Map:\n{}'.format("\n".join(f" {v}: {k}" for k, v in index_map.items())), - '', - 'Rules', - *summary_fmt - ]) - - changelog_str = '\n'.join([ - f'# Version {self.name}', - f'_Released {today}_', - '', - '### Rules', - *change_fmt, - '', - '### CLI' - ]) + today = str(date.today()) # noqa: DTZ011 + summary_fmt = [ + f"{sf.capitalize()} ({rule_count(summary[sf])}): \n{format_summary_rule_str(summary[sf])}\n" + for sf in ("added", "changed", "removed", "unchanged") + if summary[sf] + ] + + change_fmt = [ + f"{sf.capitalize()} ({rule_count(changelog[sf])}): \n{format_changelog_rule_str(changelog[sf])}\n" + for sf in ("added", "changed", "removed") + if changelog[sf] + ] + + summary_str = "\n".join( + [ + f"Version {self.name}", + f"Generated: {today}", + f"Total Rules: {len(self.rules)}", + f"Package Hash: {self.get_package_hash(verbose=False)}", + "---", + "(v: version, t: rule_type-language)", + "Index Map:\n{}".format("\n".join(f" {v}: {k}" for k, v in index_map.items())), + "", + "Rules", + *summary_fmt, + ] + ) + + changelog_str = "\n".join( + [f"# Version {self.name}", f"_Released {today}_", "", "### Rules", *change_fmt, "", "### CLI"] + ) return summary_str, changelog_str - def generate_attack_navigator(self, path: Path) -> Dict[Path, Navigator]: + def generate_attack_navigator(self, path: Path) -> dict[Path, Navigator]: """Generate ATT&CK navigator layer files.""" - save_dir = path / 'navigator_layers' + save_dir = path / "navigator_layers" save_dir.mkdir() lb = NavigatorBuilder(self.rules.rules) return lb.save_all(save_dir, verbose=False) - def generate_xslx(self, path): + def generate_xslx(self, path: str) -> None: """Generate a detailed breakdown of a package in an excel file.""" from .docs import PackageDocument @@ -380,35 +441,33 @@ def generate_xslx(self, path): doc.populate() doc.close() - def _generate_registry_package(self, save_dir): + def _generate_registry_package(self, save_dir: Path) -> None: """Generate the artifact for the oob package-storage.""" - from .schemas.registry_package import (RegistryPackageManifestV1, - RegistryPackageManifestV3) + from .schemas.registry_package import RegistryPackageManifestV1, RegistryPackageManifestV3 # 8.12.0+ we use elastic package v3 stack_version = Version.parse(self.name, optional_minor_and_patch=True) - if stack_version >= Version.parse('8.12.0'): + if stack_version >= Version.parse("8.12.0"): manifest = RegistryPackageManifestV3.from_dict(self.registry_data) else: manifest = RegistryPackageManifestV1.from_dict(self.registry_data) - package_dir = Path(save_dir) / 'fleet' / manifest.version - docs_dir = package_dir / 'docs' - rules_dir = package_dir / 'kibana' / definitions.ASSET_TYPE + package_dir = Path(save_dir) / "fleet" / manifest.version + docs_dir = package_dir / "docs" + rules_dir = package_dir / "kibana" / definitions.ASSET_TYPE docs_dir.mkdir(parents=True) rules_dir.mkdir(parents=True) - manifest_file = package_dir / 'manifest.yml' - readme_file = docs_dir / 'README.md' - notice_file = package_dir / 'NOTICE.txt' - logo_file = package_dir / 'img' / 'security-logo-color-64px.svg' + manifest_file = package_dir / "manifest.yml" + readme_file = docs_dir / "README.md" + notice_file = package_dir / "NOTICE.txt" + logo_file = package_dir / "img" / "security-logo-color-64px.svg" manifest_file.write_text(yaml.safe_dump(manifest.to_dict())) logo_file.parent.mkdir(parents=True) shutil.copyfile(FLEET_PKG_LOGO, logo_file) - # shutil.copyfile(CHANGELOG_FILE, str(rules_dir.joinpath('CHANGELOG.json'))) for rule in self.rules: asset = rule.get_asset() @@ -416,7 +475,7 @@ def _generate_registry_package(self, save_dir): # asset['id] and the file name needs to resemble RULEID_VERSION instead of RULEID asset_id = f"{asset['attributes']['rule_id']}_{asset['attributes']['version']}" asset["id"] = asset_id - asset_path = rules_dir / f'{asset_id}.json' + asset_path = rules_dir / f"{asset_id}.json" asset_path.write_text(json.dumps(asset, indent=4, sort_keys=True), encoding="utf-8") @@ -432,62 +491,64 @@ def _generate_registry_package(self, save_dir): ## License Notice - """).lstrip() # noqa: E501 + """).lstrip() # notice only needs to be appended to the README for 7.13.x # in 7.14+ there's a separate modal to display this if self.name == "7.13": - textwrap.indent(notice_contents, prefix=" ") + notice_contents = textwrap.indent(notice_contents, prefix=" ") readme_file.write_text(readme_text) notice_file.write_text(notice_contents) - def create_bulk_index_body(self) -> Tuple[Ndjson, Ndjson]: + def create_bulk_index_body(self) -> tuple[Ndjson, Ndjson]: """Create a body to bulk index into a stack.""" package_hash = self.get_package_hash(verbose=False) - now = datetime.datetime.isoformat(datetime.datetime.utcnow()) - create = {'create': {'_index': f'rules-repo-{self.name}-{package_hash}'}} + now = datetime.now(UTC).isoformat() + create = {"create": {"_index": f"rules-repo-{self.name}-{package_hash}"}} # first doc is summary stats - summary_doc = { - 'group_hash': package_hash, - 'package_version': self.name, - 'rule_count': len(self.rules), - 'rule_ids': [], - 'rule_names': [], - 'rule_hashes': [], - 'source': 'repo', - 'details': {'datetime_uploaded': now} + summary_doc: dict[str, Any] = { + "group_hash": package_hash, + "package_version": self.name, + "rule_count": len(self.rules), + "rule_ids": [], + "rule_names": [], + "rule_hashes": [], + "source": "repo", + "details": {"datetime_uploaded": now}, } bulk_upload_docs = Ndjson([create, summary_doc]) importable_rules_docs = Ndjson() for rule in self.rules: - summary_doc['rule_ids'].append(rule.id) - summary_doc['rule_names'].append(rule.name) - summary_doc['rule_hashes'].append(rule.contents.get_hash()) + summary_doc["rule_ids"].append(rule.id) + summary_doc["rule_names"].append(rule.name) + summary_doc["rule_hashes"].append(rule.contents.get_hash()) if rule.id in self.new_ids: - status = 'new' + status = "new" elif rule.id in self.changed_ids: - status = 'modified' + status = "modified" else: - status = 'unmodified' + status = "unmodified" bulk_upload_docs.append(create) relative_path = str(rule.get_base_rule_dir()) - if relative_path is None: + if not relative_path: raise ValueError(f"Could not find a valid relative path for the rule: {rule.id}") - rule_doc = dict(hash=rule.contents.get_hash(), - source='repo', - datetime_uploaded=now, - status=status, - package_version=self.name, - flat_mitre=ThreatMapping.flatten(rule.contents.data.threat).to_dict(), - relative_path=relative_path) + rule_doc = { + "hash": rule.contents.get_hash(), + "source": "repo", + "datetime_uploaded": now, + "status": status, + "package_version": self.name, + "flat_mitre": ThreatMapping.flatten(rule.contents.data.threat).to_dict(), + "relative_path": relative_path, + } rule_doc.update(**rule.contents.to_api_format()) bulk_upload_docs.append(rule_doc) importable_rules_docs.append(rule_doc) @@ -495,14 +556,17 @@ def create_bulk_index_body(self) -> Tuple[Ndjson, Ndjson]: return bulk_upload_docs, importable_rules_docs @staticmethod - def add_historical_rules(historical_rules: Dict[str, dict], manifest_version: str) -> list: + def add_historical_rules( + historical_rules: dict[str, dict[str, Any]], + manifest_version: str, + ) -> list[dict[str, Any]] | None: """Adds historical rules to existing build package.""" - rules_dir = CURRENT_RELEASE_PATH / 'fleet' / manifest_version / 'kibana' / 'security_rule' + rules_dir = CURRENT_RELEASE_PATH / "fleet" / manifest_version / "kibana" / "security_rule" # iterates over historical rules from previous package and writes them to disk - for _, historical_rule_contents in historical_rules.items(): + for historical_rule_contents in historical_rules.values(): rule_id = historical_rule_contents["attributes"]["rule_id"] - historical_rule_version = historical_rule_contents['attributes']['version'] + historical_rule_version = historical_rule_contents["attributes"]["version"] # checks if the rule exists in the current package first current_rule_path = list(rules_dir.glob(f"{rule_id}*.json")) @@ -512,7 +576,7 @@ def add_historical_rules(historical_rules: Dict[str, dict], manifest_version: st # load the current rule from disk current_rule_path = current_rule_path[0] current_rule_json = json.load(current_rule_path.open(encoding="UTF-8")) - current_rule_version = current_rule_json['attributes']['version'] + current_rule_version = current_rule_json["attributes"]["version"] # if the historical rule version and current rules version differ, write # the historical rule to disk @@ -524,4 +588,4 @@ def add_historical_rules(historical_rules: Dict[str, dict], manifest_version: st @cached def current_stack_version() -> str: - return Package.load_configs()['name'] + return Package.load_configs()["name"] diff --git a/detection_rules/remote_validation.py b/detection_rules/remote_validation.py index db30c5e953c..90c8d1a24f2 100644 --- a/detection_rules/remote_validation.py +++ b/detection_rules/remote_validation.py @@ -3,21 +3,21 @@ # 2.0; you may not use this file except in compliance with the Elastic License # 2.0. +from collections.abc import Callable from dataclasses import dataclass -from datetime import datetime +from datetime import UTC, datetime from functools import cached_property from multiprocessing.pool import ThreadPool -from typing import Dict, List, Optional +from typing import Any import elasticsearch from elasticsearch import Elasticsearch +from kibana import Kibana # type: ignore[reportMissingTypeStubs] from marshmallow import ValidationError from requests import HTTPError -from kibana import Kibana - from .config import load_current_package_version -from .misc import ClientError, getdefault, get_elasticsearch_client, get_kibana_client +from .misc import ClientError, get_elasticsearch_client, get_kibana_client, getdefault from .rule import TOMLRule, TOMLRuleContents from .schemas import definitions @@ -25,13 +25,14 @@ @dataclass class RemoteValidationResult: """Dataclass for remote validation results.""" + rule_id: definitions.UUIDString rule_name: str - contents: dict + contents: dict[str, Any] rule_version: int stack_version: str - query_results: Optional[dict] - engine_results: Optional[dict] + query_results: dict[str, Any] + engine_results: dict[str, Any] class RemoteConnector: @@ -39,17 +40,17 @@ class RemoteConnector: MAX_RETRIES = 5 - def __init__(self, parse_config: bool = False, **kwargs): - es_args = ['cloud_id', 'ignore_ssl_errors', 'elasticsearch_url', 'es_user', 'es_password', 'timeout'] - kibana_args = ['cloud_id', 'ignore_ssl_errors', 'kibana_url', 'api_key', 'space'] + def __init__(self, parse_config: bool = False, **kwargs: Any) -> None: + es_args = ["cloud_id", "ignore_ssl_errors", "elasticsearch_url", "es_user", "es_password", "timeout"] + kibana_args = ["cloud_id", "ignore_ssl_errors", "kibana_url", "api_key", "space"] if parse_config: es_kwargs = {arg: getdefault(arg)() for arg in es_args} kibana_kwargs = {arg: getdefault(arg)() for arg in kibana_args} try: - if 'max_retries' not in es_kwargs: - es_kwargs['max_retries'] = self.MAX_RETRIES + if "max_retries" not in es_kwargs: + es_kwargs["max_retries"] = self.MAX_RETRIES self.es_client = get_elasticsearch_client(**es_kwargs, **kwargs) except ClientError: self.es_client = None @@ -59,15 +60,29 @@ def __init__(self, parse_config: bool = False, **kwargs): except HTTPError: self.kibana_client = None - def auth_es(self, *, cloud_id: Optional[str] = None, ignore_ssl_errors: Optional[bool] = None, - elasticsearch_url: Optional[str] = None, es_user: Optional[str] = None, - es_password: Optional[str] = None, timeout: Optional[int] = None, **kwargs) -> Elasticsearch: + def auth_es( # noqa: PLR0913 + self, + *, + cloud_id: str | None = None, + ignore_ssl_errors: bool | None = None, + elasticsearch_url: str | None = None, + es_user: str | None = None, + es_password: str | None = None, + timeout: int | None = None, + **kwargs: Any, + ) -> Elasticsearch: """Return an authenticated Elasticsearch client.""" - if 'max_retries' not in kwargs: - kwargs['max_retries'] = self.MAX_RETRIES - self.es_client = get_elasticsearch_client(cloud_id=cloud_id, ignore_ssl_errors=ignore_ssl_errors, - elasticsearch_url=elasticsearch_url, es_user=es_user, - es_password=es_password, timeout=timeout, **kwargs) + if "max_retries" not in kwargs: + kwargs["max_retries"] = self.MAX_RETRIES + self.es_client = get_elasticsearch_client( + cloud_id=cloud_id, + ignore_ssl_errors=ignore_ssl_errors, + elasticsearch_url=elasticsearch_url, + es_user=es_user, + es_password=es_password, + timeout=timeout, + **kwargs, + ) return self.es_client def auth_kibana( @@ -78,7 +93,7 @@ def auth_kibana( kibana_url: str | None = None, space: str | None = None, ignore_ssl_errors: bool = False, - **kwargs + **kwargs: Any, ) -> Kibana: """Return an authenticated Kibana client.""" self.kibana_client = get_kibana_client( @@ -87,7 +102,7 @@ def auth_kibana( kibana_url=kibana_url, api_key=api_key, space=space, - **kwargs + **kwargs, ) return self.kibana_client @@ -95,115 +110,139 @@ def auth_kibana( class RemoteValidator(RemoteConnector): """Client class for remote validation.""" - def __init__(self, parse_config: bool = False): - super(RemoteValidator, self).__init__(parse_config=parse_config) + def __init__(self, parse_config: bool = False) -> None: + super().__init__(parse_config=parse_config) @cached_property - def get_validate_methods(self) -> List[str]: + def get_validate_methods(self) -> list[str]: """Return all validate methods.""" - exempt = ('validate_rule', 'validate_rules') - methods = [m for m in self.__dir__() if m.startswith('validate_') and m not in exempt] - return methods + exempt = ("validate_rule", "validate_rules") + return [m for m in self.__dir__() if m.startswith("validate_") and m not in exempt] - def get_validate_method(self, name: str) -> callable: + def get_validate_method(self, name: str) -> Callable[..., Any]: """Return validate method by name.""" - assert name in self.get_validate_methods, f'validate method {name} not found' + if name not in self.get_validate_methods: + raise ValueError(f"Validate method {name} not found") return getattr(self, name) @staticmethod - def prep_for_preview(contents: TOMLRuleContents) -> dict: + def prep_for_preview(contents: TOMLRuleContents) -> dict[str, Any]: """Prepare rule for preview.""" - end_time = datetime.utcnow().isoformat() + end_time = datetime.now(UTC).isoformat() dumped = contents.to_api_format().copy() dumped.update(timeframeEnd=end_time, invocationCount=1) return dumped - def engine_preview(self, contents: TOMLRuleContents) -> dict: + def engine_preview(self, contents: TOMLRuleContents) -> dict[str, Any]: """Get results from detection engine preview API.""" dumped = self.prep_for_preview(contents) - return self.kibana_client.post('/api/detection_engine/rules/preview', json=dumped) + if not self.kibana_client: + raise ValueError("No Kibana client found") + return self.kibana_client.post("/api/detection_engine/rules/preview", json=dumped) # type: ignore[reportReturnType] def validate_rule(self, contents: TOMLRuleContents) -> RemoteValidationResult: """Validate a single rule query.""" - method = self.get_validate_method(f'validate_{contents.data.type}') + method = self.get_validate_method(f"validate_{contents.data.type}") query_results = method(contents) engine_results = self.engine_preview(contents) rule_version = contents.autobumped_version stack_version = load_current_package_version() - return RemoteValidationResult(contents.data.rule_id, contents.data.name, contents.to_api_format(), - rule_version, stack_version, query_results, engine_results) + if not rule_version: + raise ValueError("No rule version found") + + return RemoteValidationResult( + contents.data.rule_id, + contents.data.name, + contents.to_api_format(), + rule_version, + stack_version, + query_results, + engine_results, + ) - def validate_rules(self, rules: List[TOMLRule], threads: int = 5) -> Dict[str, RemoteValidationResult]: + def validate_rules(self, rules: list[TOMLRule], threads: int = 5) -> dict[str, RemoteValidationResult]: """Validate a collection of rules via threads.""" responses = {} - def request(c: TOMLRuleContents): + def request(c: TOMLRuleContents) -> None: try: responses[c.data.rule_id] = self.validate_rule(c) except ValidationError as e: - responses[c.data.rule_id] = e.messages + responses[c.data.rule_id] = e.messages # type: ignore[reportUnknownMemberType] pool = ThreadPool(processes=threads) - pool.map(request, [r.contents for r in rules]) + _ = pool.map(request, [r.contents for r in rules]) pool.close() pool.join() - return responses + return responses # type: ignore[reportUnknownVariableType] - def validate_esql(self, contents: TOMLRuleContents) -> dict: - query = contents.data.query + def validate_esql(self, contents: TOMLRuleContents) -> dict[str, Any]: + query = contents.data.query # type: ignore[reportAttributeAccessIssue] rule_id = contents.data.rule_id headers = {"accept": "application/json", "content-type": "application/json"} - body = {'query': f'{query} | LIMIT 0'} + body = {"query": f"{query} | LIMIT 0"} + if not self.es_client: + raise ValueError("No ES client found") try: - response = self.es_client.perform_request('POST', '/_query', headers=headers, params={'pretty': True}, - body=body) + response = self.es_client.perform_request( + "POST", + "/_query", + headers=headers, + params={"pretty": True}, + body=body, + ) except Exception as exc: if isinstance(exc, elasticsearch.BadRequestError): - raise ValidationError(f'ES|QL query failed: {exc} for rule: {rule_id}, query: \n{query}') - else: - raise Exception(f'ES|QL query failed for rule: {rule_id}, query: \n{query}') from exc + raise ValidationError(f"ES|QL query failed: {exc} for rule: {rule_id}, query: \n{query}") from exc + raise Exception(f"ES|QL query failed for rule: {rule_id}, query: \n{query}") from exc # noqa: TRY002 return response.body - def validate_eql(self, contents: TOMLRuleContents) -> dict: + def validate_eql(self, contents: TOMLRuleContents) -> dict[str, Any]: """Validate query for "eql" rule types.""" - query = contents.data.query + query = contents.data.query # type: ignore[reportAttributeAccessIssue] rule_id = contents.data.rule_id - index = contents.data.index - time_range = {"range": {"@timestamp": {"gt": 'now-1h/h', "lte": 'now', "format": "strict_date_optional_time"}}} - body = {'query': query} + index = contents.data.index # type: ignore[reportAttributeAccessIssue] + time_range = {"range": {"@timestamp": {"gt": "now-1h/h", "lte": "now", "format": "strict_date_optional_time"}}} + body: dict[str, Any] = {"query": query} + + if not self.es_client: + raise ValueError("No ES client found") + + if not index: + raise ValueError("Indices not found") + try: - response = self.es_client.eql.search(index=index, body=body, ignore_unavailable=True, filter=time_range) + response = self.es_client.eql.search(index=index, body=body, ignore_unavailable=True, filter=time_range) # type: ignore[reportUnknownArgumentType] except Exception as exc: if isinstance(exc, elasticsearch.BadRequestError): - raise ValidationError(f'EQL query failed: {exc} for rule: {rule_id}, query: \n{query}') - else: - raise Exception(f'EQL query failed for rule: {rule_id}, query: \n{query}') from exc + raise ValidationError(f"EQL query failed: {exc} for rule: {rule_id}, query: \n{query}") from exc + raise Exception(f"EQL query failed for rule: {rule_id}, query: \n{query}") from exc # noqa: TRY002 return response.body @staticmethod - def validate_query(self, contents: TOMLRuleContents) -> dict: + def validate_query(_: Any, __: TOMLRuleContents) -> dict[str, str]: """Validate query for "query" rule types.""" - return {'results': 'Unable to remote validate query rules'} + return {"results": "Unable to remote validate query rules"} @staticmethod - def validate_threshold(self, contents: TOMLRuleContents) -> dict: + def validate_threshold(_: Any, __: TOMLRuleContents) -> dict[str, str]: """Validate query for "threshold" rule types.""" - return {'results': 'Unable to remote validate threshold rules'} + return {"results": "Unable to remote validate threshold rules"} @staticmethod - def validate_new_terms(self, contents: TOMLRuleContents) -> dict: + def validate_new_terms(_: Any, __: TOMLRuleContents) -> dict[str, str]: """Validate query for "new_terms" rule types.""" - return {'results': 'Unable to remote validate new_terms rules'} + return {"results": "Unable to remote validate new_terms rules"} @staticmethod - def validate_threat_match(self, contents: TOMLRuleContents) -> dict: + def validate_threat_match(_: Any, __: TOMLRuleContents) -> dict[str, str]: """Validate query for "threat_match" rule types.""" - return {'results': 'Unable to remote validate threat_match rules'} + return {"results": "Unable to remote validate threat_match rules"} @staticmethod - def validate_machine_learning(self, contents: TOMLRuleContents) -> dict: + def validate_machine_learning(_: Any, __: TOMLRuleContents) -> dict[str, str]: """Validate query for "machine_learning" rule types.""" - return {'results': 'Unable to remote validate machine_learning rules'} + return {"results": "Unable to remote validate machine_learning rules"} diff --git a/detection_rules/rule.py b/detection_rules/rule.py index 06c645d814c..0c293141b70 100644 --- a/detection_rules/rule.py +++ b/detection_rules/rule.py @@ -3,6 +3,7 @@ # 2.0; you may not use this file except in compliance with the Elastic License # 2.0. """Rule object.""" + import copy import dataclasses import json @@ -14,36 +15,46 @@ from dataclasses import dataclass, field from functools import cached_property from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Literal from urllib.parse import urlparse from uuid import uuid4 -import eql +import eql # type: ignore[reportMissingTypeStubs] +import kql # type: ignore[reportMissingTypeStubs] import marshmallow -from semver import Version from marko.block import Document as MarkoDocument from marko.ext.gfm import gfm from marshmallow import ValidationError, pre_load, validates_schema - -import kql +from semver import Version from . import beats, ecs, endgame, utils from .config import load_current_package_version, parse_rules_config -from .integrations import (find_least_compatible_version, get_integration_schema_fields, - load_integrations_manifests, load_integrations_schemas, - parse_datasets) +from .integrations import ( + find_least_compatible_version, + get_integration_schema_fields, + load_integrations_manifests, + load_integrations_schemas, +) from .mixins import MarshmallowDataclassMixin, StackCompatMixin from .rule_formatter import nested_normalize, toml_write -from .schemas import (SCHEMA_DIR, definitions, downgrade, - get_min_supported_stack_version, get_stack_schemas, - strip_non_public_fields) +from .schemas import ( + SCHEMA_DIR, + definitions, + downgrade, + get_min_supported_stack_version, + get_stack_schemas, + strip_non_public_fields, +) from .schemas.stack_compat import get_restricted_fields from .utils import PatchedTemplate, cached, convert_time_span, get_nested_value, set_nested_value +from .version_lock import VersionLock, loaded_version_lock +if typing.TYPE_CHECKING: + from .remote_validation import RemoteValidator -_META_SCHEMA_REQ_DEFAULTS = {} -MIN_FLEET_PACKAGE_VERSION = '7.13.0' -TIME_NOW = time.strftime('%Y/%m/%d') + +MIN_FLEET_PACKAGE_VERSION = "7.13.0" +TIME_NOW = time.strftime("%Y/%m/%d") RULES_CONFIG = parse_rules_config() DEFAULT_PREBUILT_RULES_DIRS = RULES_CONFIG.rule_dirs DEFAULT_PREBUILT_BBR_DIRS = RULES_CONFIG.bbr_rules_dirs @@ -51,38 +62,38 @@ BUILD_FIELD_VERSIONS = { - "related_integrations": (Version.parse('8.3.0'), None), - "required_fields": (Version.parse('8.3.0'), None), - "setup": (Version.parse('8.3.0'), None) + "related_integrations": (Version.parse("8.3.0"), None), + "required_fields": (Version.parse("8.3.0"), None), + "setup": (Version.parse("8.3.0"), None), } -@dataclass +@dataclass(kw_only=True) class DictRule: """Simple object wrapper for raw rule dicts.""" - contents: dict - path: Optional[Path] = None + contents: dict[str, Any] + path: Path | None = None @property - def metadata(self) -> dict: + def metadata(self) -> dict[str, Any]: """Metadata portion of TOML file rule.""" - return self.contents.get('metadata', {}) + return self.contents.get("metadata", {}) @property - def data(self) -> dict: + def data(self) -> dict[str, Any]: """Rule portion of TOML file rule.""" - return self.contents.get('data') or self.contents + return self.contents.get("data") or self.contents @property def id(self) -> str: """Get the rule ID.""" - return self.data['rule_id'] + return self.data["rule_id"] # type: ignore[reportUnknownMemberType] @property def name(self) -> str: """Get the rule name.""" - return self.data['name'] + return self.data["name"] # type: ignore[reportUnknownMemberType] def __hash__(self) -> int: """Get the hash of the rule.""" @@ -93,35 +104,35 @@ def __repr__(self) -> str: return f"Rule({self.name} {self.id})" -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class RuleMeta(MarshmallowDataclassMixin): """Data stored in a rule's [metadata] section of TOML.""" + creation_date: definitions.Date updated_date: definitions.Date - deprecation_date: Optional[definitions.Date] + deprecation_date: definitions.Date | None = None # Optional fields - bypass_bbr_timing: Optional[bool] - comments: Optional[str] - integration: Optional[Union[str, List[str]]] - maturity: Optional[definitions.Maturity] - min_stack_version: Optional[definitions.SemVer] - min_stack_comments: Optional[str] - os_type_list: Optional[List[definitions.OSType]] - query_schema_validation: Optional[bool] - related_endpoint_rules: Optional[List[str]] - promotion: Optional[bool] + bypass_bbr_timing: bool | None = None + comments: str | None = None + integration: str | list[str] | None = None + maturity: definitions.Maturity | None = None + min_stack_version: definitions.SemVer | None = None + min_stack_comments: str | None = None + os_type_list: list[definitions.OSType] | None = None + query_schema_validation: bool | None = None + related_endpoint_rules: list[str] | None = None + promotion: bool | None = None # Extended information as an arbitrary dictionary - extended: Optional[Dict[str, Any]] + extended: dict[str, Any] | None = None - def get_validation_stack_versions(self) -> Dict[str, dict]: + def get_validation_stack_versions(self) -> dict[str, dict[str, Any]]: """Get a dict of beats and ecs versions per stack release.""" - stack_versions = get_stack_schemas(self.min_stack_version) - return stack_versions + return get_stack_schemas(self.min_stack_version) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class RuleTransform(MarshmallowDataclassMixin): """Data stored in a rule's [transform] section of TOML.""" @@ -131,13 +142,13 @@ class RuleTransform(MarshmallowDataclassMixin): # timelines out of scope at the moment - @dataclass(frozen=True) + @dataclass(frozen=True, kw_only=True) class OsQuery: label: str query: str - ecs_mapping: Optional[Dict[str, Dict[Literal['field', 'value'], str]]] + ecs_mapping: dict[str, dict[Literal["field", "value"], str]] | None = None - @dataclass(frozen=True) + @dataclass(frozen=True, kw_only=True) class Investigate: @dataclass(frozen=True) class Provider: @@ -148,23 +159,25 @@ class Provider: valueType: definitions.InvestigateProviderValueType label: str - description: Optional[str] - providers: List[List[Provider]] - relativeFrom: Optional[str] - relativeTo: Optional[str] + description: str | None = None + providers: list[list[Provider]] + relativeFrom: str | None = None + relativeTo: str | None = None # these must be lists in order to have more than one. Their index in the list is how they will be referenced in the # note string templates - osquery: Optional[List[OsQuery]] - investigate: Optional[List[Investigate]] + osquery: list[OsQuery] | None = None + investigate: list[Investigate] | None = None - def render_investigate_osquery_to_string(self) -> Dict[definitions.TransformTypes, List[str]]: + def render_investigate_osquery_to_string(self) -> dict[definitions.TransformTypes, list[str]]: obj = self.to_dict() - rendered: Dict[definitions.TransformTypes, List[str]] = {'osquery': [], 'investigate': []} + rendered: dict[definitions.TransformTypes, list[str]] = {"osquery": [], "investigate": []} for plugin, entries in obj.items(): for entry in entries: - rendered[plugin].append(f'!{{{plugin}{json.dumps(entry, sort_keys=True, separators=(",", ":"))}}}') + if plugin not in rendered: + raise ValueError(f"Unexpected field value: {plugin}") + rendered[plugin].append(f"!{{{plugin}{json.dumps(entry, sort_keys=True, separators=(',', ':'))}}}") return rendered @@ -178,60 +191,64 @@ class BaseThreatEntry: reference: str @pre_load - def modify_url(self, data: Dict[str, Any], **kwargs): + def modify_url(self, data: dict[str, Any], **_: Any) -> dict[str, Any]: """Modify the URL to support MITRE ATT&CK URLS with and without trailing forward slash.""" - if urlparse(data["reference"]).scheme: - if not data["reference"].endswith("/"): - data["reference"] += "/" + p = urlparse(data["reference"]) # type: ignore[reportUnknownVariableType] + if p.scheme and not data["reference"].endswith("/"): # type: ignore[reportUnknownMemberType] + data["reference"] += "/" return data @dataclass(frozen=True) class SubTechnique(BaseThreatEntry): """Mapping to threat subtechnique.""" + reference: definitions.SubTechniqueURL -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class Technique(BaseThreatEntry): """Mapping to threat subtechnique.""" + # subtechniques are stored at threat[].technique.subtechnique[] reference: definitions.TechniqueURL - subtechnique: Optional[List[SubTechnique]] + subtechnique: list[SubTechnique] | None = None @dataclass(frozen=True) class Tactic(BaseThreatEntry): """Mapping to a threat tactic.""" + reference: definitions.TacticURL -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class ThreatMapping(MarshmallowDataclassMixin): """Mapping to a threat framework.""" + framework: Literal["MITRE ATT&CK"] tactic: Tactic - technique: Optional[List[Technique]] + technique: list[Technique] | None = None @staticmethod - def flatten(threat_mappings: Optional[List]) -> 'FlatThreatMapping': + def flatten(threat_mappings: list["ThreatMapping"] | None) -> "FlatThreatMapping": """Get flat lists of tactic and technique info.""" - tactic_names = [] - tactic_ids = [] - technique_ids = set() - technique_names = set() - sub_technique_ids = set() - sub_technique_names = set() - - for entry in (threat_mappings or []): + tactic_names: list[str] = [] + tactic_ids: list[str] = [] + technique_ids: set[str] = set() + technique_names: set[str] = set() + sub_technique_ids: set[str] = set() + sub_technique_names: set[str] = set() + + for entry in threat_mappings or []: tactic_names.append(entry.tactic.name) tactic_ids.append(entry.tactic.id) - for technique in (entry.technique or []): + for technique in entry.technique or []: technique_names.add(technique.name) technique_ids.add(technique.id) - for subtechnique in (technique.subtechnique or []): + for subtechnique in technique.subtechnique or []: sub_technique_ids.add(subtechnique.id) sub_technique_names.add(subtechnique.name) @@ -241,48 +258,49 @@ def flatten(threat_mappings: Optional[List]) -> 'FlatThreatMapping': technique_names=sorted(technique_names), technique_ids=sorted(technique_ids), sub_technique_names=sorted(sub_technique_names), - sub_technique_ids=sorted(sub_technique_ids) + sub_technique_ids=sorted(sub_technique_ids), ) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class RiskScoreMapping(MarshmallowDataclassMixin): field: str - operator: Optional[definitions.Operator] - value: Optional[str] + operator: definitions.Operator | None = None + value: str | None = None -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class SeverityMapping(MarshmallowDataclassMixin): field: str - operator: Optional[definitions.Operator] - value: Optional[str] - severity: Optional[str] + operator: definitions.Operator | None = None + value: str | None = None + severity: str | None = None @dataclass(frozen=True) class FlatThreatMapping(MarshmallowDataclassMixin): - tactic_names: List[str] - tactic_ids: List[str] - technique_names: List[str] - technique_ids: List[str] - sub_technique_names: List[str] - sub_technique_ids: List[str] + tactic_names: list[str] + tactic_ids: list[str] + technique_names: list[str] + technique_ids: list[str] + sub_technique_names: list[str] + sub_technique_ids: list[str] @dataclass(frozen=True) class AlertSuppressionDuration: """Mapping to alert suppression duration.""" + unit: definitions.TimeUnits value: definitions.AlertSuppressionValue -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class AlertSuppressionMapping(MarshmallowDataclassMixin, StackCompatMixin): """Mapping to alert suppression.""" group_by: definitions.AlertSuppressionGroupBy - duration: Optional[AlertSuppressionDuration] + duration: AlertSuppressionDuration | None = None missing_fields_strategy: definitions.AlertSuppressionMissing @@ -298,19 +316,19 @@ class FilterStateStore: store: definitions.StoreType -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class FilterMeta: - alias: Optional[Union[str, None]] = None - disabled: Optional[bool] = None - negate: Optional[bool] = None - controlledBy: Optional[str] = None # identify who owns the filter - group: Optional[str] = None # allows grouping of filters - index: Optional[str] = None - isMultiIndex: Optional[bool] = None - type: Optional[str] = None - key: Optional[str] = None - params: Optional[str] = None # Expand to FilterMetaParams when needed - value: Optional[str] = None + alias: str | None = None + disabled: bool | None = None + negate: bool | None = None + controlledBy: str | None # identify who owns the filter + group: str | None # allows grouping of filters + index: str | None = None + isMultiIndex: bool | None = None + type: str | None = None + key: str | None = None + params: str | None = None # Expand to FilterMetaParams when needed + value: str | None = None @dataclass(frozen=True) @@ -319,28 +337,29 @@ class WildcardQuery: value: str -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class Query: - wildcard: Optional[Dict[str, WildcardQuery]] = None + wildcard: dict[str, WildcardQuery] | None = None -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class Filter: """Kibana Filter for Base Rule Data.""" - # TODO: Currently unused in BaseRuleData. Revisit to extend or remove. + + # Currently unused in BaseRuleData. Revisit to extend or remove. # https://github.com/elastic/detection-rules/issues/3773 meta: FilterMeta - state: Optional[FilterStateStore] = field(metadata=dict(data_key="$state")) - query: Optional[Union[Query, Dict[str, Any]]] = None + state: FilterStateStore | None = field(metadata={"data_key": "$state"}) + query: Query | dict[str, Any] | None = None -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class BaseRuleData(MarshmallowDataclassMixin, StackCompatMixin): """Base rule data.""" @dataclass class InvestigationFields: - field_names: List[definitions.NonEmptyStr] + field_names: list[definitions.NonEmptyStr] @dataclass class RequiredFields: @@ -352,53 +371,52 @@ class RequiredFields: class RelatedIntegrations: package: definitions.NonEmptyStr version: definitions.NonEmptyStr - integration: Optional[definitions.NonEmptyStr] + integration: definitions.NonEmptyStr | None = None - actions: Optional[list] - author: List[str] - building_block_type: Optional[definitions.BuildingBlockType] - description: str - enabled: Optional[bool] - exceptions_list: Optional[list] - license: Optional[str] - false_positives: Optional[List[str]] - filters: Optional[List[dict]] - # trailing `_` required since `from` is a reserved word in python - from_: Optional[str] = field(metadata=dict(data_key="from")) - interval: Optional[definitions.Interval] - investigation_fields: Optional[InvestigationFields] = field(metadata=dict(metadata=dict(min_compat="8.11"))) - max_signals: Optional[definitions.MaxSignals] - meta: Optional[Dict[str, Any]] name: definitions.RuleName - note: Optional[definitions.Markdown] - # can we remove this comment? - # explicitly NOT allowed! - # output_index: Optional[str] - references: Optional[List[str]] - related_integrations: Optional[List[RelatedIntegrations]] = field(metadata=dict(metadata=dict(min_compat="8.3"))) - required_fields: Optional[List[RequiredFields]] = field(metadata=dict(metadata=dict(min_compat="8.3"))) - revision: Optional[int] = field(metadata=dict(metadata=dict(min_compat="8.8"))) + + author: list[str] + description: str + from_: str | None = field(metadata={"data_key": "from"}) + investigation_fields: InvestigationFields | None = field(metadata={"metadata": {"min_compat": "8.11"}}) + related_integrations: list[RelatedIntegrations] | None = field(metadata={"metadata": {"min_compat": "8.3"}}) + required_fields: list[RequiredFields] | None = field(metadata={"metadata": {"min_compat": "8.3"}}) + revision: int | None = field(metadata={"metadata": {"min_compat": "8.8"}}) + setup: definitions.Markdown | None = field(metadata={"metadata": {"min_compat": "8.3"}}) + risk_score: definitions.RiskScore - risk_score_mapping: Optional[List[RiskScoreMapping]] rule_id: definitions.UUIDString - rule_name_override: Optional[str] - setup: Optional[definitions.Markdown] = field(metadata=dict(metadata=dict(min_compat="8.3"))) - severity_mapping: Optional[List[SeverityMapping]] severity: definitions.Severity - tags: Optional[List[str]] - throttle: Optional[str] - timeline_id: Optional[definitions.TimelineTemplateId] - timeline_title: Optional[definitions.TimelineTemplateTitle] - timestamp_override: Optional[str] - to: Optional[str] type: definitions.RuleType - threat: Optional[List[ThreatMapping]] - version: Optional[definitions.PositiveInteger] + + actions: list[dict[str, Any]] | None = None + building_block_type: definitions.BuildingBlockType | None = None + enabled: bool | None = None + exceptions_list: list[dict[str, str]] | None = None + false_positives: list[str] | None = None + filters: list[dict[str, Any]] | None = None + interval: definitions.Interval | None = None + license: str | None = None + max_signals: definitions.MaxSignals | None = None + meta: dict[str, Any] | None = None + note: definitions.Markdown | None = None + references: list[str] | None = None + risk_score_mapping: list[RiskScoreMapping] | None = None + rule_name_override: str | None = None + severity_mapping: list[SeverityMapping] | None = None + tags: list[str] | None = None + threat: list[ThreatMapping] | None = None + throttle: str | None = None + timeline_id: definitions.TimelineTemplateId | None = None + timeline_title: definitions.TimelineTemplateTitle | None = None + timestamp_override: str | None = None + to: str | None = None + version: definitions.PositiveInteger | None = None @classmethod - def save_schema(cls): + def save_schema(cls) -> None: """Save the schema as a jsonschema.""" - fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(cls) + fields: tuple[dataclasses.Field[Any], ...] = dataclasses.fields(cls) type_field = next(f for f in fields if f.name == "type") rule_type = typing.get_args(type_field.type)[0] if cls != BaseRuleData else "base" schema = cls.jsonschema() @@ -409,132 +427,126 @@ def save_schema(cls): with (version_dir / f"master.{rule_type}.json").open("w") as f: json.dump(schema, f, indent=2, sort_keys=True) - def validate_query(self, meta: RuleMeta) -> None: + def validate_query(self, _: RuleMeta) -> None: pass @cached_property - def get_restricted_fields(self) -> Optional[Dict[str, tuple]]: + def get_restricted_fields(self) -> dict[str, tuple[Version | None, Version | None]] | None: """Get stack version restricted fields.""" - fields: List[dataclasses.Field, ...] = list(dataclasses.fields(self)) + fields: list[dataclasses.Field[Any]] = list(dataclasses.fields(self)) return get_restricted_fields(fields) @cached_property - def data_validator(self) -> Optional['DataValidator']: + def data_validator(self) -> "DataValidator | None": return DataValidator(is_elastic_rule=self.is_elastic_rule, **self.to_dict()) @cached_property def notify(self) -> bool: - return os.environ.get('DR_NOTIFY_INTEGRATION_UPDATE_AVAILABLE') is not None + return os.environ.get("DR_NOTIFY_INTEGRATION_UPDATE_AVAILABLE") is not None @cached_property - def parsed_note(self) -> Optional[MarkoDocument]: + def parsed_note(self) -> MarkoDocument | None: dv = self.data_validator if dv: return dv.parsed_note + return None @property - def is_elastic_rule(self): - return 'elastic' in [a.lower() for a in self.author] + def is_elastic_rule(self) -> bool: + return "elastic" in [a.lower() for a in self.author] - def get_build_fields(self) -> {}: + def get_build_fields(self) -> dict[str, tuple[Version, None]]: """Get a list of build-time fields along with the stack versions which they will build within.""" - build_fields = {} rule_fields = {f.name: f for f in dataclasses.fields(self)} - - for fld in BUILD_FIELD_VERSIONS: - if fld in rule_fields: - build_fields[fld] = BUILD_FIELD_VERSIONS[fld] - - return build_fields + return {fld: val for fld, val in BUILD_FIELD_VERSIONS.items() if fld in rule_fields} @classmethod - def process_transforms(cls, transform: RuleTransform, obj: dict) -> dict: + def process_transforms(cls, transform: RuleTransform, obj: dict[str, Any]) -> dict[str, Any]: """Process transforms from toml [transform] called in TOMLRuleContents.to_dict.""" # only create functions that CAREFULLY mutate the obj dict - def process_note_plugins(): - """Format the note field with osquery and investigate plugin strings.""" - note = obj.get('note') - if not note: - return - - rendered = transform.render_investigate_osquery_to_string() - rendered_patterns = {} - for plugin, entries in rendered.items(): - rendered_patterns.update(**{f'{plugin}_{i}': e for i, e in enumerate(entries)}) + # Format the note field with osquery and investigate plugin strings + note = obj.get("note") + if not note: + return obj - note_template = PatchedTemplate(note) - rendered_note = note_template.safe_substitute(**rendered_patterns) - obj['note'] = rendered_note + rendered = transform.render_investigate_osquery_to_string() + rendered_patterns: dict[str, Any] = {} + for plugin, entries in rendered.items(): + rendered_patterns.update(**{f"{plugin}_{i}": e for i, e in enumerate(entries)}) # type: ignore[reportUnknownMemberType] - # call transform functions - if transform: - process_note_plugins() + note_template = PatchedTemplate(note) + rendered_note = note_template.safe_substitute(**rendered_patterns) + obj["note"] = rendered_note return obj @validates_schema - def validates_data(self, data, **kwargs): + def validates_data(self, data: dict[str, Any], **_: Any) -> None: """Validate fields and data for marshmallow schemas.""" # Validate version and revision fields not supplied. - disallowed_fields = [field for field in ['version', 'revision'] if data.get(field) is not None] + disallowed_fields = [field for field in ["version", "revision"] if data.get(field) is not None] if not disallowed_fields: return - error_message = " and ".join(disallowed_fields) - # If version and revision fields are supplied, and using locked versions raise an error. if BYPASS_VERSION_LOCK is not True: - msg = (f"Configuration error: Rule {data['name']} - {data['rule_id']} " - f"should not contain rules with `{error_message}` set.") + error_message = " and ".join(disallowed_fields) + msg = ( + f"Configuration error: Rule {data['name']} - {data['rule_id']} " + f"should not contain rules with `{error_message}` set." + ) raise ValidationError(msg) class DataValidator: """Additional validation beyond base marshmallow schema validation.""" - def __init__(self, - name: definitions.RuleName, - is_elastic_rule: bool, - note: Optional[definitions.Markdown] = None, - interval: Optional[definitions.Interval] = None, - building_block_type: Optional[definitions.BuildingBlockType] = None, - setup: Optional[str] = None, - **extras): + def __init__( # noqa: PLR0913 + self, + name: definitions.RuleName, + is_elastic_rule: bool, + note: definitions.Markdown | None = None, + interval: definitions.Interval | None = None, + building_block_type: definitions.BuildingBlockType | None = None, + setup: str | None = None, + **extras: Any, + ) -> None: # only define fields needing additional validation self.name = name self.is_elastic_rule = is_elastic_rule self.note = note # Need to use extras because from is a reserved word in python - self.from_ = extras.get('from') + self.from_ = extras.get("from") self.interval = interval self.building_block_type = building_block_type self.setup = setup self._setup_in_note = False @cached_property - def parsed_note(self) -> Optional[MarkoDocument]: + def parsed_note(self) -> MarkoDocument | None: if self.note: return gfm.parse(self.note) + return None @property - def setup_in_note(self): + def setup_in_note(self) -> bool: return self._setup_in_note @setup_in_note.setter - def setup_in_note(self, value: bool): + def setup_in_note(self, value: bool) -> None: self._setup_in_note = value @cached_property def skip_validate_note(self) -> bool: - return os.environ.get('DR_BYPASS_NOTE_VALIDATION_AND_PARSE') is not None + return os.environ.get("DR_BYPASS_NOTE_VALIDATION_AND_PARSE") is not None @cached_property def skip_validate_bbr(self) -> bool: - return os.environ.get('DR_BYPASS_BBR_LOOKBACK_VALIDATION') is not None + return os.environ.get("DR_BYPASS_BBR_LOOKBACK_VALIDATION") is not None - def validate_bbr(self, bypass: bool = False): + def validate_bbr(self, bypass: bool = False) -> None: """Validate building block type and rule type.""" if self.skip_validate_bbr or bypass: @@ -552,7 +564,7 @@ def validate_lookback(str_time: str) -> bool: else: return False except Exception as e: - raise ValidationError(f"Invalid time format: {e}") + raise ValidationError(f"Invalid time format: {e}") from e return True def validate_interval(str_time: str) -> bool: @@ -563,7 +575,7 @@ def validate_interval(str_time: str) -> bool: if time < 60 * 60 * 1000: return False except Exception as e: - raise ValidationError(f"Invalid time format: {e}") + raise ValidationError(f"Invalid time format: {e}") from e return True bypass_instructions = "To bypass, use the environment variable `DR_BYPASS_BBR_LOOKBACK_VALIDATION`" @@ -574,7 +586,7 @@ def validate_interval(str_time: str) -> bool: "BBR require `from` and `interval` to be defined. " "Please set or bypass." + bypass_instructions ) - elif not validate_lookback(self.from_) or not validate_interval(self.interval): + if not validate_lookback(self.from_) or not validate_interval(self.interval): raise ValidationError( f"{self.name} is invalid." "Default BBR require `from` and `interval` to be at least now-119m and at least 60m respectively " @@ -582,35 +594,38 @@ def validate_interval(str_time: str) -> bool: "Please update values or bypass. " + bypass_instructions ) - def validate_note(self): + def validate_note(self) -> None: if self.skip_validate_note or not self.note: return + if not self.parsed_note: + return + try: for child in self.parsed_note.children: if child.get_type() == "Heading": header = gfm.renderer.render_children(child) if header.lower() == "setup": - # check that the Setup header is correctly formatted at level 2 - if child.level != 2: - raise ValidationError(f"Setup section with wrong header level: {child.level}") + if child.level != 2: # type: ignore[reportAttributeAccessIssue] # noqa: PLR2004 + raise ValidationError(f"Setup section with wrong header level: {child.level}") # type: ignore[reportAttributeAccessIssue] # noqa: TRY301 # check that the Setup header is capitalized - if child.level == 2 and header != "Setup": - raise ValidationError(f"Setup header has improper casing: {header}") + if child.level == 2 and header != "Setup": # type: ignore[reportAttributeAccessIssue] # noqa: PLR2004 + raise ValidationError(f"Setup header has improper casing: {header}") # noqa: TRY301 self.setup_in_note = True - else: - # check that the header Config does not exist in the Setup section - if child.level == 2 and "config" in header.lower(): - raise ValidationError(f"Setup header contains Config: {header}") + # check that the header Config does not exist in the Setup section + elif child.level == 2 and "config" in header.lower(): # type: ignore[reportAttributeAccessIssue] # noqa: PLR2004 + raise ValidationError(f"Setup header contains Config: {header}") # noqa: TRY301 except Exception as e: - raise ValidationError(f"Invalid markdown in rule `{self.name}`: {e}. To bypass validation on the `note`" - f"field, use the environment variable `DR_BYPASS_NOTE_VALIDATION_AND_PARSE`") + raise ValidationError( + f"Invalid markdown in rule `{self.name}`: {e}. To bypass validation on the `note`" + f"field, use the environment variable `DR_BYPASS_NOTE_VALIDATION_AND_PARSE`" + ) from e # raise if setup header is in note and in setup if self.setup_in_note and (self.setup and self.setup != "None"): @@ -623,157 +638,164 @@ class QueryValidator: @property def ast(self) -> Any: - raise NotImplementedError() + raise NotImplementedError @property def unique_fields(self) -> Any: - raise NotImplementedError() + raise NotImplementedError - def validate(self, data: 'QueryRuleData', meta: RuleMeta) -> None: - raise NotImplementedError() + def validate(self, _: "QueryRuleData", __: RuleMeta) -> None: + raise NotImplementedError @cached - def get_required_fields(self, index: str) -> List[Optional[dict]]: + def get_required_fields(self, index: str) -> list[dict[str, Any]]: """Retrieves fields needed for the query along with type information from the schema.""" if isinstance(self, ESQLValidator): return [] current_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True) - ecs_version = get_stack_schemas()[str(current_version)]['ecs'] - beats_version = get_stack_schemas()[str(current_version)]['beats'] - endgame_version = get_stack_schemas()[str(current_version)]['endgame'] + ecs_version = get_stack_schemas()[str(current_version)]["ecs"] + beats_version = get_stack_schemas()[str(current_version)]["beats"] + endgame_version = get_stack_schemas()[str(current_version)]["endgame"] ecs_schema = ecs.get_schema(ecs_version) - beat_types, beat_schema, schema = self.get_beats_schema(index or [], beats_version, ecs_version) + _, beat_schema, schema = self.get_beats_schema(index or [], beats_version, ecs_version) endgame_schema = self.get_endgame_schema(index or [], endgame_version) # construct integration schemas packages_manifest = load_integrations_manifests() integrations_schemas = load_integrations_schemas() datasets, _ = beats.get_datasets_and_modules(self.ast) - package_integrations = parse_datasets(datasets, packages_manifest) - int_schema = {} + package_integrations = parse_datasets(list(datasets), packages_manifest) + int_schema: dict[str, Any] = {} data = {"notify": False} for pk_int in package_integrations: package = pk_int["package"] integration = pk_int["integration"] - schema, _ = get_integration_schema_fields(integrations_schemas, package, integration, - current_version, packages_manifest, {}, data) + schema, _ = get_integration_schema_fields( + integrations_schemas, package, integration, current_version, packages_manifest, {}, data + ) int_schema.update(schema) - required = [] - unique_fields = self.unique_fields or [] + required: list[dict[str, Any]] = [] + unique_fields: list[str] = self.unique_fields or [] for fld in unique_fields: - field_type = ecs_schema.get(fld, {}).get('type') + field_type = ecs_schema.get(fld, {}).get("type") is_ecs = field_type is not None if not is_ecs: if int_schema: - field_type = int_schema.get(fld, None) + field_type = int_schema.get(fld) elif beat_schema: - field_type = beat_schema.get(fld, {}).get('type') + field_type = beat_schema.get(fld, {}).get("type") elif endgame_schema: field_type = endgame_schema.endgame_schema.get(fld, None) - required.append(dict(name=fld, type=field_type or 'unknown', ecs=is_ecs)) + required.append({"name": fld, "type": field_type or "unknown", "ecs": is_ecs}) - return sorted(required, key=lambda f: f['name']) + return sorted(required, key=lambda f: f["name"]) @cached - def get_beats_schema(self, index: list, beats_version: str, ecs_version: str) -> (list, dict, dict): + def get_beats_schema( + self, indices: list[str], beats_version: str, ecs_version: str + ) -> tuple[list[str], dict[str, Any] | None, dict[str, Any]]: """Get an assembled beats schema.""" - beat_types = beats.parse_beats_from_index(index) + beat_types = beats.parse_beats_from_index(indices) beat_schema = beats.get_schema_from_kql(self.ast, beat_types, version=beats_version) if beat_types else None - schema = ecs.get_kql_schema(version=ecs_version, indexes=index, beat_schema=beat_schema) + schema = ecs.get_kql_schema(version=ecs_version, indexes=indices, beat_schema=beat_schema) return beat_types, beat_schema, schema @cached - def get_endgame_schema(self, index: list, endgame_version: str) -> Optional[endgame.EndgameSchema]: + def get_endgame_schema(self, indices: list[str], endgame_version: str) -> endgame.EndgameSchema | None: """Get an assembled flat endgame schema.""" - if index and "endgame-*" not in index: + if indices and "endgame-*" not in indices: return None endgame_schema = endgame.read_endgame_schema(endgame_version=endgame_version) return endgame.EndgameSchema(endgame_schema) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class QueryRuleData(BaseRuleData): """Specific fields for query event types.""" - type: Literal["query"] - index: Optional[List[str]] - data_view_id: Optional[str] + type: Literal["query"] query: str language: definitions.FilterLanguages - alert_suppression: Optional[AlertSuppressionMapping] = field(metadata=dict(metadata=dict(min_compat="8.8"))) + alert_suppression: AlertSuppressionMapping | None = field(metadata={"metadata": {"min_compat": "8.8"}}) + + index: list[str] | None = None + data_view_id: str | None = None @cached_property def index_or_dataview(self) -> list[str]: """Return the index or dataview depending on which is set. If neither returns empty list.""" if self.index is not None: return self.index - elif self.data_view_id is not None: + if self.data_view_id is not None: return [self.data_view_id] - else: - return [] + return [] @cached_property - def validator(self) -> Optional[QueryValidator]: + def validator(self) -> QueryValidator | None: if self.language == "kuery": return KQLValidator(self.query) - elif self.language == "eql": + if self.language == "eql": return EQLValidator(self.query) - elif self.language == "esql": + if self.language == "esql": return ESQLValidator(self.query) + return None - def validate_query(self, meta: RuleMeta) -> None: + def validate_query(self, meta: RuleMeta) -> None: # type: ignore[reportIncompatibleMethodOverride] validator = self.validator - if validator is not None: - return validator.validate(self, meta) + if validator: + validator.validate(self, meta) @cached_property - def ast(self): + def ast(self) -> Any: validator = self.validator if validator is not None: return validator.ast + return None @cached_property - def unique_fields(self): + def unique_fields(self) -> None: validator = self.validator if validator is not None: return validator.unique_fields + return None @cached - def get_required_fields(self, index: str) -> List[dict]: + def get_required_fields(self, index: str) -> list[dict[str, Any]] | None: validator = self.validator if validator is not None: return validator.get_required_fields(index or []) + return None @validates_schema - def validates_index_and_data_view_id(self, data, **kwargs): + def validates_index_and_data_view_id(self, data: dict[str, Any], **_: Any) -> None: """Validate that either index or data_view_id is set, but not both.""" - if data.get('index') and data.get('data_view_id'): + if data.get("index") and data.get("data_view_id"): raise ValidationError("Only one of index or data_view_id should be set.") -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class MachineLearningRuleData(BaseRuleData): type: Literal["machine_learning"] anomaly_threshold: int - machine_learning_job_id: Union[str, List[str]] - alert_suppression: Optional[AlertSuppressionMapping] = field(metadata=dict(metadata=dict(min_compat="8.15"))) + machine_learning_job_id: str | list[str] + alert_suppression: AlertSuppressionMapping | None = field(metadata={"metadata": {"min_compat": "8.15"}}) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class ThresholdQueryRuleData(QueryRuleData): """Specific fields for query event types.""" - @dataclass(frozen=True) + @dataclass(frozen=True, kw_only=True) class ThresholdMapping(MarshmallowDataclassMixin): @dataclass(frozen=True) class ThresholdCardinality: @@ -782,14 +804,14 @@ class ThresholdCardinality: field: definitions.CardinalityFields value: definitions.ThresholdValue - cardinality: Optional[List[ThresholdCardinality]] + cardinality: list[ThresholdCardinality] | None = None - type: Literal["threshold"] + type: Literal["threshold"] # type: ignore[reportIncompatibleVariableOverride] threshold: ThresholdMapping - alert_suppression: Optional[ThresholdAlertSuppression] = field(metadata=dict(metadata=dict(min_compat="8.12"))) + alert_suppression: ThresholdAlertSuppression | None = field(metadata={"metadata": {"min_compat": "8.12"}}) # type: ignore[reportIncompatibleVariableOverride] -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class NewTermsRuleData(QueryRuleData): """Specific fields for new terms field rule.""" @@ -802,25 +824,20 @@ class HistoryWindowStart: field: definitions.NonEmptyStr value: definitions.NewTermsFields - history_window_start: List[HistoryWindowStart] + history_window_start: list[HistoryWindowStart] - type: Literal["new_terms"] + type: Literal["new_terms"] # type: ignore[reportIncompatibleVariableOverride] new_terms: NewTermsMapping - alert_suppression: Optional[AlertSuppressionMapping] = field(metadata=dict(metadata=dict(min_compat="8.14"))) + alert_suppression: AlertSuppressionMapping | None = field(metadata={"metadata": {"min_compat": "8.14"}}) @pre_load - def preload_data(self, data: dict, **kwargs) -> dict: + def preload_data(self, data: dict[str, Any], **_: Any) -> dict[str, Any]: """Preloads and formats the data to match the required schema.""" if "new_terms_fields" in data and "history_window_start" in data: new_terms_mapping = { "field": "new_terms_fields", "value": data["new_terms_fields"], - "history_window_start": [ - { - "field": "history_window_start", - "value": data["history_window_start"] - } - ] + "history_window_start": [{"field": "history_window_start", "value": data["history_window_start"]}], } data["new_terms"] = new_terms_mapping @@ -829,7 +846,7 @@ def preload_data(self, data: dict, **kwargs) -> dict: data.pop("history_window_start") return data - def transform(self, obj: dict) -> dict: + def transform(self, obj: dict[str, Any]) -> dict[str, Any]: """Transforms new terms data to API format for Kibana.""" obj[obj["new_terms"].get("field")] = obj["new_terms"].get("value") obj["history_window_start"] = obj["new_terms"]["history_window_start"][0].get("value") @@ -837,85 +854,89 @@ def transform(self, obj: dict) -> dict: return obj -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class EQLRuleData(QueryRuleData): """EQL rules are a special case of query rules.""" - type: Literal["eql"] + + type: Literal["eql"] # type: ignore[reportIncompatibleVariableOverride] language: Literal["eql"] - timestamp_field: Optional[str] = field(metadata=dict(metadata=dict(min_compat="8.0"))) - event_category_override: Optional[str] = field(metadata=dict(metadata=dict(min_compat="8.0"))) - tiebreaker_field: Optional[str] = field(metadata=dict(metadata=dict(min_compat="8.0"))) - alert_suppression: Optional[AlertSuppressionMapping] = field(metadata=dict(metadata=dict(min_compat="8.14"))) + timestamp_field: str | None = field(metadata={"metadata": {"min_compat": "8.0"}}) + event_category_override: str | None = field(metadata={"metadata": {"min_compat": "8.0"}}) + tiebreaker_field: str | None = field(metadata={"metadata": {"min_compat": "8.0"}}) + alert_suppression: AlertSuppressionMapping | None = field(metadata={"metadata": {"min_compat": "8.14"}}) def convert_relative_delta(self, lookback: str) -> int: now = len("now") - min_length = now + len('+5m') + min_length = now + len("+5m") if lookback.startswith("now") and len(lookback) >= min_length: - lookback = lookback[len("now"):] + lookback = lookback[len("now") :] sign = lookback[0] # + or - span = lookback[1:] amount = convert_time_span(span) return amount * (-1 if sign == "-" else 1) - else: - return convert_time_span(lookback) + return convert_time_span(lookback) @cached_property def is_sample(self) -> bool: """Checks if the current rule is a sample-based rule.""" - return eql.utils.get_query_type(self.ast) == 'sample' + return eql.utils.get_query_type(self.ast) == "sample" # type: ignore[reportUnknownMemberType] @cached_property def is_sequence(self) -> bool: """Checks if the current rule is a sequence-based rule.""" - return eql.utils.get_query_type(self.ast) == 'sequence' + return eql.utils.get_query_type(self.ast) == "sequence" # type: ignore[reportUnknownMemberType] @cached_property - def max_span(self) -> Optional[int]: + def max_span(self) -> int | None: """Maxspan value for sequence rules if defined.""" - if self.is_sequence and hasattr(self.ast.first, 'max_span'): + if not self.ast: + raise ValueError("No AST found") + if self.is_sequence and hasattr(self.ast.first, "max_span"): return self.ast.first.max_span.as_milliseconds() if self.ast.first.max_span else None + return None @cached_property - def look_back(self) -> Optional[Union[int, Literal['unknown']]]: + def look_back(self) -> int | Literal["unknown"] | None: """Lookback value of a rule.""" # https://www.elastic.co/guide/en/elasticsearch/reference/current/common-options.html#date-math to = self.convert_relative_delta(self.to) if self.to else 0 from_ = self.convert_relative_delta(self.from_ or "now-6m") if not (to or from_): - return 'unknown' - else: - return to - from_ + return "unknown" + return to - from_ @cached_property - def interval_ratio(self) -> Optional[float]: + def interval_ratio(self) -> float | None: """Ratio of interval time window / max_span time window.""" if self.max_span: - interval = convert_time_span(self.interval or '5m') + interval = convert_time_span(self.interval or "5m") return interval / self.max_span + return None -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class ESQLRuleData(QueryRuleData): """ESQL rules are a special case of query rules.""" - type: Literal["esql"] + + type: Literal["esql"] # type: ignore[reportIncompatibleVariableOverride] language: Literal["esql"] query: str - alert_suppression: Optional[AlertSuppressionMapping] = field(metadata=dict(metadata=dict(min_compat="8.15"))) + alert_suppression: AlertSuppressionMapping | None = field(metadata={"metadata": {"min_compat": "8.15"}}) @validates_schema - def validates_esql_data(self, data, **kwargs): + def validates_esql_data(self, data: dict[str, Any], **_: Any) -> None: """Custom validation for query rule type and subclasses.""" - if data.get('index'): + if data.get("index"): raise ValidationError("Index is not a valid field for ES|QL rule type.") # Convert the query string to lowercase to handle case insensitivity - query_lower = data['query'].lower() + query_lower = data["query"].lower() # Combine both patterns using an OR operator and compile the regex combined_pattern = re.compile( - r'(from\s+\S+\s+metadata\s+_id,\s*_version,\s*_index)|(\bstats\b.*?\bby\b)', re.DOTALL + r"(from\s+\S+\s+metadata\s+_id,\s*_version,\s*_index)|(\bstats\b.*?\bby\b)", re.DOTALL ) # Ensure that non-aggregate queries have metadata @@ -927,47 +948,45 @@ def validates_esql_data(self, data, **kwargs): ) # Enforce KEEP command for ESQL rules - if '| keep' not in query_lower: + if "| keep" not in query_lower: raise ValidationError( - f"Rule: {data['name']} does not contain a 'keep' command ->" - f" Add a 'keep' command to the query." + f"Rule: {data['name']} does not contain a 'keep' command -> Add a 'keep' command to the query." ) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class ThreatMatchRuleData(QueryRuleData): """Specific fields for indicator (threat) match rule.""" @dataclass(frozen=True) class Entries: - @dataclass(frozen=True) class ThreatMapEntry: field: definitions.NonEmptyStr type: Literal["mapping"] value: definitions.NonEmptyStr - entries: List[ThreatMapEntry] + entries: list[ThreatMapEntry] - type: Literal["threat_match"] + type: Literal["threat_match"] # type: ignore[reportIncompatibleVariableOverride] - concurrent_searches: Optional[definitions.PositiveInteger] - items_per_search: Optional[definitions.PositiveInteger] + concurrent_searches: definitions.PositiveInteger | None = None + items_per_search: definitions.PositiveInteger | None = None - threat_mapping: List[Entries] - threat_filters: Optional[List[dict]] - threat_query: Optional[str] - threat_language: Optional[definitions.FilterLanguages] - threat_index: List[str] - threat_indicator_path: Optional[str] - alert_suppression: Optional[AlertSuppressionMapping] = field(metadata=dict(metadata=dict(min_compat="8.13"))) + threat_mapping: list[Entries] + threat_filters: list[dict[str, Any]] | None = None + threat_query: str | None = None + threat_language: definitions.FilterLanguages | None = None + threat_index: list[str] + threat_indicator_path: str | None = None + alert_suppression: AlertSuppressionMapping | None = field(metadata={"metadata": {"min_compat": "8.13"}}) def validate_query(self, meta: RuleMeta) -> None: - super(ThreatMatchRuleData, self).validate_query(meta) + super().validate_query(meta) if self.threat_query: if not self.threat_language: - raise ValidationError('`threat_language` required when a `threat_query` is defined') + raise ValidationError("`threat_language` required when a `threat_query` is defined") if self.threat_language == "kuery": threat_query_validator = KQLValidator(self.threat_query) @@ -981,8 +1000,15 @@ def validate_query(self, meta: RuleMeta) -> None: # All of the possible rule types # Sort inverse of any inheritance - see comment in TOMLRuleContents.to_dict -AnyRuleData = Union[EQLRuleData, ESQLRuleData, ThresholdQueryRuleData, ThreatMatchRuleData, - MachineLearningRuleData, QueryRuleData, NewTermsRuleData] +AnyRuleData = ( + EQLRuleData + | ESQLRuleData + | ThresholdQueryRuleData + | ThreatMatchRuleData + | MachineLearningRuleData + | QueryRuleData + | NewTermsRuleData +) class BaseRuleContents(ABC): @@ -990,29 +1016,27 @@ class BaseRuleContents(ABC): @property @abstractmethod - def id(self): + def id(self) -> str: pass @property @abstractmethod - def name(self): + def name(self) -> str: pass @property @abstractmethod - def version_lock(self): + def version_lock(self) -> "VersionLock": pass @property @abstractmethod - def type(self): + def type(self) -> str: pass - def lock_info(self, bump=True) -> dict: + def lock_info(self, bump: bool = True) -> dict[str, Any]: version = self.autobumped_version if bump else (self.saved_version or 1) - contents = {"rule_name": self.name, "sha256": self.get_hash(), "version": version, "type": self.type} - - return contents + return {"rule_name": self.name, "sha256": self.get_hash(), "version": version, "type": self.type} @property def is_dirty(self) -> bool: @@ -1027,21 +1051,21 @@ def is_dirty(self) -> bool: rule_hash_with_integrations = self.get_hash(include_integrations=True) # Checking against current and previous version of the hash to avoid mass version bump - is_dirty = existing_sha256 not in (rule_hash, rule_hash_with_integrations) - return is_dirty + return existing_sha256 not in (rule_hash, rule_hash_with_integrations) @property - def lock_entry(self) -> Optional[dict]: + def lock_entry(self) -> dict[str, Any] | None: lock_entry = self.version_lock.version_lock.data.get(self.id) if lock_entry: return lock_entry.to_dict() + return None @property def has_forked(self) -> bool: """Determine if the rule has forked at any point (has a previous entry).""" lock_entry = self.lock_entry if lock_entry: - return 'previous' in lock_entry + return "previous" in lock_entry return False @property @@ -1049,36 +1073,45 @@ def is_in_forked_version(self) -> bool: """Determine if the rule is in a forked version.""" if not self.has_forked: return False - locked_min_stack = Version.parse(self.lock_entry['min_stack_version'], optional_minor_and_patch=True) + if not self.lock_entry: + raise ValueError("No lock entry found") + locked_min_stack = Version.parse(self.lock_entry["min_stack_version"], optional_minor_and_patch=True) current_package_ver = Version.parse(load_current_package_version(), optional_minor_and_patch=True) return current_package_ver < locked_min_stack - def get_version_space(self) -> Optional[int]: + def get_version_space(self) -> int | None: """Retrieve the number of version spaces available (None for unbound).""" if self.is_in_forked_version: - current_entry = self.lock_entry['previous'][self.metadata.min_stack_version] - current_version = current_entry['version'] - max_allowable_version = current_entry['max_allowable_version'] + if not self.lock_entry: + raise ValueError("No lock entry found") + + current_entry = self.lock_entry["previous"][self.metadata.min_stack_version] # type: ignore[reportAttributeAccessIssue] + current_version = current_entry["version"] + max_allowable_version = current_entry["max_allowable_version"] return max_allowable_version - current_version - 1 + return None @property - def saved_version(self) -> Optional[int]: + def saved_version(self) -> int | None: """Retrieve the version from the version.lock or from the file if version locking is bypassed.""" - toml_version = self.data.get("version") + + toml_version = self.data.get("version") # type: ignore[reportAttributeAccessIssue] if BYPASS_VERSION_LOCK: - return toml_version + return toml_version # type: ignore[reportUnknownVariableType] if toml_version: - print(f"WARNING: Rule {self.name} - {self.id} has a version set in the rule TOML." - " This `version` will be ignored and defaulted to the version.lock.json file." - " Set `bypass_version_lock` to `True` in the rules config to use the TOML version.") + print( + f"WARNING: Rule {self.name} - {self.id} has a version set in the rule TOML." + " This `version` will be ignored and defaulted to the version.lock.json file." + " Set `bypass_version_lock` to `True` in the rules config to use the TOML version." + ) return self.version_lock.get_locked_version(self.id, self.get_supported_version()) @property - def autobumped_version(self) -> Optional[int]: + def autobumped_version(self) -> int | None: """Retrieve the current version of the rule, accounting for automatic increments.""" version = self.saved_version @@ -1092,7 +1125,7 @@ def autobumped_version(self) -> Optional[int]: # Auto-increment version if the rule is 'dirty' and not bypassing version lock return version + 1 if self.is_dirty else version - def get_synthetic_version(self, use_default: bool) -> Optional[int]: + def get_synthetic_version(self, use_default: bool) -> int | None: """ Get the latest actual representation of a rule's version, where changes are accounted for automatically when version locking is used, otherwise, return the version defined in the rule toml if present else optionally @@ -1101,7 +1134,7 @@ def get_synthetic_version(self, use_default: bool) -> Optional[int]: return self.autobumped_version or self.saved_version or (1 if use_default else None) @classmethod - def convert_supported_version(cls, stack_version: Optional[str]) -> Version: + def convert_supported_version(cls, stack_version: str | None) -> Version: """Convert an optional stack version to the minimum for the lock in the form major.minor.""" min_version = get_min_supported_stack_version() if stack_version is None: @@ -1110,11 +1143,11 @@ def convert_supported_version(cls, stack_version: Optional[str]) -> Version: def get_supported_version(self) -> str: """Get the lowest stack version for the rule that is currently supported in the form major.minor.""" - rule_min_stack = self.metadata.get('min_stack_version') - min_stack = self.convert_supported_version(rule_min_stack) + rule_min_stack = self.metadata.get("min_stack_version") # type: ignore[reportAttributeAccessIssue] + min_stack = self.convert_supported_version(rule_min_stack) # type: ignore[reportUnknownArgumentType] return f"{min_stack.major}.{min_stack.minor}" - def _post_dict_conversion(self, obj: dict) -> dict: + def _post_dict_conversion(self, obj: dict[str, Any]) -> dict[str, Any]: """Transform the converted API in place before sending to Kibana.""" # cleanup the whitespace in the rule @@ -1127,10 +1160,10 @@ def _post_dict_conversion(self, obj: dict) -> dict: return obj @abstractmethod - def to_api_format(self, include_version: bool = True) -> dict: + def to_api_format(self, include_version: bool = True) -> dict[str, Any]: """Convert the rule to the API format.""" - def get_hashable_content(self, include_version: bool = False, include_integrations: bool = False) -> dict: + def get_hashable_content(self, include_version: bool = False, include_integrations: bool = False) -> dict[str, Any]: """Returns the rule content to be used for calculating the hash value for the rule""" # get the API dict without the version by default, otherwise it'll always be dirty. @@ -1155,38 +1188,35 @@ def get_hash(self, include_version: bool = False, include_integrations: bool = F @dataclass(frozen=True) class TOMLRuleContents(BaseRuleContents, MarshmallowDataclassMixin): """Rule object which maps directly to the TOML layout.""" + metadata: RuleMeta - transform: Optional[RuleTransform] - data: AnyRuleData = field(metadata=dict(data_key="rule")) + data: AnyRuleData = field(metadata={"data_key": "rule"}) + transform: RuleTransform | None = None @cached_property - def version_lock(self): - # VersionLock - from .version_lock import loaded_version_lock - + def version_lock(self) -> VersionLock: # type: ignore[reportIncompatibleMethodOverride] if RULES_CONFIG.bypass_version_lock is True: - err_msg = "Cannot access the version lock when the versioning strategy is configured to bypass the" \ - " version lock. Set `bypass_version_lock` to `false` in the rules config to use the version lock." + err_msg = ( + "Cannot access the version lock when the versioning strategy is configured to bypass the" + " version lock. Set `bypass_version_lock` to `false` in the rules config to use the version lock." + ) raise ValueError(err_msg) - return getattr(self, '_version_lock', None) or loaded_version_lock - - def set_version_lock(self, value): - from .version_lock import VersionLock - - err_msg = "Cannot set the version lock when the versioning strategy is configured to bypass the version lock." \ - " Set `bypass_version_lock` to `false` in the rules config to use the version lock." - assert not RULES_CONFIG.bypass_version_lock, err_msg + return getattr(self, "_version_lock", None) or loaded_version_lock - if value and not isinstance(value, VersionLock): - raise TypeError(f'version lock property must be set with VersionLock objects only. Got {type(value)}') + def set_version_lock(self, value: VersionLock) -> None: + if RULES_CONFIG.bypass_version_lock: + raise ValueError( + "Cannot set the version lock when the versioning strategy is configured to bypass the version lock." + " Set `bypass_version_lock` to `false` in the rules config to use the version lock." + ) # circumvent frozen class - self.__dict__['_version_lock'] = value + self.__dict__["_version_lock"] = value @classmethod - def all_rule_types(cls) -> set: - types = set() + def all_rule_types(cls) -> set[str]: + types: set[str] = set() for subclass in typing.get_args(AnyRuleData): field = next(field for field in dataclasses.fields(subclass) if field.name == "type") types.update(typing.get_args(field.type)) @@ -1194,11 +1224,11 @@ def all_rule_types(cls) -> set: return types @classmethod - def get_data_subclass(cls, rule_type: str) -> typing.Type[BaseRuleData]: + def get_data_subclass(cls, rule_type: str) -> type[BaseRuleData]: """Get the proper subclass depending on the rule type""" for subclass in typing.get_args(AnyRuleData): field = next(field for field in dataclasses.fields(subclass) if field.name == "type") - if (rule_type, ) == typing.get_args(field.type): + if (rule_type,) == typing.get_args(field.type): return subclass raise ValueError(f"Unknown rule type {rule_type}") @@ -1215,23 +1245,25 @@ def name(self) -> str: def type(self) -> str: return self.data.type - def _add_known_nulls(self, rule_dict: dict) -> dict: + def _add_known_nulls(self, rule_dict: dict[str, Any]) -> dict[str, Any]: """Add known nulls to the rule.""" # Note this is primarily as a stopgap until add support for Rule Actions for pair in definitions.KNOWN_NULL_ENTRIES: for compound_key, sub_key in pair.items(): value = get_nested_value(rule_dict, compound_key) if isinstance(value, list): - items_to_update = [ - item for item in value if isinstance(item, dict) and get_nested_value(item, sub_key) is None + items_to_update: list[dict[str, Any]] = [ + item + for item in value # type: ignore[reportUnknownVariableType] + if isinstance(item, dict) and get_nested_value(item, sub_key) is None ] for item in items_to_update: set_nested_value(item, sub_key, None) return rule_dict - def _post_dict_conversion(self, obj: dict) -> dict: + def _post_dict_conversion(self, obj: dict[str, Any]) -> dict[str, Any]: """Transform the converted API in place before sending to Kibana.""" - super()._post_dict_conversion(obj) + _ = super()._post_dict_conversion(obj) # build time fields self._convert_add_related_integrations(obj) @@ -1239,16 +1271,16 @@ def _post_dict_conversion(self, obj: dict) -> dict: self._convert_add_setup(obj) # validate new fields against the schema - rule_type = obj['type'] + rule_type = obj["type"] subclass = self.get_data_subclass(rule_type) subclass.from_dict(obj) # rule type transforms - self.data.transform(obj) if hasattr(self.data, 'transform') else False + self.data.transform(obj) if hasattr(self.data, "transform") else False # type: ignore[reportAttributeAccessIssue] return obj - def _convert_add_related_integrations(self, obj: dict) -> None: + def _convert_add_related_integrations(self, obj: dict[str, Any]) -> None: """Add restricted field related_integrations to the obj.""" field_name = "related_integrations" package_integrations = obj.get(field_name, []) @@ -1257,41 +1289,47 @@ def _convert_add_related_integrations(self, obj: dict) -> None: packages_manifest = load_integrations_manifests() current_stack_version = load_current_package_version() - if self.check_restricted_field_version(field_name): - if (isinstance(self.data, QueryRuleData) or isinstance(self.data, MachineLearningRuleData)): - if (self.data.get('language') is not None and self.data.get('language') != 'lucene') or \ - self.data.get('type') == 'machine_learning': - package_integrations = self.get_packaged_integrations(self.data, self.metadata, - packages_manifest) - - if not package_integrations: - return - - for package in package_integrations: - package["version"] = find_least_compatible_version( - package=package["package"], - integration=package["integration"], - current_stack_version=current_stack_version, - packages_manifest=packages_manifest) - - # if integration is not a policy template remove - if package["version"]: - version_data = packages_manifest.get(package["package"], - {}).get(package["version"].strip("^"), {}) - policy_templates = version_data.get("policy_templates", []) - - if package["integration"] not in policy_templates: - del package["integration"] - - # remove duplicate entries - package_integrations = list({json.dumps(d, sort_keys=True): - d for d in package_integrations}.values()) - obj.setdefault("related_integrations", package_integrations) - - def _convert_add_required_fields(self, obj: dict) -> None: + if self.check_restricted_field_version(field_name) and isinstance( + self.data, QueryRuleData | MachineLearningRuleData + ): # type: ignore[reportUnnecessaryIsInstance] + if (self.data.get("language") is not None and self.data.get("language") != "lucene") or self.data.get( + "type" + ) == "machine_learning": + package_integrations = self.get_packaged_integrations( + self.data, # type: ignore[reportArgumentType] + self.metadata, + packages_manifest, + ) + + if not package_integrations: + return + + for package in package_integrations: + package["version"] = find_least_compatible_version( + package=package["package"], + integration=package["integration"], + current_stack_version=current_stack_version, + packages_manifest=packages_manifest, + ) + + # if integration is not a policy template remove + if package["version"]: + version_data = packages_manifest.get(package["package"], {}).get( + package["version"].strip("^"), {} + ) + policy_templates = version_data.get("policy_templates", []) + + if package["integration"] not in policy_templates: + del package["integration"] + + # remove duplicate entries + package_integrations = list({json.dumps(d, sort_keys=True): d for d in package_integrations}.values()) + obj.setdefault("related_integrations", package_integrations) + + def _convert_add_required_fields(self, obj: dict[str, Any]) -> None: """Add restricted field required_fields to the obj, derived from the query AST.""" - if isinstance(self.data, QueryRuleData) and self.data.language != 'lucene': - index = obj.get('index') or [] + if isinstance(self.data, QueryRuleData) and self.data.language != "lucene": + index: list[str] = obj.get("index") or [] required_fields = self.data.get_required_fields(index) else: required_fields = [] @@ -1300,7 +1338,7 @@ def _convert_add_required_fields(self, obj: dict) -> None: if required_fields and self.check_restricted_field_version(field_name=field_name): obj.setdefault(field_name, required_fields) - def _convert_add_setup(self, obj: dict) -> None: + def _convert_add_setup(self, obj: dict[str, Any]) -> None: """Add restricted field setup to the obj.""" rule_note = obj.get("note", "") field_name = "setup" @@ -1311,13 +1349,19 @@ def _convert_add_setup(self, obj: dict) -> None: data_validator = self.data.data_validator + if not data_validator: + raise ValueError("No data validator found") + if not data_validator.skip_validate_note and data_validator.setup_in_note and not field_value: parsed_note = self.data.parsed_note + if not parsed_note: + raise ValueError("No parsed note found") + # parse note tree for i, child in enumerate(parsed_note.children): - if child.get_type() == "Heading" and "Setup" in gfm.render(child): - field_value = self._convert_get_setup_content(parsed_note.children[i + 1:]) + if child.get_type() == "Heading" and "Setup" in gfm.render(child): # type: ignore[reportArgumentType] + field_value = self._convert_get_setup_content(parsed_note.children[i + 1 :]) # clean up old note field investigation_guide = rule_note.replace("## Setup\n\n", "") @@ -1327,14 +1371,14 @@ def _convert_add_setup(self, obj: dict) -> None: break @cached - def _convert_get_setup_content(self, note_tree: list) -> str: + def _convert_get_setup_content(self, note_tree: list[Any]) -> str: """Get note paragraph starting from the setup header.""" - setup = [] + setup: list[str] = [] for child in note_tree: if child.get_type() == "BlankLine" or child.get_type() == "LineBreak": setup.append("\n") elif child.get_type() == "CodeSpan": - setup.append(f"`{gfm.renderer.render_raw_text(child)}`") + setup.append(f"`{gfm.renderer.render_raw_text(child)}`") # type: ignore[reportUnknownMemberType] elif child.get_type() == "Paragraph": setup.append(self._convert_get_setup_content(child.children)) setup.append("\n") @@ -1343,7 +1387,7 @@ def _convert_get_setup_content(self, note_tree: list) -> str: setup.append("\n") elif child.get_type() == "RawText": setup.append(child.children) - elif child.get_type() == "Heading" and child.level >= 2: + elif child.get_type() == "Heading" and child.level >= 2: # noqa: PLR2004 break else: setup.append(self._convert_get_setup_content(child.children)) @@ -1353,11 +1397,17 @@ def _convert_get_setup_content(self, note_tree: list) -> str: def check_explicit_restricted_field_version(self, field_name: str) -> bool: """Explicitly check restricted fields against global min and max versions.""" min_stack, max_stack = BUILD_FIELD_VERSIONS[field_name] + if not min_stack or not max_stack: + return True return self.compare_field_versions(min_stack, max_stack) def check_restricted_field_version(self, field_name: str) -> bool: """Check restricted fields against schema min and max versions.""" - min_stack, max_stack = self.data.get_restricted_fields.get(field_name) + if not self.data.get_restricted_fields: + raise ValueError("No restricted fields found") + min_stack, max_stack = self.data.get_restricted_fields[field_name] + if not min_stack or not max_stack: + return True return self.compare_field_versions(min_stack, max_stack) @staticmethod @@ -1368,10 +1418,14 @@ def compare_field_versions(min_stack: Version, max_stack: Version) -> bool: return min_stack <= current_version >= max_stack @classmethod - def get_packaged_integrations(cls, data: QueryRuleData, meta: RuleMeta, - package_manifest: dict) -> Optional[List[dict]]: - packaged_integrations = [] - datasets, _ = beats.get_datasets_and_modules(data.get('ast') or []) + def get_packaged_integrations( + cls, + data: QueryRuleData, + meta: RuleMeta, + package_manifest: dict[str, Any], + ) -> list[dict[str, Any]] | None: + packaged_integrations: list[dict[str, Any]] = [] + datasets, _ = beats.get_datasets_and_modules(data.get("ast") or []) # type: ignore[reportArgumentType] # integration is None to remove duplicate references upstream in Kibana # chronologically, event.dataset is checked for package:integration, then rule tags @@ -1381,78 +1435,93 @@ def get_packaged_integrations(cls, data: QueryRuleData, meta: RuleMeta, rule_integrations = meta.get("integration", []) if rule_integrations: for integration in rule_integrations: - ineligible_integrations = definitions.NON_DATASET_PACKAGES + \ - [*map(str.lower, definitions.MACHINE_LEARNING_PACKAGES)] + ineligible_integrations = [ + *definitions.NON_DATASET_PACKAGES, + *map(str.lower, definitions.MACHINE_LEARNING_PACKAGES), + ] if integration in ineligible_integrations or isinstance(data, MachineLearningRuleData): packaged_integrations.append({"package": integration, "integration": None}) - packaged_integrations.extend(parse_datasets(datasets, package_manifest)) + packaged_integrations.extend(parse_datasets(list(datasets), package_manifest)) return packaged_integrations @validates_schema - def post_conversion_validation(self, value: dict, **kwargs): + def post_conversion_validation(self, value: dict[str, Any], **_: Any) -> None: """Additional validations beyond base marshmallow schemas.""" data: AnyRuleData = value["data"] metadata: RuleMeta = value["metadata"] + if not data.data_validator: + raise ValueError("No data validator found") + test_config = RULES_CONFIG.test_config - if not test_config.check_skip_by_rule_id(value['data'].rule_id): + if not test_config.check_skip_by_rule_id(value["data"].rule_id): + bypass = metadata.get("bypass_bbr_timing") or False data.validate_query(metadata) data.data_validator.validate_note() - data.data_validator.validate_bbr(metadata.get('bypass_bbr_timing')) - data.validate(metadata) if hasattr(data, 'validate') else False + data.data_validator.validate_bbr(bypass) + data.validate(metadata) if hasattr(data, "validate") else False # type: ignore[reportUnknownMemberType] @staticmethod - def validate_remote(remote_validator: 'RemoteValidator', contents: 'TOMLRuleContents'): - remote_validator.validate_rule(contents) + def validate_remote(remote_validator: "RemoteValidator", contents: "TOMLRuleContents") -> None: + _ = remote_validator.validate_rule(contents) @classmethod def from_rule_resource( - cls, rule: dict, creation_date: str = TIME_NOW, updated_date: str = TIME_NOW, maturity: str = 'development' - ) -> 'TOMLRuleContents': + cls, + rule: dict[str, Any], + creation_date: str = TIME_NOW, + updated_date: str = TIME_NOW, + maturity: str = "development", + ) -> "TOMLRuleContents": """Create a TOMLRuleContents from a kibana rule resource.""" - integrations = [r.get("package") for r in rule.get("related_integrations")] + integrations = [r["package"] for r in rule["related_integrations"]] meta = { "creation_date": creation_date, "updated_date": updated_date, "maturity": maturity, "integration": integrations, } - contents = cls.from_dict({'metadata': meta, 'rule': rule, 'transforms': None}, unknown=marshmallow.EXCLUDE) - return contents + return cls.from_dict({"metadata": meta, "rule": rule, "transforms": None}, unknown=marshmallow.EXCLUDE) - def to_dict(self, strip_none_values=True) -> dict: + def to_dict(self, strip_none_values: bool = True) -> dict[str, Any]: # Load schemas directly from the data and metadata classes to avoid schema ambiguity which can # result from union fields which contain classes and related subclasses (AnyRuleData). See issue #1141 metadata = self.metadata.to_dict(strip_none_values=strip_none_values) data = self.data.to_dict(strip_none_values=strip_none_values) - self.data.process_transforms(self.transform, data) - dict_obj = dict(metadata=metadata, rule=data) + if self.transform: + data = self.data.process_transforms(self.transform, data) + dict_obj = {"metadata": metadata, "rule": data} return nested_normalize(dict_obj) - def flattened_dict(self) -> dict: - flattened = dict() + def flattened_dict(self) -> dict[str, Any]: + flattened: dict[str, Any] = {} flattened.update(self.data.to_dict()) flattened.update(self.metadata.to_dict()) return flattened - def to_api_format(self, include_version: bool = not BYPASS_VERSION_LOCK, include_metadata: bool = False) -> dict: + def to_api_format( + self, + include_version: bool = not BYPASS_VERSION_LOCK, + include_metadata: bool = False, + ) -> dict[str, Any]: """Convert the TOML rule to the API format.""" + rule_dict = self.to_dict() rule_dict = self._add_known_nulls(rule_dict) - converted_data = rule_dict['rule'] + converted_data = rule_dict["rule"] converted = self._post_dict_conversion(converted_data) if include_metadata: - converted["meta"] = rule_dict['metadata'] + converted["meta"] = rule_dict["metadata"] if include_version: converted["version"] = self.autobumped_version return converted - def check_restricted_fields_compatibility(self) -> Dict[str, dict]: + def check_restricted_fields_compatibility(self) -> dict[str, dict[str, Any]]: """Check for compatibility between restricted fields and the min_stack_version of the rule.""" default_min_stack = get_min_supported_stack_version() if self.metadata.min_stack_version is not None: @@ -1461,12 +1530,19 @@ def check_restricted_fields_compatibility(self) -> Dict[str, dict]: min_stack = default_min_stack restricted = self.data.get_restricted_fields - invalid = {} + if not restricted: + raise ValueError("No restricted fields found") + + invalid: dict[str, dict[str, Any]] = {} for _field, values in restricted.items(): if self.data.get(_field) is not None: min_allowed, _ = values + + if not min_allowed: + raise ValueError("Min allowed versino is None") + if min_stack < min_allowed: - invalid[_field] = {'min_stack_version': min_stack, 'min_allowed_version': min_allowed} + invalid[_field] = {"min_stack_version": min_stack, "min_allowed_version": min_allowed} return invalid @@ -1474,132 +1550,137 @@ def check_restricted_fields_compatibility(self) -> Dict[str, dict]: @dataclass class TOMLRule: contents: TOMLRuleContents = field(hash=True) - path: Optional[Path] = None + path: Path | None = None gh_pr: Any = field(hash=False, compare=False, default=None, repr=False) @property - def id(self): + def id(self) -> definitions.UUIDString: return self.contents.id @property - def name(self): + def name(self) -> str: return self.contents.data.name - def get_asset(self) -> dict: + def get_asset(self) -> dict[str, Any]: """Generate the relevant fleet compatible asset.""" return {"id": self.id, "attributes": self.contents.to_api_format(), "type": definitions.SAVED_OBJECT_TYPE} def get_base_rule_dir(self) -> Path | None: """Get the base rule directory for the rule.""" + if not self.path: + raise ValueError("No path found") rule_path = self.path.resolve() for rules_dir in DEFAULT_PREBUILT_RULES_DIRS + DEFAULT_PREBUILT_BBR_DIRS: if rule_path.is_relative_to(rules_dir): return rule_path.relative_to(rules_dir) return None - def save_toml(self, strip_none_values: bool = True): - assert self.path is not None, f"Can't save rule {self.name} (self.id) without a path" - converted = dict( - metadata=self.contents.metadata.to_dict(), - rule=self.contents.data.to_dict(strip_none_values=strip_none_values), - ) + def save_toml(self, strip_none_values: bool = True) -> None: + if self.path is None: + raise ValueError(f"Can't save rule {self.name} (self.id) without a path") + + converted = { + "metadata": self.contents.metadata.to_dict(), + "rule": self.contents.data.to_dict(strip_none_values=strip_none_values), + } if self.contents.transform: converted["transform"] = self.contents.transform.to_dict() - toml_write(converted, str(self.path.absolute())) - def save_json(self, path: Path, include_version: bool = True): - path = path.with_suffix('.json') - with open(str(path.absolute()), 'w', newline='\n') as f: + if not self.path: + raise ValueError("No path found") + + toml_write(converted, self.path.absolute()) + + def save_json(self, path: Path, include_version: bool = True) -> None: + path = path.with_suffix(".json") + with path.absolute().open("w", newline="\n") as f: json.dump(self.contents.to_api_format(include_version=include_version), f, sort_keys=True, indent=2) - f.write('\n') + _ = f.write("\n") @dataclass(frozen=True) class DeprecatedRuleContents(BaseRuleContents): - metadata: dict - data: dict - transform: Optional[dict] + metadata: dict[str, Any] + data: dict[str, Any] + transform: dict[str, Any] | None = None @cached_property - def version_lock(self): + def version_lock(self) -> VersionLock: # type: ignore[reportIncompatibleMethodOverride] # VersionLock - from .version_lock import loaded_version_lock - - return getattr(self, '_version_lock', None) or loaded_version_lock + return getattr(self, "_version_lock", None) or loaded_version_lock - def set_version_lock(self, value): - from .version_lock import VersionLock - - err_msg = "Cannot set the version lock when the versioning strategy is configured to bypass the version lock." \ - " Set `bypass_version_lock` to `false` in the rules config to use the version lock." - assert not RULES_CONFIG.bypass_version_lock, err_msg - - if value and not isinstance(value, VersionLock): - raise TypeError(f'version lock property must be set with VersionLock objects only. Got {type(value)}') + def set_version_lock(self, value: VersionLock | None) -> None: + if RULES_CONFIG.bypass_version_lock: + raise ValueError( + "Cannot set the version lock when the versioning strategy is configured to bypass the version lock." + " Set `bypass_version_lock` to `false` in the rules config to use the version lock." + ) # circumvent frozen class - self.__dict__['_version_lock'] = value + self.__dict__["_version_lock"] = value @property - def id(self) -> str: - return self.data.get('rule_id') + def id(self) -> str | None: # type: ignore[reportIncompatibleMethodOverride] + return self.data.get("rule_id") @property - def name(self) -> str: - return self.data.get('name') + def name(self) -> str | None: # type: ignore[reportIncompatibleMethodOverride] + return self.data.get("name") @property - def type(self) -> str: - return self.data.get('type') + def type(self) -> str | None: # type: ignore[reportIncompatibleMethodOverride] + return self.data.get("type") @classmethod - def from_dict(cls, obj: dict): - kwargs = dict(metadata=obj['metadata'], data=obj['rule']) - kwargs['transform'] = obj['transform'] if 'transform' in obj else None + def from_dict(cls, obj: dict[str, Any]) -> "DeprecatedRuleContents": + kwargs = {"metadata": obj["metadata"], "data": obj["rule"]} + kwargs["transform"] = obj.get("transform") return cls(**kwargs) - def to_api_format(self, include_version: bool = not BYPASS_VERSION_LOCK) -> dict: + def to_api_format(self, include_version: bool = not BYPASS_VERSION_LOCK) -> dict[str, Any]: """Convert the TOML rule to the API format.""" data = copy.deepcopy(self.data) if self.transform: transform = RuleTransform.from_dict(self.transform) - BaseRuleData.process_transforms(transform, data) + _ = BaseRuleData.process_transforms(transform, data) converted = data if include_version: converted["version"] = self.autobumped_version - converted = self._post_dict_conversion(converted) - return converted + return self._post_dict_conversion(converted) -class DeprecatedRule(dict): +class DeprecatedRule(dict[str, Any]): """Minimal dict object for deprecated rule.""" - def __init__(self, path: Path, contents: DeprecatedRuleContents, *args, **kwargs): - super(DeprecatedRule, self).__init__(*args, **kwargs) + def __init__(self, path: Path, contents: DeprecatedRuleContents, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) self.path = path self.contents = contents - def __repr__(self): - return f'{type(self).__name__}(contents={self.contents}, path={self.path})' + def __repr__(self) -> str: + return f"{type(self).__name__}(contents={self.contents}, path={self.path})" @property - def id(self) -> str: + def id(self) -> str | None: return self.contents.id @property - def name(self) -> str: + def name(self) -> str | None: return self.contents.name -def downgrade_contents_from_rule(rule: TOMLRule, target_version: str, - replace_id: bool = True, include_metadata: bool = False) -> dict: +def downgrade_contents_from_rule( + rule: TOMLRule, + target_version: str, + replace_id: bool = True, + include_metadata: bool = False, +) -> dict[str, Any]: """Generate the downgraded contents from a rule.""" rule_dict = rule.contents.to_dict()["rule"] min_stack_version = target_version or rule.contents.metadata.min_stack_version or "8.3.0" - min_stack_version = Version.parse(min_stack_version, - optional_minor_and_patch=True) + min_stack_version = Version.parse(min_stack_version, optional_minor_and_patch=True) rule_dict.setdefault("meta", {}).update(rule.contents.metadata.to_dict()) if replace_id: @@ -1614,41 +1695,69 @@ def downgrade_contents_from_rule(rule: TOMLRule, target_version: str, rule_contents = TOMLRuleContents.from_dict(rule_contents_dict) payload = rule_contents.to_api_format(include_metadata=include_metadata) - payload = strip_non_public_fields(min_stack_version, payload) - return payload + return strip_non_public_fields(min_stack_version, payload) -def set_eql_config(min_stack_version: str) -> eql.parser.ParserConfig: +def set_eql_config(min_stack_version_val: str) -> eql.parser.ParserConfig: """Based on the rule version set the eql functions allowed.""" - if not min_stack_version: - min_stack_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True) + if min_stack_version_val: + min_stack_version = Version.parse(min_stack_version_val, optional_minor_and_patch=True) else: - min_stack_version = Version.parse(min_stack_version, optional_minor_and_patch=True) + min_stack_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True) config = eql.parser.ParserConfig() for feature, version_range in definitions.ELASTICSEARCH_EQL_FEATURES.items(): if version_range[0] <= min_stack_version <= (version_range[1] or min_stack_version): - config.context[feature] = True + config.context[feature] = True # type: ignore[reportUnknownMemberType] return config -def get_unique_query_fields(rule: TOMLRule) -> List[str]: +def get_unique_query_fields(rule: TOMLRule) -> list[str] | None: """Get a list of unique fields used in a rule query from rule contents.""" contents = rule.contents.to_api_format() - language = contents.get('language') - query = contents.get('query') - if language in ('kuery', 'eql'): - # TODO: remove once py-eql supports ipv6 for cidrmatch + language = contents.get("language") + query = contents.get("query") + if language not in ("kuery", "eql"): + return None + + # remove once py-eql supports ipv6 for cidrmatch + + min_stack_version = rule.contents.metadata.get("min_stack_version") + if not min_stack_version: + raise ValueError("Min stack version not found") + cfg = set_eql_config(min_stack_version) + with eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions, eql.parser.skip_optimizations, cfg: + parsed = ( # type: ignore[reportUnknownVariableType] + kql.parse(query, normalize_kql_keywords=RULES_CONFIG.normalize_kql_keywords) # type: ignore[reportUnknownMemberType] + if language == "kuery" + else eql.parse_query(query) # type: ignore[reportUnknownMemberType] + ) + return sorted({str(f) for f in parsed if isinstance(f, (eql.ast.Field | kql.ast.Field))}) # type: ignore[reportUnknownVariableType] + + +def parse_datasets(datasets: list[str], package_manifest: dict[str, Any]) -> list[dict[str, Any]]: + """Parses datasets into packaged integrations from rule data.""" + packaged_integrations: list[dict[str, Any]] = [] + for _value in sorted(datasets): + # cleanup extra quotes pulled from ast field + value = _value.strip('"') + + integration = "Unknown" + if "." in value: + package, integration = value.split(".", 1) + # Handle cases where endpoint event datasource needs to be parsed uniquely (e.g endpoint.events.network) + # as endpoint.network + if package == "endpoint" and "events" in integration: + integration = integration.split(".")[1] + else: + package = value - cfg = set_eql_config(rule.contents.metadata.get('min_stack_version')) - with eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions, eql.parser.skip_optimizations, cfg: - parsed = (kql.parse(query, normalize_kql_keywords=RULES_CONFIG.normalize_kql_keywords) - if language == 'kuery' else eql.parse_query(query)) - return sorted(set(str(f) for f in parsed if isinstance(f, (eql.ast.Field, kql.ast.Field)))) + if package in list(package_manifest): + packaged_integrations.append({"package": package, "integration": integration}) + return packaged_integrations # avoid a circular import from .rule_validators import EQLValidator, ESQLValidator, KQLValidator # noqa: E402 -from .remote_validation import RemoteValidator # noqa: E402 diff --git a/detection_rules/rule_formatter.py b/detection_rules/rule_formatter.py index f080ed0bf55..0702bdf8a89 100644 --- a/detection_rules/rule_formatter.py +++ b/detection_rules/rule_formatter.py @@ -4,13 +4,15 @@ # 2.0. """Helper functions for managing rules in the repository.""" + import copy import dataclasses -import io import json import textwrap -import typing from collections import OrderedDict +from collections.abc import Iterable +from pathlib import Path +from typing import Any, TextIO import toml @@ -24,63 +26,76 @@ @cached -def get_preserved_fmt_fields(): +def get_preserved_fmt_fields() -> set[str]: from .rule import BaseRuleData - preserved_keys = set() - for field in dataclasses.fields(BaseRuleData): # type: dataclasses.Field - if field.type in (definitions.Markdown, typing.Optional[definitions.Markdown]): + preserved_keys: set[str] = set() + + for field in dataclasses.fields(BaseRuleData): + if field.type in (definitions.Markdown, definitions.Markdown | None): preserved_keys.add(field.metadata.get("data_key", field.name)) return preserved_keys -def cleanup_whitespace(val): +def cleanup_whitespace(val: Any) -> Any: if isinstance(val, str): return " ".join(line.strip() for line in val.strip().splitlines()) return val -def nested_normalize(d, skip_cleanup=False): +def nested_normalize(d: Any, skip_cleanup: bool = False) -> Any: + preserved_fields = get_preserved_fmt_fields() + if isinstance(d, str): return d if skip_cleanup else cleanup_whitespace(d) - elif isinstance(d, list): - return [nested_normalize(val) for val in d] - elif isinstance(d, dict): - for k, v in d.items(): - if k == 'query': - # TODO: the linter still needs some work, but once up to par, uncomment to implement - kql.lint(v) + if isinstance(d, list): + return [nested_normalize(val) for val in d] # type: ignore[reportUnknownVariableType] + if isinstance(d, dict): + for k, v in d.items(): # type: ignore[reportUnknownVariableType] + if k == "query": + # the linter still needs some work, but once up to par, uncomment to implement - kql.lint(v) # do not normalize queries - d.update({k: v}) - elif k in get_preserved_fmt_fields(): + d.update({k: v}) # type: ignore[reportUnknownMemberType] + elif k in preserved_fields: # let these maintain newlines and whitespace for markdown support - d.update({k: nested_normalize(v, skip_cleanup=True)}) + d.update({k: nested_normalize(v, skip_cleanup=True)}) # type: ignore[reportUnknownMemberType] else: - d.update({k: nested_normalize(v)}) - return d - else: - return d + d.update({k: nested_normalize(v)}) # type: ignore[reportUnknownMemberType] + return d # type: ignore[reportUnknownVariableType] + return d -def wrap_text(v, block_indent=0, join=False): +def wrap_text(v: str, block_indent: int = 0) -> list[str]: """Block and indent a blob of text.""" - v = ' '.join(v.split()) - lines = textwrap.wrap(v, initial_indent=' ' * block_indent, subsequent_indent=' ' * block_indent, width=120, - break_long_words=False, break_on_hyphens=False) - lines = [line + '\n' for line in lines] + v = " ".join(v.split()) + lines = textwrap.wrap( + v, + initial_indent=" " * block_indent, + subsequent_indent=" " * block_indent, + width=120, + break_long_words=False, + break_on_hyphens=False, + ) + lines = [line + "\n" for line in lines] # If there is a single line that contains a quote, add a new blank line to trigger multiline formatting if len(lines) == 1 and '"' in lines[0]: - lines = lines + [''] - return lines if not join else ''.join(lines) + lines = [*lines, ""] + return lines -class NonformattedField(str): +def wrap_text_and_join(v: str, block_indent: int = 0) -> str: + lines = wrap_text(v, block_indent=block_indent) + return "".join(lines) + + +class NonformattedField(str): # noqa: SLOT000 """Non-formatting class.""" -def preserve_formatting_for_fields(data: OrderedDict, fields_to_preserve: list) -> OrderedDict: +def preserve_formatting_for_fields(data: OrderedDict[str, Any], fields_to_preserve: list[str]) -> OrderedDict[str, Any]: """Preserve formatting for specified nested fields in an action.""" - def apply_preservation(target: OrderedDict, keys: list) -> None: + def apply_preservation(target: OrderedDict[str, Any], keys: list[str]) -> None: """Apply NonformattedField preservation based on keys path.""" for key in keys[:-1]: # Iterate to the key, diving into nested dictionaries @@ -96,28 +111,28 @@ def apply_preservation(target: OrderedDict, keys: list) -> None: target[final_key] = NonformattedField(target[final_key]) for field_path in fields_to_preserve: - keys = field_path.split('.') + keys = field_path.split(".") apply_preservation(data, keys) return data -class RuleTomlEncoder(toml.TomlEncoder): +class RuleTomlEncoder(toml.TomlEncoder): # type: ignore[reportMissingTypeArgument] """Generate a pretty form of toml.""" - def __init__(self, _dict=dict, preserve=False): + def __init__(self, *args: Any, **kwargs: Any) -> None: """Create the encoder but override some default functions.""" - super(RuleTomlEncoder, self).__init__(_dict, preserve) + super().__init__(*args, **kwargs) # type: ignore[reportUnknownMemberType] self._old_dump_str = toml.TomlEncoder().dump_funcs[str] self._old_dump_list = toml.TomlEncoder().dump_funcs[list] self.dump_funcs[str] = self.dump_str - self.dump_funcs[type(u"")] = self.dump_str + self.dump_funcs[str] = self.dump_str self.dump_funcs[list] = self.dump_list self.dump_funcs[NonformattedField] = self.dump_str - def dump_str(self, v): + def dump_str(self, v: str | NonformattedField) -> str: """Change the TOML representation to multi-line or single quote when logical.""" - initial_newline = ['\n'] + initial_newline = ["\n"] if isinstance(v, NonformattedField): # first line break is not forced like other multiline string dumps @@ -132,135 +147,123 @@ def dump_str(self, v): if multiline: if raw: - return "".join([TRIPLE_DQ] + initial_newline + lines + [TRIPLE_DQ]) - else: - return "\n".join([TRIPLE_SQ] + [self._old_dump_str(line)[1:-1] for line in lines] + [TRIPLE_SQ]) - elif raw: - return u"'{:s}'".format(lines[0]) + return "".join([TRIPLE_DQ, *initial_newline, *lines, TRIPLE_DQ]) + return "\n".join([TRIPLE_SQ] + [self._old_dump_str(line)[1:-1] for line in lines] + [TRIPLE_SQ]) + if raw: + return f"'{lines[0]:s}'" return self._old_dump_str(v) - def _dump_flat_list(self, v): + def _dump_flat_list(self, v: Iterable[Any]) -> str: """A slightly tweaked version of original dump_list, removing trailing commas.""" if not v: return "[]" - retval = "[" + str(self.dump_value(v[0])) + "," - for u in v[1:]: + v_list = list(v) + + retval = "[" + str(self.dump_value(v_list[0])) + "," + for u in v_list[1:]: retval += " " + str(self.dump_value(u)) + "," - retval = retval.rstrip(',') + "]" - return retval + return retval.rstrip(",") + "]" - def dump_list(self, v): + def dump_list(self, v: Iterable[Any]) -> str: """Dump a list more cleanly.""" - if all([isinstance(d, str) for d in v]) and sum(len(d) + 3 for d in v) > 100: - dump = [] + if all(isinstance(d, str) for d in v) and sum(len(d) + 3 for d in v) > 100: # noqa: PLR2004 + dump: list[str] = [] for item in v: - if len(item) > (120 - 4 - 3 - 3) and ' ' in item: - dump.append(' """\n{} """'.format(wrap_text(item, block_indent=4, join=True))) + if len(item) > (120 - 4 - 3 - 3) and " " in item: + dump.append(f' """\n{wrap_text_and_join(item, block_indent=4)} """') else: - dump.append(' ' * 4 + self.dump_value(item)) - return '[\n{},\n]'.format(',\n'.join(dump)) + dump.append(" " * 4 + self.dump_value(item)) + return "[\n{},\n]".format(",\n".join(dump)) if v and all(isinstance(i, dict) for i in v): # Compact inline format for lists of dictionaries with proper indentation - retval = "\n" + ' ' * 2 + "[\n" - retval += ",\n".join([' ' * 4 + self.dump_inline_table(u).strip() for u in v]) - retval += "\n" + ' ' * 2 + "]\n" + retval = "\n" + " " * 2 + "[\n" + retval += ",\n".join([" " * 4 + self.dump_inline_table(u).strip() for u in v]) + retval += "\n" + " " * 2 + "]\n" return retval return self._dump_flat_list(v) -def toml_write(rule_contents, outfile=None): +def toml_write(rule_contents: dict[str, Any], out_file_path: Path | None = None) -> None: # noqa: PLR0915 """Write rule in TOML.""" - def write(text, nl=True): - if outfile: - outfile.write(text) - if nl: - outfile.write(u"\n") - else: - print(text, end='' if not nl else '\n') encoder = RuleTomlEncoder() contents = copy.deepcopy(rule_contents) - needs_close = False - def order_rule(obj): + def order_rule(obj: Any) -> Any: if isinstance(obj, dict): - obj = OrderedDict(sorted(obj.items())) + obj = OrderedDict(sorted(obj.items())) # type: ignore[reportUnknownArgumentType, reportUnknownVariableType] for k, v in obj.items(): - if isinstance(v, dict) or isinstance(v, list): + if isinstance(v, dict | list): obj[k] = order_rule(v) if isinstance(obj, list): - for i, v in enumerate(obj): - if isinstance(v, dict) or isinstance(v, list): + for i, v in enumerate(obj): # type: ignore[reportUnknownMemberType] + if isinstance(v, dict | list): obj[i] = order_rule(v) - obj = sorted(obj, key=lambda x: json.dumps(x)) + obj = sorted(obj, key=lambda x: json.dumps(x)) # type: ignore[reportUnknownArgumentType, reportUnknownVariableType] return obj - def _do_write(_data, _contents): + def _do_write(f: TextIO | None, _data: str, _contents: dict[str, Any]) -> None: # noqa: PLR0912 query = None threat_query = None - if _data == 'rule': + if _data == "rule": # - We want to avoid the encoder for the query and instead use kql-lint. # - Linting is done in rule.normalize() which is also called in rule.validate(). # - Until lint has tabbing, this is going to result in all queries being flattened with no wrapping, # but will at least purge extraneous white space - query = contents['rule'].pop('query', '').strip() + query = contents["rule"].pop("query", "").strip() # - As tags are expanding, we may want to reconsider the need to have them in alphabetical order - # tags = contents['rule'].get("tags", []) - # - # if tags and isinstance(tags, list): - # contents['rule']["tags"] = list(sorted(set(tags))) - threat_query = contents['rule'].pop('threat_query', '').strip() + threat_query = contents["rule"].pop("threat_query", "").strip() - top = OrderedDict() - bottom = OrderedDict() + top: OrderedDict[str, Any] = OrderedDict() + bottom: OrderedDict[str, Any] = OrderedDict() - for k in sorted(list(_contents)): + for k in sorted(_contents): v = _contents.pop(k) - if k == 'actions': + if k == "actions": # explicitly preserve formatting for message field in actions preserved_fields = ["params.message"] v = [preserve_formatting_for_fields(action, preserved_fields) for action in v] if v is not None else [] - if k == 'filters': + if k == "filters": # explicitly preserve formatting for value field in filters preserved_fields = ["meta.value"] v = [preserve_formatting_for_fields(meta, preserved_fields) for meta in v] if v is not None else [] - if k == 'note' and isinstance(v, str): + if k == "note" and isinstance(v, str): # Transform instances of \ to \\ as calling write will convert \\ to \. # This will ensure that the output file has the correct number of backslashes. v = v.replace("\\", "\\\\") - if k == 'setup' and isinstance(v, str): + if k == "setup" and isinstance(v, str): # Transform instances of \ to \\ as calling write will convert \\ to \. # This will ensure that the output file has the correct number of backslashes. v = v.replace("\\", "\\\\") - if k == 'description' and isinstance(v, str): + if k == "description" and isinstance(v, str): # Transform instances of \ to \\ as calling write will convert \\ to \. # This will ensure that the output file has the correct number of backslashes. v = v.replace("\\", "\\\\") - if k == 'osquery' and isinstance(v, list): + if k == "osquery" and isinstance(v, list): # Specifically handle transform.osquery queries - for osquery_item in v: - if 'query' in osquery_item and isinstance(osquery_item['query'], str): + for osquery_item in v: # type: ignore[reportUnknownVariableType] + if "query" in osquery_item and isinstance(osquery_item["query"], str): # Transform instances of \ to \\ as calling write will convert \\ to \. # This will ensure that the output file has the correct number of backslashes. - osquery_item['query'] = osquery_item['query'].replace("\\", "\\\\") + osquery_item["query"] = osquery_item["query"].replace("\\", "\\\\") # type: ignore[reportUnknownMemberType] if isinstance(v, dict): - bottom[k] = OrderedDict(sorted(v.items())) + bottom[k] = OrderedDict(sorted(v.items())) # type: ignore[reportUnknownArgumentType] elif isinstance(v, list): - if any([isinstance(value, (dict, list)) for value in v]): + if any(isinstance(value, (dict | list)) for value in v): # type: ignore[reportUnknownArgumentType] bottom[k] = v else: top[k] = v @@ -270,39 +273,40 @@ def _do_write(_data, _contents): top[k] = v if query: - top.update({'query': "XXxXX"}) + top.update({"query": "XXxXX"}) # type: ignore[reportUnknownMemberType] if threat_query: - top.update({'threat_query': "XXxXX"}) + top.update({"threat_query": "XXxXX"}) # type: ignore[reportUnknownMemberType] - top.update(bottom) - top = toml.dumps(OrderedDict({data: top}), encoder=encoder) + top.update(bottom) # type: ignore[reportUnknownMemberType] + top_out = toml.dumps(OrderedDict({data: top}), encoder=encoder) # type: ignore[reportUnknownMemberType] # we want to preserve the threat_query format, but want to modify it in the context of encoded dump if threat_query: - formatted_threat_query = "\nthreat_query = '''\n{}\n'''{}".format(threat_query, '\n\n' if bottom else '') - top = top.replace('threat_query = "XXxXX"', formatted_threat_query) + formatted_threat_query = "\nthreat_query = '''\n{}\n'''{}".format(threat_query, "\n\n" if bottom else "") + top_out = top_out.replace('threat_query = "XXxXX"', formatted_threat_query) # we want to preserve the query format, but want to modify it in the context of encoded dump if query: - formatted_query = "\nquery = '''\n{}\n'''{}".format(query, '\n\n' if bottom else '') - top = top.replace('query = "XXxXX"', formatted_query) + formatted_query = "\nquery = '''\n{}\n'''{}".format(query, "\n\n" if bottom else "") + top_out = top_out.replace('query = "XXxXX"', formatted_query) - write(top) - - try: + if f: + _ = f.write(top_out + "\n") + else: + print(top_out) - if outfile and not isinstance(outfile, io.IOBase): - needs_close = True - outfile = open(outfile, 'w') + f = None + if out_file_path: + f = out_file_path.open("w") - for data in ('metadata', 'transform', 'rule'): + try: + for data in ("metadata", "transform", "rule"): _contents = contents.get(data, {}) if not _contents: continue order_rule(_contents) - _do_write(data, _contents) - + _do_write(f, data, _contents) finally: - if needs_close and hasattr(outfile, "close"): - outfile.close() + if f: + f.close() diff --git a/detection_rules/rule_loader.py b/detection_rules/rule_loader.py index b2d943c9e7d..7fa84b32ff2 100644 --- a/detection_rules/rule_loader.py +++ b/detection_rules/rule_loader.py @@ -4,39 +4,47 @@ # 2.0. """Load rule metadata transform between rule and api formats.""" + +import json from collections import OrderedDict +from collections.abc import Callable, Iterable, Iterator from dataclasses import dataclass, field +from multiprocessing.pool import ThreadPool from pathlib import Path from subprocess import CalledProcessError -from typing import Callable, Dict, Iterable, List, Optional, Union +from typing import TYPE_CHECKING, Any import click -import pytoml -import json +import pytoml # type: ignore[reportMissingTypeStubs] +import requests +from github.File import File +from github.PullRequest import PullRequest from marshmallow.exceptions import ValidationError from . import utils from .config import parse_rules_config -from .rule import ( - DeprecatedRule, DeprecatedRuleContents, DictRule, TOMLRule, - TOMLRuleContents -) -from .schemas import definitions +from .ghwrap import GithubClient +from .rule import DeprecatedRule, DeprecatedRuleContents, DictRule, TOMLRule, TOMLRuleContents from .utils import cached, get_path +if TYPE_CHECKING: + from .schemas import definitions + from .version_lock import VersionLock + + RULES_CONFIG = parse_rules_config() DEFAULT_PREBUILT_RULES_DIRS = RULES_CONFIG.rule_dirs DEFAULT_PREBUILT_BBR_DIRS = RULES_CONFIG.bbr_rules_dirs -FILE_PATTERN = r'^([a-z0-9_])+\.(json|toml)$' +FILE_PATTERN = r"^([a-z0-9_])+\.(json|toml)$" -def path_getter(value: str) -> Callable[[dict], bool]: +def path_getter(value: str) -> Callable[[dict[str, Any]], Any]: """Get the path from a Python object.""" path = value.replace("__", ".").split(".") - def callback(obj: dict): + def callback(obj: dict[str, Any]) -> Any: for p in path: - if isinstance(obj, dict) and p in path: + if p in path: obj = obj[p] else: return None @@ -46,28 +54,36 @@ def callback(obj: dict): return callback -def dict_filter(_obj: Optional[dict] = None, **critieria) -> Callable[[dict], bool]: +def dict_filter(_obj: dict[str, Any] | None = None, **criteria: Any) -> Callable[[dict[str, Any]], bool]: """Get a callable that will return true if a dictionary matches a set of criteria. * each key is a dotted (or __ delimited) path into a dictionary to check * each value is a value or list of values to match """ - critieria.update(_obj or {}) - checkers = [(path_getter(k), set(v) if isinstance(v, (list, set, tuple)) else {v}) for k, v in critieria.items()] - - def callback(obj: dict) -> bool: + criteria.update(_obj or {}) + checkers = [ + # What if v is not be hashable? + (path_getter(k), set(v if isinstance(v, (list | set | tuple)) else (v,))) # type: ignore[reportUnknownArgumentType] + for k, v in criteria.items() + ] + + def callback(obj: dict[str, Any]) -> bool: for getter, expected in checkers: target_values = getter(obj) - target_values = set(target_values) if isinstance(target_values, (list, set, tuple)) else {target_values} + target_values = ( # type: ignore[reportUnknownVariableType] + set(target_values) # type: ignore[reportUnknownVariableType] + if isinstance(target_values, (list | set | tuple)) + else {target_values} + ) - return bool(expected.intersection(target_values)) + return bool(expected.intersection(target_values)) # type: ignore[reportUnknownArgumentType] return False return callback -def metadata_filter(**metadata) -> Callable[[TOMLRule], bool]: +def metadata_filter(**metadata: Any) -> Callable[[TOMLRule], bool]: """Get a filter callback based off rule metadata""" flt = dict_filter(metadata) @@ -81,48 +97,61 @@ def callback(rule: TOMLRule) -> bool: production_filter = metadata_filter(maturity="production") -def load_locks_from_tag(remote: str, tag: str, version_lock: str = 'detection_rules/etc/version.lock.json', - deprecated_file: str = 'detection_rules/etc/deprecated_rules.json') -> (str, dict, dict): +def load_locks_from_tag( + remote: str, + tag: str, + version_lock: str = "detection_rules/etc/version.lock.json", + deprecated_file: str = "detection_rules/etc/deprecated_rules.json", +) -> tuple[str, dict[str, Any], dict[str, Any]]: """Loads version and deprecated lock files from git tag.""" import json + git = utils.make_git() - exists_args = ['ls-remote'] + exists_args = ["ls-remote"] if remote: exists_args.append(remote) - exists_args.append(f'refs/tags/{tag}') + exists_args.append(f"refs/tags/{tag}") - assert git(*exists_args), f'tag: {tag} does not exist in {remote or "local"}' + if not git(*exists_args): + raise ValueError(f"tag: {tag} does not exist in {remote or 'local'}") - fetch_tags = ['fetch'] + fetch_tags = ["fetch"] if remote: - fetch_tags += [remote, '--tags', '-f', tag] + fetch_tags += [remote, "--tags", "-f", tag] else: - fetch_tags += ['--tags', '-f', tag] + fetch_tags += ["--tags", "-f", tag] - git(*fetch_tags) + _ = git(*fetch_tags) - commit_hash = git('rev-list', '-1', tag) + commit_hash = git("rev-list", "-1", tag) try: - version = json.loads(git('show', f'{tag}:{version_lock}')) + version = json.loads(git("show", f"{tag}:{version_lock}")) except CalledProcessError: # Adding resiliency to account for the old directory structure - version = json.loads(git('show', f'{tag}:etc/version.lock.json')) + version = json.loads(git("show", f"{tag}:etc/version.lock.json")) try: - deprecated = json.loads(git('show', f'{tag}:{deprecated_file}')) + deprecated = json.loads(git("show", f"{tag}:{deprecated_file}")) except CalledProcessError: # Adding resiliency to account for the old directory structure - deprecated = json.loads(git('show', f'{tag}:etc/deprecated_rules.json')) + deprecated = json.loads(git("show", f"{tag}:etc/deprecated_rules.json")) return commit_hash, version, deprecated -def update_metadata_from_file(rule_path: Path, fields_to_update: dict) -> dict: +def update_metadata_from_file(rule_path: Path, fields_to_update: dict[str, Any]) -> dict[str, Any]: """Update metadata fields for a rule with local contents.""" - contents = {} + + contents: dict[str, Any] = {} if not rule_path.exists(): return contents - local_metadata = RuleCollection().load_file(rule_path).contents.metadata.to_dict() + + rule_contents = RuleCollection().load_file(rule_path).contents + + if not isinstance(rule_contents, TOMLRuleContents): + raise TypeError("TOML rule expected") + + local_metadata = rule_contents.metadata.to_dict() if local_metadata: contents["maturity"] = local_metadata.get("maturity", "development") for field_name, should_update in fields_to_update.items(): @@ -132,34 +161,34 @@ def update_metadata_from_file(rule_path: Path, fields_to_update: dict) -> dict: @dataclass -class BaseCollection: +class BaseCollection[T]: """Base class for collections.""" - rules: list + rules: list[T] - def __len__(self): + def __len__(self) -> int: """Get the total amount of rules in the collection.""" return len(self.rules) - def __iter__(self): + def __iter__(self) -> Iterator[T]: """Iterate over all rules in the collection.""" return iter(self.rules) @dataclass -class DeprecatedCollection(BaseCollection): +class DeprecatedCollection(BaseCollection[DeprecatedRule]): """Collection of loaded deprecated rule dicts.""" - id_map: Dict[str, DeprecatedRule] = field(default_factory=dict) - file_map: Dict[Path, DeprecatedRule] = field(default_factory=dict) - name_map: Dict[str, DeprecatedRule] = field(default_factory=dict) - rules: List[DeprecatedRule] = field(default_factory=list) + id_map: dict[str, DeprecatedRule] = field(default_factory=dict) # type: ignore[reportUnknownVariableType] + file_map: dict[Path, DeprecatedRule] = field(default_factory=dict) # type: ignore[reportUnknownVariableType] + name_map: dict[str, DeprecatedRule] = field(default_factory=dict) # type: ignore[reportUnknownVariableType] + rules: list[DeprecatedRule] = field(default_factory=list) # type: ignore[reportUnknownVariableType] - def __contains__(self, rule: DeprecatedRule): + def __contains__(self, rule: DeprecatedRule) -> bool: """Check if a rule is in the map by comparing IDs.""" return rule.id in self.id_map - def filter(self, cb: Callable[[DeprecatedRule], bool]) -> 'RuleCollection': + def filter(self, cb: Callable[[DeprecatedRule], bool]) -> "RuleCollection": """Retrieve a filtered collection of rules.""" filtered_collection = RuleCollection() @@ -169,33 +198,33 @@ def filter(self, cb: Callable[[DeprecatedRule], bool]) -> 'RuleCollection': return filtered_collection -class RawRuleCollection(BaseCollection): +class RawRuleCollection(BaseCollection[DictRule]): """Collection of rules in raw dict form.""" __default = None __default_bbr = None - def __init__(self, rules: Optional[List[dict]] = None, ext_patterns: Optional[List[str]] = None): + def __init__(self, rules: list[DictRule] | None = None, ext_patterns: list[str] | None = None) -> None: """Create a new raw rule collection, with optional file ext pattern override.""" # ndjson is unsupported since it breaks the contract of 1 rule per file, so rules should be manually broken out # first - self.ext_patterns = ext_patterns or ['*.toml', '*.json'] - self.id_map: Dict[definitions.UUIDString, DictRule] = {} - self.file_map: Dict[Path, DictRule] = {} - self.name_map: Dict[definitions.RuleName, DictRule] = {} - self.rules: List[DictRule] = [] - self.errors: Dict[Path, Exception] = {} + self.ext_patterns = ext_patterns or ["*.toml", "*.json"] + self.id_map: dict[definitions.UUIDString, DictRule] = {} + self.file_map: dict[Path, DictRule] = {} + self.name_map: dict[definitions.RuleName, DictRule] = {} + self.rules: list[DictRule] = [] + self.errors: dict[Path, Exception] = {} self.frozen = False - self._raw_load_cache: Dict[Path, dict] = {} - for rule in (rules or []): + self._raw_load_cache: dict[Path, dict[str, Any]] = {} + for rule in rules or []: self.add_rule(rule) - def __contains__(self, rule: DictRule): + def __contains__(self, rule: DictRule) -> bool: """Check if a rule is in the map by comparing IDs.""" return rule.id in self.id_map - def filter(self, cb: Callable[[DictRule], bool]) -> 'RawRuleCollection': + def filter(self, cb: Callable[[DictRule], bool]) -> "RawRuleCollection": """Retrieve a filtered collection of rules.""" filtered_collection = RawRuleCollection() @@ -204,7 +233,7 @@ def filter(self, cb: Callable[[DictRule], bool]) -> 'RawRuleCollection': return filtered_collection - def _load_rule_file(self, path: Path) -> dict: + def _load_rule_file(self, path: Path) -> dict[str, Any]: """Load a rule file into a dictionary.""" if path in self._raw_load_cache: return self._raw_load_cache[path] @@ -213,49 +242,53 @@ def _load_rule_file(self, path: Path) -> dict: # use pytoml instead of toml because of annoying bugs # https://github.com/uiri/toml/issues/152 # might also be worth looking at https://github.com/sdispater/tomlkit - raw_dict = pytoml.loads(path.read_text()) + raw_dict = pytoml.loads(path.read_text()) # type: ignore[reportUnknownMemberType] elif path.suffix == ".json": raw_dict = json.loads(path.read_text()) elif path.suffix == ".ndjson": - raise ValueError('ndjson is not supported in RawRuleCollection. Break out the rules individually.') + raise ValueError("ndjson is not supported in RawRuleCollection. Break out the rules individually.") else: raise ValueError(f"Unsupported file type {path.suffix} for rule {path}") self._raw_load_cache[path] = raw_dict - return raw_dict + return raw_dict # type: ignore[reportUnknownVariableType] - def _get_paths(self, directory: Path, recursive=True) -> List[Path]: + def _get_paths(self, directory: Path, recursive: bool = True) -> list[Path]: """Get all paths in a directory that match the ext patterns.""" - paths = [] + paths: list[Path] = [] for pattern in self.ext_patterns: paths.extend(sorted(directory.rglob(pattern) if recursive else directory.glob(pattern))) return paths - def _assert_new(self, rule: DictRule): + def _assert_new(self, rule: DictRule) -> None: """Assert that a rule is new and can be added to the collection.""" id_map = self.id_map file_map = self.file_map name_map = self.name_map - assert not self.frozen, f"Unable to add rule {rule.name} {rule.id} to a frozen collection" - assert rule.id not in id_map, \ - f"Rule ID {rule.id} for {rule.name} collides with rule {id_map.get(rule.id).name}" - assert rule.name not in name_map, \ - f"Rule Name {rule.name} for {rule.id} collides with rule ID {name_map.get(rule.name).id}" + if self.frozen: + raise ValueError(f"Unable to add rule {rule.name} {rule.id} to a frozen collection") + + if rule.id in id_map: + raise ValueError(f"Rule ID {rule.id} for {rule.name} collides with rule {id_map[rule.id].name}") + + if rule.name in name_map: + raise ValueError(f"Rule Name {rule.name} for {rule.id} collides with rule ID {name_map[rule.name].id}") if rule.path is not None: rule_path = rule.path.resolve() - assert rule_path not in file_map, f"Rule file {rule_path} already loaded" + if rule_path in file_map: + raise ValueError(f"Rule file {rule_path} already loaded") file_map[rule_path] = rule - def add_rule(self, rule: DictRule): + def add_rule(self, rule: DictRule) -> None: """Add a rule to the collection.""" self._assert_new(rule) self.id_map[rule.id] = rule self.name_map[rule.name] = rule self.rules.append(rule) - def load_dict(self, obj: dict, path: Optional[Path] = None) -> DictRule: + def load_dict(self, obj: dict[str, Any], path: Path | None = None) -> DictRule: """Load a rule from a dictionary.""" rule = DictRule(contents=obj, path=path) self.add_rule(rule) @@ -267,11 +300,10 @@ def load_file(self, path: Path) -> DictRule: path = path.resolve() # use the default rule loader as a cache. # if it already loaded the rule, then we can just use it from that - if self.__default is not None and self is not self.__default: - if path in self.__default.file_map: - rule = self.__default.file_map[path] - self.add_rule(rule) - return rule + if self.__default and self is not self.__default and path in self.__default.file_map: + rule = self.__default.file_map[path] + self.add_rule(rule) + return rule obj = self._load_rule_file(path) return self.load_dict(obj, path=path) @@ -279,12 +311,17 @@ def load_file(self, path: Path) -> DictRule: print(f"Error loading rule in {path}") raise - def load_files(self, paths: Iterable[Path]): + def load_files(self, paths: Iterable[Path]) -> None: """Load multiple files into the collection.""" for path in paths: - self.load_file(path) - - def load_directory(self, directory: Path, recursive=True, obj_filter: Optional[Callable[[dict], bool]] = None): + _ = self.load_file(path) + + def load_directory( + self, + directory: Path, + recursive: bool = True, + obj_filter: Callable[..., bool] | None = None, + ) -> None: """Load all rules in a directory.""" paths = self._get_paths(directory, recursive=recursive) if obj_filter is not None: @@ -292,18 +329,22 @@ def load_directory(self, directory: Path, recursive=True, obj_filter: Optional[C self.load_files(paths) - def load_directories(self, directories: Iterable[Path], recursive=True, - obj_filter: Optional[Callable[[dict], bool]] = None): + def load_directories( + self, + directories: Iterable[Path], + recursive: bool = True, + obj_filter: Callable[..., bool] | None = None, + ) -> None: """Load all rules in multiple directories.""" for path in directories: self.load_directory(path, recursive=recursive, obj_filter=obj_filter) - def freeze(self): + def freeze(self) -> None: """Freeze the rule collection and make it immutable going forward.""" self.frozen = True @classmethod - def default(cls) -> 'RawRuleCollection': + def default(cls) -> "RawRuleCollection": """Return the default rule collection, which retrieves from rules/.""" if cls.__default is None: collection = RawRuleCollection() @@ -315,7 +356,7 @@ def default(cls) -> 'RawRuleCollection': return cls.__default @classmethod - def default_bbr(cls) -> 'RawRuleCollection': + def default_bbr(cls) -> "RawRuleCollection": """Return the default BBR collection, which retrieves from building_block_rules/.""" if cls.__default_bbr is None: collection = RawRuleCollection() @@ -326,34 +367,32 @@ def default_bbr(cls) -> 'RawRuleCollection': return cls.__default_bbr -class RuleCollection(BaseCollection): +class RuleCollection(BaseCollection[TOMLRule]): """Collection of rule objects.""" __default = None __default_bbr = None - def __init__(self, rules: Optional[List[TOMLRule]] = None): - from .version_lock import VersionLock - - self.id_map: Dict[definitions.UUIDString, TOMLRule] = {} - self.file_map: Dict[Path, TOMLRule] = {} - self.name_map: Dict[definitions.RuleName, TOMLRule] = {} - self.rules: List[TOMLRule] = [] + def __init__(self, rules: list[TOMLRule] | None = None) -> None: + self.id_map: dict[definitions.UUIDString, TOMLRule] = {} + self.file_map: dict[Path, TOMLRule] = {} + self.name_map: dict[definitions.RuleName, TOMLRule] = {} + self.rules: list[TOMLRule] = [] self.deprecated: DeprecatedCollection = DeprecatedCollection() - self.errors: Dict[Path, Exception] = {} + self.errors: dict[Path, Exception] = {} self.frozen = False - self._toml_load_cache: Dict[Path, dict] = {} - self._version_lock: Optional[VersionLock] = None + self._toml_load_cache: dict[Path, dict[str, Any]] = {} + self._version_lock: VersionLock | None = None - for rule in (rules or []): + for rule in rules or []: self.add_rule(rule) - def __contains__(self, rule: TOMLRule): + def __contains__(self, rule: TOMLRule) -> bool: """Check if a rule is in the map by comparing IDs.""" return rule.id in self.id_map - def filter(self, cb: Callable[[TOMLRule], bool]) -> 'RuleCollection': + def filter(self, cb: Callable[[TOMLRule], bool]) -> "RuleCollection": """Retrieve a filtered collection of rules.""" filtered_collection = RuleCollection() @@ -363,10 +402,10 @@ def filter(self, cb: Callable[[TOMLRule], bool]) -> 'RuleCollection': return filtered_collection @staticmethod - def deserialize_toml_string(contents: Union[bytes, str]) -> dict: - return pytoml.loads(contents) + def deserialize_toml_string(contents: bytes | str) -> dict[str, Any]: + return pytoml.loads(contents) # type: ignore[reportUnknownMemberType] - def _load_toml_file(self, path: Path) -> dict: + def _load_toml_file(self, path: Path) -> dict[str, Any]: if path in self._toml_load_cache: return self._toml_load_cache[path] @@ -378,10 +417,10 @@ def _load_toml_file(self, path: Path) -> dict: self._toml_load_cache[path] = toml_dict return toml_dict - def _get_paths(self, directory: Path, recursive=True) -> List[Path]: - return sorted(directory.rglob('*.toml') if recursive else directory.glob('*.toml')) + def _get_paths(self, directory: Path, recursive: bool = True) -> list[Path]: + return sorted(directory.rglob("*.toml") if recursive else directory.glob("*.toml")) - def _assert_new(self, rule: Union[TOMLRule, DeprecatedRule], is_deprecated=False): + def _assert_new(self, rule: TOMLRule | DeprecatedRule, is_deprecated: bool = False) -> None: if is_deprecated: id_map = self.deprecated.id_map file_map = self.deprecated.file_map @@ -391,47 +430,64 @@ def _assert_new(self, rule: Union[TOMLRule, DeprecatedRule], is_deprecated=False file_map = self.file_map name_map = self.name_map - assert not self.frozen, f"Unable to add rule {rule.name} {rule.id} to a frozen collection" - assert rule.id not in id_map, \ - f"Rule ID {rule.id} for {rule.name} collides with rule {id_map.get(rule.id).name}" - assert rule.name not in name_map, \ - f"Rule Name {rule.name} for {rule.id} collides with rule ID {name_map.get(rule.name).id}" + if not rule.id: + raise ValueError("Rule has no ID") + + if self.frozen: + raise ValueError(f"Unable to add rule {rule.name} {rule.id} to a frozen collection") + + if rule.id in id_map: + raise ValueError(f"Rule ID {rule.id} for {rule.name} collides with rule {id_map[rule.id].name}") + + if not rule.name: + raise ValueError("Rule has no name") + + if rule.name in name_map: + raise ValueError(f"Rule Name {rule.name} for {rule.id} collides with rule ID {name_map[rule.name].id}") if rule.path is not None: rule_path = rule.path.resolve() - assert rule_path not in file_map, f"Rule file {rule_path} already loaded" - file_map[rule_path] = rule + if rule_path in file_map: + raise ValueError(f"Rule file {rule_path} already loaded") + file_map[rule_path] = rule # type: ignore[reportArgumentType] - def add_rule(self, rule: TOMLRule): + def add_rule(self, rule: TOMLRule) -> None: self._assert_new(rule) self.id_map[rule.id] = rule self.name_map[rule.name] = rule self.rules.append(rule) - def add_deprecated_rule(self, rule: DeprecatedRule): + def add_deprecated_rule(self, rule: DeprecatedRule) -> None: self._assert_new(rule, is_deprecated=True) + + if not rule.id: + raise ValueError("Rule has no ID") + if not rule.name: + raise ValueError("Rule has no name") + self.deprecated.id_map[rule.id] = rule self.deprecated.name_map[rule.name] = rule self.deprecated.rules.append(rule) - def load_dict(self, obj: dict, path: Optional[Path] = None) -> Union[TOMLRule, DeprecatedRule]: + def load_dict(self, obj: dict[str, Any], path: Path | None = None) -> TOMLRule | DeprecatedRule: # bypass rule object load (load_dict) and load as a dict only - if obj.get('metadata', {}).get('maturity', '') == 'deprecated': + if obj.get("metadata", {}).get("maturity", "") == "deprecated": contents = DeprecatedRuleContents.from_dict(obj) if not RULES_CONFIG.bypass_version_lock: contents.set_version_lock(self._version_lock) + if not path: + raise ValueError("No path value provided") deprecated_rule = DeprecatedRule(path, contents) self.add_deprecated_rule(deprecated_rule) return deprecated_rule - else: - contents = TOMLRuleContents.from_dict(obj) - if not RULES_CONFIG.bypass_version_lock: - contents.set_version_lock(self._version_lock) - rule = TOMLRule(path=path, contents=contents) - self.add_rule(rule) - return rule + contents = TOMLRuleContents.from_dict(obj) + if not RULES_CONFIG.bypass_version_lock: + contents.set_version_lock(self._version_lock) + rule = TOMLRule(path=path, contents=contents) + self.add_rule(rule) + return rule - def load_file(self, path: Path) -> Union[TOMLRule, DeprecatedRule]: + def load_file(self, path: Path) -> TOMLRule | DeprecatedRule: try: path = path.resolve() @@ -442,7 +498,7 @@ def load_file(self, path: Path) -> Union[TOMLRule, DeprecatedRule]: rule = self.__default.file_map[path] self.add_rule(rule) return rule - elif path in self.__default.deprecated.file_map: + if path in self.__default.deprecated.file_map: deprecated_rule = self.__default.deprecated.file_map[path] self.add_deprecated_rule(deprecated_rule) return deprecated_rule @@ -453,36 +509,37 @@ def load_file(self, path: Path) -> Union[TOMLRule, DeprecatedRule]: print(f"Error loading rule in {path}") raise - def load_git_tag(self, branch: str, remote: Optional[str] = None, skip_query_validation=False): + def load_git_tag(self, branch: str, remote: str, skip_query_validation: bool = False) -> None: """Load rules from a Git branch.""" from .version_lock import VersionLock, add_rule_types_to_lock git = utils.make_git() - paths = [] + paths: list[str] = [] for rules_dir in DEFAULT_PREBUILT_RULES_DIRS: - rules_dir = rules_dir.relative_to(get_path(".")) - paths.extend(git("ls-tree", "-r", "--name-only", branch, rules_dir).splitlines()) + rdir = rules_dir.relative_to(get_path(["."])) + git_output = git("ls-tree", "-r", "--name-only", branch, rdir) + paths.extend(git_output.splitlines()) - rule_contents = [] - rule_map = {} + rule_contents: list[tuple[dict[str, Any], Path]] = [] + rule_map: dict[str, Any] = {} for path in paths: - path = Path(path) - if path.suffix != ".toml": + ppath = Path(path) + if ppath.suffix != ".toml": continue - contents = git("show", f"{branch}:{path}") + contents = git("show", f"{branch}:{ppath}") toml_dict = self.deserialize_toml_string(contents) if skip_query_validation: - toml_dict['metadata']['query_schema_validation'] = False + toml_dict["metadata"]["query_schema_validation"] = False - rule_contents.append((toml_dict, path)) - rule_map[toml_dict['rule']['rule_id']] = toml_dict + rule_contents.append((toml_dict, ppath)) + rule_map[toml_dict["rule"]["rule_id"]] = toml_dict commit_hash, v_lock, d_lock = load_locks_from_tag(remote, branch) - v_lock_name_prefix = f'{remote}/' if remote else '' - v_lock_name = f'{v_lock_name_prefix}{branch}-{commit_hash}' + v_lock_name_prefix = f"{remote}/" if remote else "" + v_lock_name = f"{v_lock_name_prefix}{branch}-{commit_hash}" # For backwards compatibility with tagged branches that existed before the types were added and validation # enforced, we will need to manually add the rule types to the version lock allow them to pass validation. @@ -494,34 +551,43 @@ def load_git_tag(self, branch: str, remote: Optional[str] = None, skip_query_val for rule_content in rule_contents: toml_dict, path = rule_content try: - self.load_dict(toml_dict, path) + _ = self.load_dict(toml_dict, path) except ValidationError as e: self.errors[path] = e continue - def load_files(self, paths: Iterable[Path]): + def load_files(self, paths: Iterable[Path]) -> None: """Load multiple files into the collection.""" for path in paths: - self.load_file(path) - - def load_directory(self, directory: Path, recursive=True, obj_filter: Optional[Callable[[dict], bool]] = None): + _ = self.load_file(path) + + def load_directory( + self, + directory: Path, + recursive: bool = True, + obj_filter: Callable[..., bool] | None = None, + ) -> None: paths = self._get_paths(directory, recursive=recursive) if obj_filter is not None: paths = [path for path in paths if obj_filter(self._load_toml_file(path))] self.load_files(paths) - def load_directories(self, directories: Iterable[Path], recursive=True, - obj_filter: Optional[Callable[[dict], bool]] = None): + def load_directories( + self, + directories: Iterable[Path], + recursive: bool = True, + obj_filter: Callable[..., bool] | None = None, + ) -> None: for path in directories: self.load_directory(path, recursive=recursive, obj_filter=obj_filter) - def freeze(self): + def freeze(self) -> None: """Freeze the rule collection and make it immutable going forward.""" self.frozen = True @classmethod - def default(cls) -> 'RuleCollection': + def default(cls) -> "RuleCollection": """Return the default rule collection, which retrieves from rules/.""" if cls.__default is None: collection = RuleCollection() @@ -533,7 +599,7 @@ def default(cls) -> 'RuleCollection': return cls.__default @classmethod - def default_bbr(cls) -> 'RuleCollection': + def default_bbr(cls) -> "RuleCollection": """Return the default BBR collection, which retrieves from building_block_rules/.""" if cls.__default_bbr is None: collection = RuleCollection() @@ -543,28 +609,32 @@ def default_bbr(cls) -> 'RuleCollection': return cls.__default_bbr - def compare_collections(self, other: 'RuleCollection' - ) -> (Dict[str, TOMLRule], Dict[str, TOMLRule], Dict[str, DeprecatedRule]): + def compare_collections( + self, other: "RuleCollection" + ) -> tuple[dict[str, TOMLRule], dict[str, TOMLRule], dict[str, DeprecatedRule]]: """Get the changes between two sets of rules.""" - assert self._version_lock, 'RuleCollection._version_lock must be set for self' - assert other._version_lock, 'RuleCollection._version_lock must be set for other' + if not self._version_lock: + raise ValueError("RuleCollection._version_lock must be set for self") + + if not other._version_lock: # noqa: SLF001 + raise ValueError("RuleCollection._version_lock must be set for other") # we cannot trust the assumption that either of the versions or deprecated files were pre-locked, which means we # have to perform additional checks beyond what is done in manage_versions - changed_rules = {} - new_rules = {} - newly_deprecated = {} + changed_rules: dict[str, TOMLRule] = {} + new_rules: dict[str, TOMLRule] = {} + newly_deprecated: dict[str, DeprecatedRule] = {} pre_versions_hash = utils.dict_hash(self._version_lock.version_lock.to_dict()) - post_versions_hash = utils.dict_hash(other._version_lock.version_lock.to_dict()) + post_versions_hash = utils.dict_hash(other._version_lock.version_lock.to_dict()) # noqa: SLF001 pre_deprecated_hash = utils.dict_hash(self._version_lock.deprecated_lock.to_dict()) - post_deprecated_hash = utils.dict_hash(other._version_lock.deprecated_lock.to_dict()) + post_deprecated_hash = utils.dict_hash(other._version_lock.deprecated_lock.to_dict()) # noqa: SLF001 if pre_versions_hash == post_versions_hash and pre_deprecated_hash == post_deprecated_hash: return changed_rules, new_rules, newly_deprecated for rule in other: - if rule.contents.metadata.maturity != 'production': + if rule.contents.metadata.maturity != "production": continue if rule.id not in self.id_map: @@ -575,46 +645,44 @@ def compare_collections(self, other: 'RuleCollection' changed_rules[rule.id] = rule for rule in other.deprecated: - if rule.id not in self.deprecated.id_map: + if rule.id and rule.id not in self.deprecated.id_map: newly_deprecated[rule.id] = rule return changed_rules, new_rules, newly_deprecated @cached -def load_github_pr_rules(labels: list = None, repo: str = 'elastic/detection-rules', token=None, threads=50, - verbose=True) -> (Dict[str, TOMLRule], Dict[str, TOMLRule], Dict[str, list]): +def load_github_pr_rules( + labels: list[str] | None = None, + repo_name: str = "elastic/detection-rules", + token: str | None = None, + threads: int = 50, + verbose: bool = True, +) -> tuple[dict[str, TOMLRule], dict[str, list[TOMLRule]], dict[str, list[str]]]: """Load all rules active as a GitHub PR.""" - from multiprocessing.pool import ThreadPool - from pathlib import Path - - import pytoml - import requests - - from .ghwrap import GithubClient github = GithubClient(token=token) - repo = github.client.get_repo(repo) - labels = set(labels or []) - open_prs = [r for r in repo.get_pulls() if not labels.difference(set(list(lbl.name for lbl in r.get_labels())))] + repo = github.client.get_repo(repo_name) + labels_set = set(labels or []) + open_prs = [r for r in repo.get_pulls() if not labels_set.difference({lbl.name for lbl in r.get_labels()})] - new_rules: List[TOMLRule] = [] - modified_rules: List[TOMLRule] = [] - errors: Dict[str, list] = {} + new_rules: list[TOMLRule] = [] + modified_rules: list[TOMLRule] = [] + errors: dict[str, list[str]] = {} existing_rules = RuleCollection.default() - pr_rules = [] + pr_rules: list[tuple[PullRequest, File]] = [] if verbose: - click.echo('Downloading rules from GitHub PRs') + click.echo("Downloading rules from GitHub PRs") - def download_worker(pr_info): + def download_worker(pr_info: tuple[PullRequest, File]) -> None: pull, rule_file = pr_info - response = requests.get(rule_file.raw_url) + response = requests.get(rule_file.raw_url, timeout=10) try: - raw_rule = pytoml.loads(response.text) - contents = TOMLRuleContents.from_dict(raw_rule) - rule = TOMLRule(path=rule_file.filename, contents=contents) + raw_rule = pytoml.loads(response.text) # type: ignore[reportUnknownVariableType] + contents = TOMLRuleContents.from_dict(raw_rule) # type: ignore[reportUnknownArgumentType] + rule = TOMLRule(path=Path(rule_file.filename), contents=contents) rule.gh_pr = pull if rule in existing_rules: @@ -622,20 +690,22 @@ def download_worker(pr_info): else: new_rules.append(rule) - except Exception as e: - errors.setdefault(Path(rule_file.filename).name, []).append(str(e)) + except Exception as e: # noqa: BLE001 + name = Path(rule_file.filename).name + errors.setdefault(name, []).append(str(e)) for pr in open_prs: - pr_rules.extend([(pr, f) for f in pr.get_files() - if f.filename.startswith('rules/') and f.filename.endswith('.toml')]) + pr_rules.extend( + [(pr, f) for f in pr.get_files() if f.filename.startswith("rules/") and f.filename.endswith(".toml")] + ) pool = ThreadPool(processes=threads) - pool.map(download_worker, pr_rules) + _ = pool.map(download_worker, pr_rules) pool.close() pool.join() new = OrderedDict([(rule.contents.id, rule) for rule in sorted(new_rules, key=lambda r: r.contents.name)]) - modified = OrderedDict() + modified: OrderedDict[str, list[TOMLRule]] = OrderedDict() for modified_rule in sorted(modified_rules, key=lambda r: r.contents.name): modified.setdefault(modified_rule.contents.id, []).append(modified_rule) @@ -644,15 +714,15 @@ def download_worker(pr_info): __all__ = ( - "FILE_PATTERN", - "DEFAULT_PREBUILT_RULES_DIRS", "DEFAULT_PREBUILT_BBR_DIRS", - "load_github_pr_rules", + "DEFAULT_PREBUILT_RULES_DIRS", + "FILE_PATTERN", "DeprecatedCollection", "DeprecatedRule", "RawRuleCollection", "RuleCollection", + "dict_filter", + "load_github_pr_rules", "metadata_filter", "production_filter", - "dict_filter", ) diff --git a/detection_rules/rule_validators.py b/detection_rules/rule_validators.py index 5fe957ec1a5..cbdd7fe2eb2 100644 --- a/detection_rules/rule_validators.py +++ b/detection_rules/rule_validators.py @@ -4,38 +4,40 @@ # 2.0. """Validation logic for rules containing queries.""" + import re +import typing +from collections.abc import Callable from enum import Enum from functools import cached_property, wraps -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any -import eql -from eql import ast -from eql.parser import KvTree, LarkToEQL, NodeInfo, TypeHint -from eql.parser import _parse as base_parse +import click +import eql # type: ignore[reportMissingTypeStubs] +import kql # type: ignore[reportMissingTypeStubs] +from eql import ast # type: ignore[reportMissingTypeStubs] +from eql.parser import KvTree, LarkToEQL, NodeInfo, TypeHint # type: ignore[reportMissingTypeStubs] +from eql.parser import _parse as base_parse # type: ignore[reportMissingTypeStubs] from marshmallow import ValidationError from semver import Version -import kql -import click - from . import ecs, endgame from .config import CUSTOM_RULES_DIR, load_current_package_version, parse_rules_config from .custom_schemas import update_auto_generated_schema -from .integrations import (get_integration_schema_data, - load_integrations_manifests) -from .rule import (EQLRuleData, QueryRuleData, QueryValidator, RuleMeta, - TOMLRuleContents, set_eql_config) +from .integrations import get_integration_schema_data, load_integrations_manifests +from .rule import EQLRuleData, QueryRuleData, QueryValidator, RuleMeta, TOMLRuleContents, set_eql_config from .schemas import get_stack_schemas -EQL_ERROR_TYPES = Union[eql.EqlCompileError, - eql.EqlError, - eql.EqlParseError, - eql.EqlSchemaError, - eql.EqlSemanticError, - eql.EqlSyntaxError, - eql.EqlTypeMismatchError] -KQL_ERROR_TYPES = Union[kql.KqlCompileError, kql.KqlParseError] +EQL_ERROR_TYPES = ( + eql.EqlCompileError + | eql.EqlError + | eql.EqlParseError + | eql.EqlSchemaError + | eql.EqlSemanticError + | eql.EqlSyntaxError + | eql.EqlTypeMismatchError +) +KQL_ERROR_TYPES = kql.KqlCompileError | kql.KqlParseError RULES_CONFIG = parse_rules_config() @@ -43,23 +45,27 @@ class ExtendedTypeHint(Enum): IP = "ip" @classmethod - def primitives(cls): + def primitives(cls): # noqa: ANN206 """Get all primitive types.""" return TypeHint.Boolean, TypeHint.Numeric, TypeHint.Null, TypeHint.String, ExtendedTypeHint.IP - def is_primitive(self): + def is_primitive(self) -> bool: """Check if a type is a primitive.""" return self in self.primitives() -def custom_in_set(self, node: KvTree) -> NodeInfo: +@typing.no_type_check +def custom_in_set(self: LarkToEQL, node: KvTree) -> NodeInfo: """Override and address the limitations of the eql in_set method.""" - # return BaseInSetMethod(self, node) - outer, container = self.visit(node.child_trees) # type: (NodeInfo, list[NodeInfo]) + response = self.visit(node.child_trees) + if not response: + raise ValueError("Child trees are not provided") + + outer, container = response if not outer.validate_type(ExtendedTypeHint.primitives()): # can't compare non-primitives to sets - raise self._type_error(outer, ExtendedTypeHint.primitives()) + raise self._type_error(outer, ExtendedTypeHint.primit()) # Check that everything inside the container has the same type as outside error_message = "Unable to compare {expected_type} to {actual_type}" @@ -91,8 +97,8 @@ def custom_base_parse_decorator(func: Callable[..., Any]) -> Callable[..., Any]: """Override and address the limitations of the eql in_set method.""" @wraps(func) - def wrapper(query: str, start: Optional[str] = None, **kwargs: Dict[str, Any]) -> Any: - original_in_set = LarkToEQL.in_set + def wrapper(query: str, start: str | None = None, **kwargs: dict[str, Any]) -> Any: + original_in_set = LarkToEQL.in_set # type: ignore[reportUnknownMemberType] LarkToEQL.in_set = custom_in_set try: result = func(query, start=start, **kwargs) @@ -103,40 +109,42 @@ def wrapper(query: str, start: Optional[str] = None, **kwargs: Dict[str, Any]) - return wrapper -eql.parser._parse = custom_base_parse_decorator(base_parse) +eql.parser._parse = custom_base_parse_decorator(base_parse) # type: ignore[reportPrivateUsage] # noqa: SLF001 class KQLValidator(QueryValidator): """Specific fields for KQL query event types.""" @cached_property - def ast(self) -> kql.ast.Expression: - return kql.parse(self.query, normalize_kql_keywords=RULES_CONFIG.normalize_kql_keywords) + def ast(self) -> kql.ast.Expression: # type: ignore[reportIncompatibleMethod] + return kql.parse(self.query, normalize_kql_keywords=RULES_CONFIG.normalize_kql_keywords) # type: ignore[reportUnknownMemberType] @cached_property - def unique_fields(self) -> List[str]: - return list(set(str(f) for f in self.ast if isinstance(f, kql.ast.Field))) + def unique_fields(self) -> list[str]: # type: ignore[reportIncompatibleMethod] + return list({str(f) for f in self.ast if isinstance(f, kql.ast.Field)}) # type: ignore[reportUnknownVariableType] def auto_add_field(self, validation_checks_error: kql.errors.KqlParseError, index_or_dataview: str) -> None: """Auto add a missing field to the schema.""" field_name = extract_error_field(self.query, validation_checks_error) + if not field_name: + raise ValueError("No fied name found for the error") field_type = ecs.get_all_flattened_schema().get(field_name) update_auto_generated_schema(index_or_dataview, field_name, field_type) def to_eql(self) -> eql.ast.Expression: - return kql.to_eql(self.query) + return kql.to_eql(self.query) # type: ignore[reportUnknownVariableType] - def validate(self, data: QueryRuleData, meta: RuleMeta, max_attempts: int = 10) -> None: + def validate(self, data: QueryRuleData, meta: RuleMeta, max_attempts: int = 10) -> None: # type: ignore[reportIncompatibleMethod] """Validate the query, called from the parent which contains [metadata] information.""" if meta.query_schema_validation is False or meta.maturity == "deprecated": # syntax only, which is done via self.ast return - if isinstance(data, QueryRuleData) and data.language != 'lucene': + if data.language != "lucene": packages_manifest = load_integrations_manifests() package_integrations = TOMLRuleContents.get_packaged_integrations(data, meta, packages_manifest) for _ in range(max_attempts): - validation_checks = {"stack": None, "integrations": None} + validation_checks: dict[str, KQL_ERROR_TYPES | None] = {"stack": None, "integrations": None} # validate the query against fields within beats validation_checks["stack"] = self.validate_stack_combos(data, meta) @@ -144,60 +152,63 @@ def validate(self, data: QueryRuleData, meta: RuleMeta, max_attempts: int = 10) # validate the query against related integration fields validation_checks["integrations"] = self.validate_integration(data, meta, package_integrations) - if (validation_checks["stack"] and not package_integrations): + if validation_checks["stack"] and not package_integrations: # if auto add, try auto adding and then call stack_combo validation again - if validation_checks["stack"].error_msg == "Unknown field" and RULES_CONFIG.auto_gen_schema_file: + if validation_checks["stack"].error_msg == "Unknown field" and RULES_CONFIG.auto_gen_schema_file: # type: ignore[reportAttributeAccessIssue] # auto add the field and re-validate - self.auto_add_field(validation_checks["stack"], data.index_or_dataview[0]) + self.auto_add_field(validation_checks["stack"], data.index_or_dataview[0]) # type: ignore[reportArgumentType] else: raise validation_checks["stack"] - if (validation_checks["stack"] and validation_checks["integrations"]): + if validation_checks["stack"] and validation_checks["integrations"]: # if auto add, try auto adding and then call stack_combo validation again - if validation_checks["stack"].error_msg == "Unknown field" and RULES_CONFIG.auto_gen_schema_file: + if validation_checks["stack"].error_msg == "Unknown field" and RULES_CONFIG.auto_gen_schema_file: # type: ignore[reportAttributeAccessIssue] # auto add the field and re-validate - self.auto_add_field(validation_checks["stack"], data.index_or_dataview[0]) + self.auto_add_field(validation_checks["stack"], data.index_or_dataview[0]) # type: ignore[reportArgumentType] else: - click.echo(f"Stack Error Trace: {validation_checks["stack"]}") - click.echo(f"Integrations Error Trace: {validation_checks["integrations"]}") + click.echo(f"Stack Error Trace: {validation_checks['stack']}") + click.echo(f"Integrations Error Trace: {validation_checks['integrations']}") raise ValueError("Error in both stack and integrations checks") else: break - else: - raise ValueError(f"Maximum validation attempts exceeded for {data.rule_id} - {data.name}") - - def validate_stack_combos(self, data: QueryRuleData, meta: RuleMeta) -> Union[KQL_ERROR_TYPES, None, TypeError]: + def validate_stack_combos(self, data: QueryRuleData, meta: RuleMeta) -> KQL_ERROR_TYPES | None: """Validate the query against ECS and beats schemas across stack combinations.""" for stack_version, mapping in meta.get_validation_stack_versions().items(): - beats_version = mapping['beats'] - ecs_version = mapping['ecs'] - err_trailer = f'stack: {stack_version}, beats: {beats_version}, ecs: {ecs_version}' + beats_version = mapping["beats"] + ecs_version = mapping["ecs"] + err_trailer = f"stack: {stack_version}, beats: {beats_version}, ecs: {ecs_version}" - beat_types, beat_schema, schema = self.get_beats_schema(data.index_or_dataview, - beats_version, ecs_version) + beat_types, _, schema = self.get_beats_schema(data.index_or_dataview, beats_version, ecs_version) try: - kql.parse(self.query, schema=schema, normalize_kql_keywords=RULES_CONFIG.normalize_kql_keywords) + kql.parse(self.query, schema=schema, normalize_kql_keywords=RULES_CONFIG.normalize_kql_keywords) # type: ignore[reportUnknownMemberType] except kql.KqlParseError as exc: message = exc.error_msg trailer = err_trailer if "Unknown field" in message and beat_types: trailer = f"\nTry adding event.module or event.dataset to specify beats module\n\n{trailer}" - return kql.KqlParseError(exc.error_msg, exc.line, exc.column, exc.source, - len(exc.caret.lstrip()), trailer=trailer) - except Exception as exc: - print(err_trailer) - return exc + return kql.KqlParseError( + exc.error_msg, # type: ignore[reportUnknownArgumentType] + exc.line, # type: ignore[reportUnknownArgumentType] + exc.column, # type: ignore[reportUnknownArgumentType] + exc.source, # type: ignore[reportUnknownArgumentType] + len(exc.caret.lstrip()), + trailer=trailer, + ) + return None - def validate_integration( - self, data: QueryRuleData, meta: RuleMeta, package_integrations: List[dict] - ) -> Union[KQL_ERROR_TYPES, None, TypeError]: + def validate_integration( # noqa: PLR0912 + self, + data: QueryRuleData, + meta: RuleMeta, + package_integrations: list[dict[str, Any]], + ) -> KQL_ERROR_TYPES | None: """Validate the query, called from the parent which contains [metadata] information.""" if meta.query_schema_validation is False or meta.maturity == "deprecated": - return + return None error_fields = {} package_schemas = {} @@ -207,14 +218,12 @@ def validate_integration( package = integration_data["package"] integration = integration_data["integration"] if integration: - package_schemas.setdefault(package, {}).setdefault(integration, {}) + package_schemas.setdefault(package, {}).setdefault(integration, {}) # type: ignore[reportUnknownMemberType] else: - package_schemas.setdefault(package, {}) + package_schemas.setdefault(package, {}) # type: ignore[reportUnknownMemberType] # Process each integration schema - for integration_schema_data in get_integration_schema_data( - data, meta, package_integrations - ): + for integration_schema_data in get_integration_schema_data(data, meta, package_integrations): package, integration = ( integration_schema_data["package"], integration_schema_data["integration"], @@ -240,9 +249,11 @@ def validate_integration( # Validate the query against the schema try: - kql.parse(self.query, - schema=integration_schema, - normalize_kql_keywords=RULES_CONFIG.normalize_kql_keywords) + kql.parse( # type: ignore[reportUnknownMemberType] + self.query, + schema=integration_schema, + normalize_kql_keywords=RULES_CONFIG.normalize_kql_keywords, + ) except kql.KqlParseError as exc: if exc.error_msg == "Unknown field": field = extract_error_field(self.query, exc) @@ -260,26 +271,24 @@ def validate_integration( "integration": integration, } if data.get("notify", False): - print( - f"\nWarning: `{field}` in `{data.name}` not found in schema. {trailer}" - ) + print(f"\nWarning: `{field}` in `{data.name}` not found in schema. {trailer}") else: return kql.KqlParseError( - exc.error_msg, - exc.line, - exc.column, - exc.source, + exc.error_msg, # type: ignore[reportUnknownArgumentType] + exc.line, # type: ignore[reportUnknownArgumentType] + exc.column, # type: ignore[reportUnknownArgumentType] + exc.source, # type: ignore[reportUnknownArgumentType] len(exc.caret.lstrip()), - exc.trailer, + exc.trailer, # type: ignore[reportUnknownArgumentType] ) # Check error fields against schemas of different packages or different integrations - for field, error_data in list(error_fields.items()): - error_package, error_integration = ( + for field, error_data in list(error_fields.items()): # type: ignore[reportUnknownArgumentType] + error_package, error_integration = ( # type: ignore[reportUnknownVariableType] error_data["package"], error_data["integration"], ) - for package, integrations_or_schema in package_schemas.items(): + for package, integrations_or_schema in package_schemas.items(): # type: ignore[reportUnknownVariableType] if error_integration is None: # Compare against the schema directly if there's no integration if error_package != package and field in integrations_or_schema: @@ -287,95 +296,100 @@ def validate_integration( break else: # Compare against integration schemas - for integration, schema in integrations_or_schema.items(): - check_alt_schema = ( - error_package != package or # noqa: W504 - (error_package == package and error_integration != integration) + for integration, schema in integrations_or_schema.items(): # type: ignore[reportUnknownMemberType] + check_alt_schema = error_package != package or ( # type: ignore[reportUnknownVariableType] + error_package == package and error_integration != integration ) if check_alt_schema and field in schema: del error_fields[field] # Raise the first error if error_fields: - _, error_data = next(iter(error_fields.items())) + _, error_data = next(iter(error_fields.items())) # type: ignore[reportUnknownVariableType] return kql.KqlParseError( - error_data["error"].error_msg, - error_data["error"].line, - error_data["error"].column, - error_data["error"].source, - len(error_data["error"].caret.lstrip()), - error_data["trailer"], + error_data["error"].error_msg, # type: ignore[reportUnknownArgumentType] + error_data["error"].line, # type: ignore[reportUnknownArgumentType] + error_data["error"].column, # type: ignore[reportUnknownArgumentType] + error_data["error"].source, # type: ignore[reportUnknownArgumentType] + len(error_data["error"].caret.lstrip()), # type: ignore[reportUnknownArgumentType] + error_data["trailer"], # type: ignore[reportUnknownArgumentType] ) + return None class EQLValidator(QueryValidator): """Specific fields for EQL query event types.""" @cached_property - def ast(self) -> eql.ast.Expression: + def ast(self) -> eql.ast.Expression: # type: ignore[reportIncompatibleMethodOverrichemas] latest_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True) cfg = set_eql_config(str(latest_version)) with eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions, eql.parser.skip_optimizations, cfg: - return eql.parse_query(self.query) + return eql.parse_query(self.query) # type: ignore[reportUnknownVariableType] - def text_fields(self, eql_schema: Union[ecs.KqlSchema2Eql, endgame.EndgameSchema]) -> List[str]: + def text_fields(self, eql_schema: ecs.KqlSchema2Eql | endgame.EndgameSchema) -> list[str]: """Return a list of fields of type text.""" - from kql.parser import elasticsearch_type_family + from kql.parser import elasticsearch_type_family # type: ignore[reportMissingTypeStubs] + schema = eql_schema.kql_schema if isinstance(eql_schema, ecs.KqlSchema2Eql) else eql_schema.endgame_schema - return [f for f in self.unique_fields if elasticsearch_type_family(schema.get(f)) == 'text'] + return [f for f in self.unique_fields if elasticsearch_type_family(schema.get(f)) == "text"] # type: ignore[reportArgumentType] @cached_property - def unique_fields(self) -> List[str]: - return list(set(str(f) for f in self.ast if isinstance(f, eql.ast.Field))) + def unique_fields(self) -> list[str]: # type: ignore[reportIncompatibleMethodOverride] + return list({str(f) for f in self.ast if isinstance(f, eql.ast.Field)}) # type: ignore[reportUnknownVariableType] def auto_add_field(self, validation_checks_error: eql.errors.EqlParseError, index_or_dataview: str) -> None: """Auto add a missing field to the schema.""" field_name = extract_error_field(self.query, validation_checks_error) + if not field_name: + raise ValueError("No field name found") field_type = ecs.get_all_flattened_schema().get(field_name) update_auto_generated_schema(index_or_dataview, field_name, field_type) - def validate(self, data: "QueryRuleData", meta: RuleMeta, max_attempts: int = 10) -> None: + def validate(self, data: "QueryRuleData", meta: RuleMeta, max_attempts: int = 10) -> None: # type: ignore[reportIncompatibleMethodOverride] # noqa: PLR0912 """Validate an EQL query while checking TOMLRule.""" if meta.query_schema_validation is False or meta.maturity == "deprecated": # syntax only, which is done via self.ast return - if isinstance(data, QueryRuleData) and data.language != "lucene": + if data.language != "lucene": packages_manifest = load_integrations_manifests() package_integrations = TOMLRuleContents.get_packaged_integrations(data, meta, packages_manifest) for _ in range(max_attempts): validation_checks = {"stack": None, "integrations": None} # validate the query against fields within beats - validation_checks["stack"] = self.validate_stack_combos(data, meta) + validation_checks["stack"] = self.validate_stack_combos(data, meta) # type: ignore[reportArgumentType] + + stack_check = validation_checks["stack"] if package_integrations: # validate the query against related integration fields - validation_checks["integrations"] = self.validate_integration(data, meta, package_integrations) + validation_checks["integrations"] = self.validate_integration(data, meta, package_integrations) # type: ignore[reportArgumentType] - if validation_checks["stack"] and not package_integrations: + if stack_check and not package_integrations: # if auto add, try auto adding and then validate again if ( - "Field not recognized" in validation_checks["stack"].error_msg - and RULES_CONFIG.auto_gen_schema_file # noqa: W503 + "Field not recognized" in str(stack_check) # type: ignore[reportUnknownMemberType] + and RULES_CONFIG.auto_gen_schema_file ): # auto add the field and re-validate - self.auto_add_field(validation_checks["stack"], data.index_or_dataview[0]) + self.auto_add_field(stack_check, data.index_or_dataview[0]) # type: ignore[reportArgumentType] else: - raise validation_checks["stack"] + raise stack_check - elif validation_checks["stack"] and validation_checks["integrations"]: + elif stack_check and validation_checks["integrations"]: # if auto add, try auto adding and then validate again if ( - "Field not recognized" in validation_checks["stack"].error_msg - and RULES_CONFIG.auto_gen_schema_file # noqa: W503 + "Field not recognized" in stack_check.error_msg # type: ignore[reportUnknownMemberType] + and RULES_CONFIG.auto_gen_schema_file ): # auto add the field and re-validate - self.auto_add_field(validation_checks["stack"], data.index_or_dataview[0]) + self.auto_add_field(stack_check, data.index_or_dataview[0]) # type: ignore[reportArgumentType] else: - click.echo(f"Stack Error Trace: {validation_checks["stack"]}") - click.echo(f"Integrations Error Trace: {validation_checks["integrations"]}") + click.echo(f"Stack Error Trace: {stack_check}") + click.echo(f"Integrations Error Trace: {validation_checks['integrations']}") raise ValueError("Error in both stack and integrations checks") else: @@ -385,7 +399,8 @@ def validate(self, data: "QueryRuleData", meta: RuleMeta, max_attempts: int = 10 raise ValueError(f"Maximum validation attempts exceeded for {data.rule_id} - {data.name}") rule_type_config_fields, rule_type_config_validation_failed = self.validate_rule_type_configurations( - data, meta + data, # type: ignore[reportArgumentType] + meta, ) if rule_type_config_validation_failed: raise ValueError( @@ -393,42 +408,56 @@ def validate(self, data: "QueryRuleData", meta: RuleMeta, max_attempts: int = 10 {rule_type_config_fields}""" ) - def validate_stack_combos(self, data: QueryRuleData, meta: RuleMeta) -> Union[EQL_ERROR_TYPES, None, ValueError]: + def validate_stack_combos(self, data: QueryRuleData, meta: RuleMeta) -> EQL_ERROR_TYPES | None | ValueError: """Validate the query against ECS and beats schemas across stack combinations.""" for stack_version, mapping in meta.get_validation_stack_versions().items(): - beats_version = mapping['beats'] - ecs_version = mapping['ecs'] - endgame_version = mapping['endgame'] - err_trailer = f'stack: {stack_version}, beats: {beats_version},' \ - f'ecs: {ecs_version}, endgame: {endgame_version}' - - beat_types, beat_schema, schema = self.get_beats_schema(data.index_or_dataview, - beats_version, ecs_version) + beats_version = mapping["beats"] + ecs_version = mapping["ecs"] + endgame_version = mapping["endgame"] + err_trailer = ( + f"stack: {stack_version}, beats: {beats_version},ecs: {ecs_version}, endgame: {endgame_version}" + ) + + beat_types, _, schema = self.get_beats_schema(data.index_or_dataview, beats_version, ecs_version) endgame_schema = self.get_endgame_schema(data.index_or_dataview, endgame_version) eql_schema = ecs.KqlSchema2Eql(schema) # validate query against the beats and eql schema - exc = self.validate_query_with_schema(data=data, schema=eql_schema, err_trailer=err_trailer, - beat_types=beat_types, min_stack_version=meta.min_stack_version) + exc = self.validate_query_with_schema( # type: ignore[reportUnknownVariableType] + data=data, + schema=eql_schema, + err_trailer=err_trailer, + beat_types=beat_types, + min_stack_version=meta.min_stack_version, # type: ignore[reportArgumentType] + ) if exc: return exc if endgame_schema: # validate query against the endgame schema - exc = self.validate_query_with_schema(data=data, schema=endgame_schema, err_trailer=err_trailer, - min_stack_version=meta.min_stack_version) + exc = self.validate_query_with_schema( + data=data, + schema=endgame_schema, + err_trailer=err_trailer, + min_stack_version=meta.min_stack_version, # type: ignore[reportArgumentType] + ) if exc: raise exc + return None - def validate_integration(self, data: QueryRuleData, meta: RuleMeta, - package_integrations: List[dict]) -> Union[EQL_ERROR_TYPES, None, ValueError]: + def validate_integration( # noqa: PLR0912 + self, + data: QueryRuleData, + meta: RuleMeta, + package_integrations: list[dict[str, Any]], + ) -> EQL_ERROR_TYPES | None | ValueError: """Validate an EQL query while checking TOMLRule against integration schemas.""" if meta.query_schema_validation is False or meta.maturity == "deprecated": # syntax only, which is done via self.ast - return + return None error_fields = {} - package_schemas = {} + package_schemas: dict[str, Any] = {} # Initialize package_schemas with a nested structure for integration_data in package_integrations: @@ -440,9 +469,7 @@ def validate_integration(self, data: QueryRuleData, meta: RuleMeta, package_schemas.setdefault(package, {}) # Process each integration schema - for integration_schema_data in get_integration_schema_data( - data, meta, package_integrations - ): + for integration_schema_data in get_integration_schema_data(data, meta, package_integrations): ecs_version = integration_schema_data["ecs_version"] package, integration = ( integration_schema_data["package"], @@ -477,11 +504,11 @@ def validate_integration(self, data: QueryRuleData, meta: RuleMeta, data=data, schema=eql_schema, err_trailer=err_trailer, - min_stack_version=meta.min_stack_version, + min_stack_version=meta.min_stack_version, # type: ignore[reportArgumentType] ) if isinstance(exc, eql.EqlParseError): - message = exc.error_msg + message = exc.error_msg # type: ignore[reportUnknownVariableType] if message == "Unknown field" or "Field not recognized" in message: field = extract_error_field(self.query, exc) trailer = ( @@ -497,15 +524,13 @@ def validate_integration(self, data: QueryRuleData, meta: RuleMeta, "integration": integration, } if data.get("notify", False): - print( - f"\nWarning: `{field}` in `{data.name}` not found in schema. {trailer}" - ) + print(f"\nWarning: `{field}` in `{data.name}` not found in schema. {trailer}") else: return exc # Check error fields against schemas of different packages or different integrations - for field, error_data in list(error_fields.items()): - error_package, error_integration = ( + for field, error_data in list(error_fields.items()): # type: ignore[reportUnknownArgumentType] + error_package, error_integration = ( # type: ignore[reportUnknownVariableType] error_data["package"], error_data["integration"], ) @@ -517,27 +542,31 @@ def validate_integration(self, data: QueryRuleData, meta: RuleMeta, else: # Compare against integration schemas for integration, schema in integrations_or_schema.items(): - check_alt_schema = ( - error_package != package or # noqa: W504 - (error_package == package and error_integration != integration) + check_alt_schema = ( # type: ignore[reportUnknownVariableType] + error_package != package or (error_package == package and error_integration != integration) ) if check_alt_schema and field in schema: del error_fields[field] # raise the first error if error_fields: - _, data = next(iter(error_fields.items())) - exc = data["error"] - return exc + _, data = next(iter(error_fields.items())) # type: ignore[reportUnknownArgumentType] + return data["error"] # type: ignore[reportIndexIssue] + return None - def validate_query_with_schema(self, data: 'QueryRuleData', schema: Union[ecs.KqlSchema2Eql, endgame.EndgameSchema], - err_trailer: str, min_stack_version: str, beat_types: list = None) -> Union[ - EQL_ERROR_TYPES, ValueError, None]: + def validate_query_with_schema( + self, + data: "QueryRuleData", # noqa: ARG002 + schema: ecs.KqlSchema2Eql | endgame.EndgameSchema, + err_trailer: str, + min_stack_version: str, + beat_types: list[str] | None = None, + ) -> EQL_ERROR_TYPES | ValueError | None: """Validate the query against the schema.""" try: config = set_eql_config(min_stack_version) with config, schema, eql.parser.elasticsearch_syntax, eql.parser.ignore_missing_functions: - eql.parse_query(self.query) + _ = eql.parse_query(self.query) # type: ignore[reportUnknownMemberType] except eql.EqlParseError as exc: message = exc.error_msg trailer = err_trailer @@ -546,69 +575,75 @@ def validate_query_with_schema(self, data: 'QueryRuleData', schema: Union[ecs.Kq elif "Field not recognized" in message: text_fields = self.text_fields(schema) if text_fields: - fields_str = ', '.join(text_fields) + fields_str = ", ".join(text_fields) trailer = f"\neql does not support text fields: {fields_str}\n\n{trailer}" - return exc.__class__(exc.error_msg, exc.line, exc.column, exc.source, - len(exc.caret.lstrip()), trailer=trailer) + return exc.__class__( + exc.error_msg, # type: ignore[reportUnknownArgumentType] + exc.line, # type: ignore[reportUnknownArgumentType] + exc.column, # type: ignore[reportUnknownArgumentType] + exc.source, # type: ignore[reportUnknownArgumentType] + len(exc.caret.lstrip()), + trailer=trailer, + ) - except Exception as exc: + except Exception as exc: # noqa: BLE001 print(err_trailer) - return exc + return exc # type: ignore[reportReturnType] - def validate_rule_type_configurations(self, data: EQLRuleData, meta: RuleMeta) -> \ - Tuple[List[Optional[str]], bool]: + def validate_rule_type_configurations(self, data: EQLRuleData, meta: RuleMeta) -> tuple[list[str | None], bool]: """Validate EQL rule type configurations.""" if data.timestamp_field or data.event_category_override or data.tiebreaker_field: - # get a list of rule type configuration fields # Get a list of rule type configuration fields fields = ["timestamp_field", "event_category_override", "tiebreaker_field"] - set_fields = list(filter(None, (data.get(field) for field in fields))) + set_fields = list(filter(None, (data.get(field) for field in fields))) # type: ignore[reportUnknownVariableType] # get stack_version and ECS schema min_stack_version = meta.get("min_stack_version") if min_stack_version is None: min_stack_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True) - ecs_version = get_stack_schemas()[str(min_stack_version)]['ecs'] + ecs_version = get_stack_schemas()[str(min_stack_version)]["ecs"] schema = ecs.get_schema(ecs_version) # return a list of rule type config field values and whether any are not in the schema - return (set_fields, any([f not in schema.keys() for f in set_fields])) - else: - # if rule type fields are not set, return an empty list and False - return [], False + return (set_fields, any(f not in schema for f in set_fields)) # type: ignore[reportUnknownVariableType] + # if rule type fields are not set, return an empty list and False + return [], False class ESQLValidator(QueryValidator): """Validate specific fields for ESQL query event types.""" @cached_property - def ast(self): + def ast(self) -> None: # type: ignore[reportIncompatibleMethodOverride] return None @cached_property - def unique_fields(self) -> List[str]: + def unique_fields(self) -> list[str]: # type: ignore[reportIncompatibleMethodOverride] """Return a list of unique fields in the query.""" # return empty list for ES|QL rules until ast is available (friendlier than raising error) - # raise NotImplementedError('ES|QL query parsing not yet supported') return [] - def validate(self, data: 'QueryRuleData', meta: RuleMeta) -> None: + def validate(self, _: "QueryRuleData", __: RuleMeta) -> None: # type: ignore[reportIncompatibleMethodOverride] """Validate an ESQL query while checking TOMLRule.""" # temporarily override to NOP until ES|QL query parsing is supported - def validate_integration(self, data: QueryRuleData, meta: RuleMeta, package_integrations: List[dict]) -> Union[ - ValidationError, None, ValueError]: - # return self.validate(data, meta) + def validate_integration( + self, + _: QueryRuleData, + __: RuleMeta, + ___: list[dict[str, Any]], + ) -> ValidationError | None | ValueError: + # Disabling self.validate(data, meta) pass -def extract_error_field(source: str, exc: Union[eql.EqlParseError, kql.KqlParseError]) -> Optional[str]: +def extract_error_field(source: str, exc: eql.EqlParseError | kql.KqlParseError) -> str | None: """Extract the field name from an EQL or KQL parse error.""" lines = source.splitlines() - mod = -1 if exc.line == len(lines) else 0 - line = lines[exc.line + mod] - start = exc.column - stop = start + len(exc.caret.strip()) - return re.sub(r'^\W+|\W+$', '', line[start:stop]) + mod = -1 if exc.line == len(lines) else 0 # type: ignore[reportUnknownMemberType] + line = lines[exc.line + mod] # type: ignore[reportUnknownMemberType] + start = exc.column # type: ignore[reportUnknownMemberType] + stop = start + len(exc.caret.strip()) # type: ignore[reportUnknownVariableType] + return re.sub(r"^\W+|\W+$", "", line[start:stop]) # type: ignore[reportUnknownArgumentType] diff --git a/detection_rules/schemas/__init__.py b/detection_rules/schemas/__init__.py index 98506eeb2bd..416094cb1d3 100644 --- a/detection_rules/schemas/__init__.py +++ b/detection_rules/schemas/__init__.py @@ -4,47 +4,52 @@ # 2.0. import json from collections import OrderedDict -from typing import List, Optional -from typing import OrderedDict as OrderedDictType +from collections import OrderedDict as OrderedDictType +from collections.abc import Callable +from typing import Any import jsonschema from semver import Version -from ..config import load_current_package_version, parse_rules_config -from ..utils import cached, get_etc_path +from detection_rules.config import load_current_package_version, parse_rules_config +from detection_rules.utils import cached, get_etc_path + from . import definitions from .stack_compat import get_incompatible_fields - __all__ = ( "SCHEMA_DIR", + "all_versions", "definitions", "downgrade", "get_incompatible_fields", "get_min_supported_stack_version", "get_stack_schemas", "get_stack_versions", - "all_versions", ) RULES_CONFIG = parse_rules_config() -SCHEMA_DIR = get_etc_path("api_schemas") -migrations = {} +SCHEMA_DIR = get_etc_path(["api_schemas"]) + +MigratedFuncT = Callable[..., Any] + +migrations: dict[str, MigratedFuncT] = {} -def all_versions() -> List[str]: +def all_versions() -> list[str]: """Get all known stack versions.""" return [str(v) for v in sorted(migrations, key=lambda x: Version.parse(x, optional_minor_and_patch=True))] -def migrate(version: str): +def migrate(version: str) -> Callable[[MigratedFuncT], MigratedFuncT]: """Decorator to set a migration.""" # checks that the migrate decorator name is semi-semantic versioned # raises validation error from semver if not - Version.parse(version, optional_minor_and_patch=True) + _ = Version.parse(version, optional_minor_and_patch=True) - def wrapper(f): - assert version not in migrations + def wrapper(f: MigratedFuncT) -> MigratedFuncT: + if version in migrations: + raise ValueError("Version found in migrations") migrations[version] = f return f @@ -52,7 +57,7 @@ def wrapper(f): @cached -def get_schema_file(version: Version, rule_type: str) -> dict: +def get_schema_file(version: Version, rule_type: str) -> dict[str, Any]: path = SCHEMA_DIR / str(version) / f"{version}.{rule_type}.json" if not path.exists(): @@ -61,13 +66,13 @@ def get_schema_file(version: Version, rule_type: str) -> dict: return json.loads(path.read_text(encoding="utf8")) -def strip_additional_properties(version: Version, api_contents: dict) -> dict: +def strip_additional_properties(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Remove all fields that the target schema doesn't recognize.""" - stripped = {} + stripped: dict[str, Any] = {} target_schema = get_schema_file(version, api_contents["type"]) - for field, field_schema in target_schema["properties"].items(): + for field in target_schema["properties"]: if field in api_contents: stripped[field] = api_contents[field] @@ -76,47 +81,46 @@ def strip_additional_properties(version: Version, api_contents: dict) -> dict: return stripped -def strip_non_public_fields(min_stack_version: Version, data_dict: dict) -> dict: +def strip_non_public_fields(min_stack_version: Version, data_dict: dict[str, Any]) -> dict[str, Any]: """Remove all non public fields.""" for field, version_range in definitions.NON_PUBLIC_FIELDS.items(): if version_range[0] <= min_stack_version <= (version_range[1] or min_stack_version): - if field in data_dict: - del data_dict[field] + data_dict.pop(field, None) return data_dict @migrate("7.8") -def migrate_to_7_8(version: Version, api_contents: dict) -> dict: +def migrate_to_7_8(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 7.8.""" return strip_additional_properties(version, api_contents) @migrate("7.9") -def migrate_to_7_9(version: Version, api_contents: dict) -> dict: +def migrate_to_7_9(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 7.9.""" return strip_additional_properties(version, api_contents) @migrate("7.10") -def downgrade_threat_to_7_10(version: Version, api_contents: dict) -> dict: +def downgrade_threat_to_7_10(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Downgrade the threat mapping changes from 7.11 to 7.10.""" if "threat" in api_contents: v711_threats = api_contents.get("threat", []) - v710_threats = [] + v710_threats: list[Any] = [] for threat in v711_threats: # drop tactic without threat if "technique" not in threat: continue - threat = threat.copy() - threat["technique"] = [t.copy() for t in threat["technique"]] + threat_copy = threat.copy() + threat_copy["technique"] = [t.copy() for t in threat_copy["technique"]] # drop subtechniques - for technique in threat["technique"]: + for technique in threat_copy["technique"]: technique.pop("subtechnique", None) - v710_threats.append(threat) + v710_threats.append(threat_copy) api_contents = api_contents.copy() api_contents.pop("threat") @@ -130,24 +134,24 @@ def downgrade_threat_to_7_10(version: Version, api_contents: dict) -> dict: @migrate("7.11") -def downgrade_threshold_to_7_11(version: Version, api_contents: dict) -> dict: +def downgrade_threshold_to_7_11(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Remove 7.12 threshold changes that don't impact the rule.""" if "threshold" in api_contents: - threshold = api_contents['threshold'] - threshold_field = threshold['field'] + threshold = api_contents["threshold"] + threshold_field = threshold["field"] # attempt to convert threshold field to a string if len(threshold_field) > 1: - raise ValueError('Cannot downgrade a threshold rule that has multiple threshold fields defined') + raise ValueError("Cannot downgrade a threshold rule that has multiple threshold fields defined") - if threshold.get('cardinality'): - raise ValueError('Cannot downgrade a threshold rule that has a defined cardinality') + if threshold.get("cardinality"): + raise ValueError("Cannot downgrade a threshold rule that has a defined cardinality") api_contents = api_contents.copy() api_contents["threshold"] = api_contents["threshold"].copy() # if cardinality was defined with no field or value - api_contents['threshold'].pop('cardinality', None) + api_contents["threshold"].pop("cardinality", None) api_contents["threshold"]["field"] = api_contents["threshold"]["field"][0] # finally, downgrade any additional properties that were added @@ -155,20 +159,20 @@ def downgrade_threshold_to_7_11(version: Version, api_contents: dict) -> dict: @migrate("7.12") -def migrate_to_7_12(version: Version, api_contents: dict) -> dict: +def migrate_to_7_12(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 7.12.""" return strip_additional_properties(version, api_contents) @migrate("7.13") -def downgrade_ml_multijob_713(version: Version, api_contents: dict) -> dict: +def downgrade_ml_multijob_713(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Convert `machine_learning_job_id` as an array to a string for < 7.13.""" if "machine_learning_job_id" in api_contents: job_id = api_contents["machine_learning_job_id"] if isinstance(job_id, list): - if len(job_id) > 1: - raise ValueError('Cannot downgrade an ML rule with multiple jobs defined') + if len(job_id) > 1: # type: ignore[reportUnknownArgumentType] + raise ValueError("Cannot downgrade an ML rule with multiple jobs defined") api_contents = api_contents.copy() api_contents["machine_learning_job_id"] = job_id[0] @@ -178,149 +182,150 @@ def downgrade_ml_multijob_713(version: Version, api_contents: dict) -> dict: @migrate("7.14") -def migrate_to_7_14(version: Version, api_contents: dict) -> dict: +def migrate_to_7_14(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 7.14.""" return strip_additional_properties(version, api_contents) @migrate("7.15") -def migrate_to_7_15(version: Version, api_contents: dict) -> dict: +def migrate_to_7_15(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 7.15.""" return strip_additional_properties(version, api_contents) @migrate("7.16") -def migrate_to_7_16(version: Version, api_contents: dict) -> dict: +def migrate_to_7_16(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 7.16.""" return strip_additional_properties(version, api_contents) @migrate("8.0") -def migrate_to_8_0(version: Version, api_contents: dict) -> dict: +def migrate_to_8_0(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.0.""" return strip_additional_properties(version, api_contents) @migrate("8.1") -def migrate_to_8_1(version: Version, api_contents: dict) -> dict: +def migrate_to_8_1(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.1.""" return strip_additional_properties(version, api_contents) @migrate("8.2") -def migrate_to_8_2(version: Version, api_contents: dict) -> dict: +def migrate_to_8_2(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.2.""" return strip_additional_properties(version, api_contents) @migrate("8.3") -def migrate_to_8_3(version: Version, api_contents: dict) -> dict: +def migrate_to_8_3(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.3.""" return strip_additional_properties(version, api_contents) @migrate("8.4") -def migrate_to_8_4(version: Version, api_contents: dict) -> dict: +def migrate_to_8_4(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.4.""" return strip_additional_properties(version, api_contents) @migrate("8.5") -def migrate_to_8_5(version: Version, api_contents: dict) -> dict: +def migrate_to_8_5(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.5.""" return strip_additional_properties(version, api_contents) @migrate("8.6") -def migrate_to_8_6(version: Version, api_contents: dict) -> dict: +def migrate_to_8_6(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.6.""" return strip_additional_properties(version, api_contents) @migrate("8.7") -def migrate_to_8_7(version: Version, api_contents: dict) -> dict: +def migrate_to_8_7(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.7.""" return strip_additional_properties(version, api_contents) @migrate("8.8") -def migrate_to_8_8(version: Version, api_contents: dict) -> dict: +def migrate_to_8_8(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.8.""" return strip_additional_properties(version, api_contents) @migrate("8.9") -def migrate_to_8_9(version: Version, api_contents: dict) -> dict: +def migrate_to_8_9(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.9.""" return strip_additional_properties(version, api_contents) @migrate("8.10") -def migrate_to_8_10(version: Version, api_contents: dict) -> dict: +def migrate_to_8_10(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.10.""" return strip_additional_properties(version, api_contents) @migrate("8.11") -def migrate_to_8_11(version: Version, api_contents: dict) -> dict: +def migrate_to_8_11(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.11.""" return strip_additional_properties(version, api_contents) @migrate("8.12") -def migrate_to_8_12(version: Version, api_contents: dict) -> dict: +def migrate_to_8_12(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.12.""" return strip_additional_properties(version, api_contents) @migrate("8.13") -def migrate_to_8_13(version: Version, api_contents: dict) -> dict: +def migrate_to_8_13(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.13.""" return strip_additional_properties(version, api_contents) @migrate("8.14") -def migrate_to_8_14(version: Version, api_contents: dict) -> dict: +def migrate_to_8_14(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.14.""" return strip_additional_properties(version, api_contents) @migrate("8.15") -def migrate_to_8_15(version: Version, api_contents: dict) -> dict: +def migrate_to_8_15(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.15.""" return strip_additional_properties(version, api_contents) @migrate("8.16") -def migrate_to_8_16(version: Version, api_contents: dict) -> dict: +def migrate_to_8_16(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.16.""" return strip_additional_properties(version, api_contents) @migrate("8.17") -def migrate_to_8_17(version: Version, api_contents: dict) -> dict: +def migrate_to_8_17(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.17.""" return strip_additional_properties(version, api_contents) @migrate("8.18") -def migrate_to_8_18(version: Version, api_contents: dict) -> dict: +def migrate_to_8_18(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 8.18.""" return strip_additional_properties(version, api_contents) @migrate("9.0") -def migrate_to_9_0(version: Version, api_contents: dict) -> dict: +def migrate_to_9_0(version: Version, api_contents: dict[str, Any]) -> dict[str, Any]: """Default migration for 9.0.""" return strip_additional_properties(version, api_contents) -def downgrade(api_contents: dict, target_version: str, current_version: Optional[str] = None) -> dict: +def downgrade( + api_contents: dict[str, Any], target_version: str, current_version_val: str | None = None +) -> dict[str, Any]: """Downgrade a rule to a target stack version.""" - from ..packaging import current_stack_version + from ..packaging import current_stack_version # noqa: TID252 - if current_version is None: - current_version = current_stack_version() + current_version = current_version_val or current_stack_version() current = Version.parse(current_version, optional_minor_and_patch=True) target = Version.parse(target_version, optional_minor_and_patch=True) @@ -340,47 +345,46 @@ def downgrade(api_contents: dict, target_version: str, current_version: Optional @cached -def load_stack_schema_map() -> dict: +def load_stack_schema_map() -> dict[str, Any]: return RULES_CONFIG.stack_schema_map @cached -def get_stack_schemas(stack_version: Optional[str] = '0.0.0') -> OrderedDictType[str, dict]: +def get_stack_schemas(stack_version_val: str | None = "0.0.0") -> OrderedDictType[str, dict[str, Any]]: """ Return all ECS, beats, and custom stack versions for every stack version. Only versions >= specified stack version and <= package are returned. """ - stack_version = Version.parse(stack_version or '0.0.0', optional_minor_and_patch=True) + stack_version = Version.parse(stack_version_val or "0.0.0", optional_minor_and_patch=True) current_package = Version.parse(load_current_package_version(), optional_minor_and_patch=True) stack_map = load_stack_schema_map() - versions = {k: v for k, v in stack_map.items() if - (((mapped_version := Version.parse(k)) >= stack_version) - and (mapped_version <= current_package) and v)} # noqa: W503 + versions = { + k: v + for k, v in stack_map.items() + if (((mapped_version := Version.parse(k)) >= stack_version) and (mapped_version <= current_package) and v) + } if stack_version > current_package: - versions[stack_version] = {'beats': 'main', 'ecs': 'master'} + versions[stack_version] = {"beats": "main", "ecs": "master"} - versions_reversed = OrderedDict(sorted(versions.items(), reverse=True)) - return versions_reversed + return OrderedDict(sorted(versions.items(), reverse=True)) -def get_stack_versions(drop_patch=False) -> List[str]: +def get_stack_versions(drop_patch: bool = False) -> list[str]: """Get a list of stack versions supported (for the matrix).""" versions = list(load_stack_schema_map()) if drop_patch: - abridged_versions = [] + abridged_versions: list[str] = [] for version in versions: - abridged, _ = version.rsplit('.', 1) + abridged, _ = version.rsplit(".", 1) abridged_versions.append(abridged) return abridged_versions - else: - return versions + return versions @cached def get_min_supported_stack_version() -> Version: """Get the minimum defined and supported stack version.""" stack_map = load_stack_schema_map() - min_version = min([Version.parse(v) for v in list(stack_map)]) - return min_version + return min([Version.parse(v) for v in list(stack_map)]) diff --git a/detection_rules/schemas/definitions.py b/detection_rules/schemas/definitions.py index b18cc9cca9f..0d9317a730a 100644 --- a/detection_rules/schemas/definitions.py +++ b/detection_rules/schemas/definitions.py @@ -4,256 +4,348 @@ # 2.0. """Custom shared definitions for schemas.""" + import os -from typing import Final, List, Literal +import re +from collections.abc import Callable +from re import Pattern +from typing import Annotated, Any, Final, Literal, NewType from marshmallow import fields, validate -from marshmallow_dataclass import NewType from semver import Version from detection_rules.config import CUSTOM_RULES_DIR -def elastic_timeline_template_id_validator(): +def elastic_timeline_template_id_validator() -> Callable[[Any], Any]: """Custom validator for Timeline Template IDs.""" - def validator(value): - if os.environ.get('DR_BYPASS_TIMELINE_TEMPLATE_VALIDATION') is not None: - fields.String().deserialize(value) - else: - validate.OneOf(list(TIMELINE_TEMPLATES))(value) - return validator + def validator_wrapper(value: Any) -> Any: + if os.environ.get("DR_BYPASS_TIMELINE_TEMPLATE_VALIDATION") is None: + template_ids = list(TIMELINE_TEMPLATES) + validator = validate.OneOf(template_ids) + validator(value) + return value + + return validator_wrapper -def elastic_timeline_template_title_validator(): +def elastic_timeline_template_title_validator() -> Callable[[Any], Any]: """Custom validator for Timeline Template Titles.""" - def validator(value): - if os.environ.get('DR_BYPASS_TIMELINE_TEMPLATE_VALIDATION') is not None: - fields.String().deserialize(value) - else: - validate.OneOf(TIMELINE_TEMPLATES.values())(value) - return validator + def validator_wrapper(value: Any) -> Any: + if os.environ.get("DR_BYPASS_TIMELINE_TEMPLATE_VALIDATION") is None: + template_titles = TIMELINE_TEMPLATES.values() + validator = validate.OneOf(template_titles) + validator(value) + return value + + return validator_wrapper -def elastic_rule_name_regexp(pattern): +def elastic_rule_name_regexp(pattern: Pattern[str]) -> Callable[[Any], Any]: """Custom validator for rule names.""" - def validator(value): + + regexp_validator = validate.Regexp(pattern) + + def validator_wrapper(value: Any) -> Any: if not CUSTOM_RULES_DIR: - validate.Regexp(pattern)(value) - else: - fields.String().deserialize(value) - return validator + regexp_validator(value) + return value + + return validator_wrapper ASSET_TYPE = "security_rule" SAVED_OBJECT_TYPE = "security-rule" -DATE_PATTERN = r'^\d{4}/\d{2}/\d{2}$' -MATURITY_LEVELS = ['development', 'experimental', 'beta', 'production', 'deprecated'] -OS_OPTIONS = ['windows', 'linux', 'macos'] -NAME_PATTERN = r'^[a-zA-Z0-9].+?[a-zA-Z0-9\[\]()]$' -PR_PATTERN = r'^$|\d+$' -SHA256_PATTERN = r'^[a-fA-F0-9]{64}$' -UUID_PATTERN = r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$' - -_version = r'\d+\.\d+(\.\d+[\w-]*)*' -CONDITION_VERSION_PATTERN = rf'^\^{_version}$' -VERSION_PATTERN = f'^{_version}$' -MINOR_SEMVER = r'^\d+\.\d+$' -BRANCH_PATTERN = f'{VERSION_PATTERN}|^master$' +DATE_PATTERN = re.compile(r"^\d{4}/\d{2}/\d{2}$") +MATURITY_LEVELS = ["development", "experimental", "beta", "production", "deprecated"] +OS_OPTIONS = ["windows", "linux", "macos"] + +NAME_PATTERN = re.compile(r"^[a-zA-Z0-9].+?[a-zA-Z0-9\[\]()]$") +PR_PATTERN = re.compile(r"^$|\d+$") +SHA256_PATTERN = re.compile(r"^[a-fA-F0-9]{64}$") +UUID_PATTERN = re.compile(r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$") + +_version = r"\d+\.\d+(\.\d+[\w-]*)*" +CONDITION_VERSION_PATTERN = re.compile(rf"^\^{_version}$") +VERSION_PATTERN = f"^{_version}$" +MINOR_SEMVER = re.compile(r"^\d+\.\d+$") +BRANCH_PATTERN = f"{VERSION_PATTERN}|^master$" ELASTICSEARCH_EQL_FEATURES = { - "allow_negation": (Version.parse('8.9.0'), None), - "allow_runs": (Version.parse('7.16.0'), None), - "allow_sample": (Version.parse('8.6.0'), None), - "elasticsearch_validate_optional_fields": (Version.parse('7.16.0'), None) + "allow_negation": (Version.parse("8.9.0"), None), + "allow_runs": (Version.parse("7.16.0"), None), + "allow_sample": (Version.parse("8.6.0"), None), + "elasticsearch_validate_optional_fields": (Version.parse("7.16.0"), None), } -NON_DATASET_PACKAGES = ['apm', - 'auditd_manager', - 'cloud_defend', - 'endpoint', - 'jamf_protect', - 'network_traffic', - 'system', - 'windows', - 'sentinel_one_cloud_funnel', - 'ti_rapid7_threat_command', - 'm365_defender', - 'panw', - 'crowdstrike'] +NON_DATASET_PACKAGES = [ + "apm", + "auditd_manager", + "cloud_defend", + "endpoint", + "jamf_protect", + "network_traffic", + "system", + "windows", + "sentinel_one_cloud_funnel", + "ti_rapid7_threat_command", + "m365_defender", + "panw", + "crowdstrike", +] NON_PUBLIC_FIELDS = { - "related_integrations": (Version.parse('8.3.0'), None), - "required_fields": (Version.parse('8.3.0'), None), - "setup": (Version.parse('8.3.0'), None) + "related_integrations": (Version.parse("8.3.0"), None), + "required_fields": (Version.parse("8.3.0"), None), + "setup": (Version.parse("8.3.0"), None), } -INTERVAL_PATTERN = r'^\d+[mshd]$' -TACTIC_URL = r'^https://attack.mitre.org/tactics/TA[0-9]+/$' -TECHNIQUE_URL = r'^https://attack.mitre.org/techniques/T[0-9]+/$' -SUBTECHNIQUE_URL = r'^https://attack.mitre.org/techniques/T[0-9]+/[0-9]+/$' -MACHINE_LEARNING = 'machine_learning' -QUERY = 'query' +INTERVAL_PATTERN = r"^\d+[mshd]$" +TACTIC_URL = r"^https://attack.mitre.org/tactics/TA[0-9]+/$" +TECHNIQUE_URL = r"^https://attack.mitre.org/techniques/T[0-9]+/$" +SUBTECHNIQUE_URL = r"^https://attack.mitre.org/techniques/T[0-9]+/[0-9]+/$" +MACHINE_LEARNING = "machine_learning" +QUERY = "query" QUERY_FIELD_OP_EXCEPTIONS = ["powershell.file.script_block_text"] # we had a bad rule ID make it in before tightening up the pattern, and so we have to let it bypass -KNOWN_BAD_RULE_IDS = Literal['119c8877-8613-416d-a98a-96b6664ee73a5'] -KNOWN_BAD_DEPRECATED_DATES = Literal['2021-03-03'] +KNOWN_BAD_RULE_IDS = Literal["119c8877-8613-416d-a98a-96b6664ee73a5"] +KNOWN_BAD_DEPRECATED_DATES = Literal["2021-03-03"] # Known Null values that cannot be handled in TOML due to lack of Null value support via compound dicts KNOWN_NULL_ENTRIES = [{"rule.actions": "frequency.throttle"}] -OPERATORS = ['equals'] - -TIMELINE_TEMPLATES: Final[dict] = { - 'db366523-f1c6-4c1f-8731-6ce5ed9e5717': 'Generic Endpoint Timeline', - '91832785-286d-4ebe-b884-1a208d111a70': 'Generic Network Timeline', - '76e52245-7519-4251-91ab-262fb1a1728c': 'Generic Process Timeline', - '495ad7a7-316e-4544-8a0f-9c098daee76e': 'Generic Threat Match Timeline', - '4d4c0b59-ea83-483f-b8c1-8c360ee53c5c': 'Comprehensive File Timeline', - 'e70679c2-6cde-4510-9764-4823df18f7db': 'Comprehensive Process Timeline', - '300afc76-072d-4261-864d-4149714bf3f1': 'Comprehensive Network Timeline', - '3e47ef71-ebfc-4520-975c-cb27fc090799': 'Comprehensive Registry Timeline', - '3e827bab-838a-469f-bd1e-5e19a2bff2fd': 'Alerts Involving a Single User Timeline', - '4434b91a-94ca-4a89-83cb-a37cdc0532b7': 'Alerts Involving a Single Host Timeline' +OPERATORS = ["equals"] + +TIMELINE_TEMPLATES: Final[dict[str, str]] = { + "db366523-f1c6-4c1f-8731-6ce5ed9e5717": "Generic Endpoint Timeline", + "91832785-286d-4ebe-b884-1a208d111a70": "Generic Network Timeline", + "76e52245-7519-4251-91ab-262fb1a1728c": "Generic Process Timeline", + "495ad7a7-316e-4544-8a0f-9c098daee76e": "Generic Threat Match Timeline", + "4d4c0b59-ea83-483f-b8c1-8c360ee53c5c": "Comprehensive File Timeline", + "e70679c2-6cde-4510-9764-4823df18f7db": "Comprehensive Process Timeline", + "300afc76-072d-4261-864d-4149714bf3f1": "Comprehensive Network Timeline", + "3e47ef71-ebfc-4520-975c-cb27fc090799": "Comprehensive Registry Timeline", + "3e827bab-838a-469f-bd1e-5e19a2bff2fd": "Alerts Involving a Single User Timeline", + "4434b91a-94ca-4a89-83cb-a37cdc0532b7": "Alerts Involving a Single Host Timeline", } EXPECTED_RULE_TAGS = [ - 'Data Source: Active Directory', - 'Data Source: Amazon Web Services', - 'Data Source: Auditd Manager', - 'Data Source: AWS', - 'Data Source: APM', - 'Data Source: Azure', - 'Data Source: CyberArk PAS', - 'Data Source: Elastic Defend', - 'Data Source: Elastic Defend for Containers', - 'Data Source: Elastic Endgame', - 'Data Source: GCP', - 'Data Source: Google Cloud Platform', - 'Data Source: Google Workspace', - 'Data Source: Kubernetes', - 'Data Source: Microsoft 365', - 'Data Source: Okta', - 'Data Source: PowerShell Logs', - 'Data Source: Sysmon Only', - 'Data Source: Zoom', - 'Domain: Cloud', - 'Domain: Container', - 'Domain: Endpoint', - 'Mitre Atlas: *', - 'OS: Linux', - 'OS: macOS', - 'OS: Windows', - 'Rule Type: BBR', - 'Resources: Investigation Guide', - 'Rule Type: Higher-Order Rule', - 'Rule Type: Machine Learning', - 'Rule Type: ML', - 'Tactic: Collection', - 'Tactic: Command and Control', - 'Tactic: Credential Access', - 'Tactic: Defense Evasion', - 'Tactic: Discovery', - 'Tactic: Execution', - 'Tactic: Exfiltration', - 'Tactic: Impact', - 'Tactic: Initial Access', - 'Tactic: Lateral Movement', - 'Tactic: Persistence', - 'Tactic: Privilege Escalation', - 'Tactic: Reconnaissance', - 'Tactic: Resource Development', - 'Threat: BPFDoor', - 'Threat: Cobalt Strike', - 'Threat: Lightning Framework', - 'Threat: Orbit', - 'Threat: Rootkit', - 'Threat: TripleCross', - 'Use Case: Active Directory Monitoring', - 'Use Case: Asset Visibility', - 'Use Case: Configuration Audit', - 'Use Case: Guided Onboarding', - 'Use Case: Identity and Access Audit', - 'Use Case: Log Auditing', - 'Use Case: Network Security Monitoring', - 'Use Case: Threat Detection', - 'Use Case: UEBA', - 'Use Case: Vulnerability' + "Data Source: Active Directory", + "Data Source: Amazon Web Services", + "Data Source: Auditd Manager", + "Data Source: AWS", + "Data Source: APM", + "Data Source: Azure", + "Data Source: CyberArk PAS", + "Data Source: Elastic Defend", + "Data Source: Elastic Defend for Containers", + "Data Source: Elastic Endgame", + "Data Source: GCP", + "Data Source: Google Cloud Platform", + "Data Source: Google Workspace", + "Data Source: Kubernetes", + "Data Source: Microsoft 365", + "Data Source: Okta", + "Data Source: PowerShell Logs", + "Data Source: Sysmon Only", + "Data Source: Zoom", + "Domain: Cloud", + "Domain: Container", + "Domain: Endpoint", + "Mitre Atlas: *", + "OS: Linux", + "OS: macOS", + "OS: Windows", + "Rule Type: BBR", + "Resources: Investigation Guide", + "Rule Type: Higher-Order Rule", + "Rule Type: Machine Learning", + "Rule Type: ML", + "Tactic: Collection", + "Tactic: Command and Control", + "Tactic: Credential Access", + "Tactic: Defense Evasion", + "Tactic: Discovery", + "Tactic: Execution", + "Tactic: Exfiltration", + "Tactic: Impact", + "Tactic: Initial Access", + "Tactic: Lateral Movement", + "Tactic: Persistence", + "Tactic: Privilege Escalation", + "Tactic: Reconnaissance", + "Tactic: Resource Development", + "Threat: BPFDoor", + "Threat: Cobalt Strike", + "Threat: Lightning Framework", + "Threat: Orbit", + "Threat: Rootkit", + "Threat: TripleCross", + "Use Case: Active Directory Monitoring", + "Use Case: Asset Visibility", + "Use Case: Configuration Audit", + "Use Case: Guided Onboarding", + "Use Case: Identity and Access Audit", + "Use Case: Log Auditing", + "Use Case: Network Security Monitoring", + "Use Case: Threat Detection", + "Use Case: UEBA", + "Use Case: Vulnerability", ] -NonEmptyStr = NewType('NonEmptyStr', str, validate=validate.Length(min=1)) -MACHINE_LEARNING_PACKAGES = ['LMD', 'DGA', 'DED', 'ProblemChild', 'Beaconing', 'PAD'] -AlertSuppressionGroupBy = NewType('AlertSuppressionGroupBy', List[NonEmptyStr], validate=validate.Length(min=1, max=3)) -AlertSuppressionMissing = NewType('AlertSuppressionMissing', str, - validate=validate.OneOf(['suppress', 'doNotSuppress'])) -AlertSuppressionValue = NewType("AlertSupressionValue", int, validate=validate.Range(min=1)) -TimeUnits = Literal['s', 'm', 'h'] -BranchVer = NewType('BranchVer', str, validate=validate.Regexp(BRANCH_PATTERN)) -CardinalityFields = NewType('CardinalityFields', List[NonEmptyStr], validate=validate.Length(min=0, max=3)) + +MACHINE_LEARNING_PACKAGES = ["LMD", "DGA", "DED", "ProblemChild", "Beaconing", "PAD"] + CodeString = NewType("CodeString", str) -ConditionSemVer = NewType('ConditionSemVer', str, validate=validate.Regexp(CONDITION_VERSION_PATTERN)) -Date = NewType('Date', str, validate=validate.Regexp(DATE_PATTERN)) -ExceptionEntryOperator = Literal['included', 'excluded'] -ExceptionEntryType = Literal['match', 'match_any', 'exists', 'list', 'wildcard', 'nested'] -ExceptionNamespaceType = Literal['single', 'agnostic'] -ExceptionItemEndpointTags = Literal['endpoint', 'os:windows', 'os:linux', 'os:macos'] -ExceptionContainerType = Literal['detection', 'endpoint', 'rule_default'] -ExceptionItemType = Literal['simple'] +Markdown = NewType("Markdown", CodeString) + +TimeUnits = Literal["s", "m", "h"] +ExceptionEntryOperator = Literal["included", "excluded"] +ExceptionEntryType = Literal["match", "match_any", "exists", "list", "wildcard", "nested"] +ExceptionNamespaceType = Literal["single", "agnostic"] +ExceptionItemEndpointTags = Literal["endpoint", "os:windows", "os:linux", "os:macos"] +ExceptionContainerType = Literal["detection", "endpoint", "rule_default"] +ExceptionItemType = Literal["simple"] FilterLanguages = Literal["eql", "esql", "kuery", "lucene"] -Interval = NewType('Interval', str, validate=validate.Regexp(INTERVAL_PATTERN)) + InvestigateProviderQueryType = Literal["phrase", "range"] InvestigateProviderValueType = Literal["string", "boolean"] -Markdown = NewType("MarkdownField", CodeString) -Maturity = Literal['development', 'experimental', 'beta', 'production', 'deprecated'] -MaxSignals = NewType("MaxSignals", int, validate=validate.Range(min=1)) -NewTermsFields = NewType('NewTermsFields', List[NonEmptyStr], validate=validate.Length(min=1, max=3)) -Operator = Literal['equals'] -OSType = Literal['windows', 'linux', 'macos'] -PositiveInteger = NewType('PositiveInteger', int, validate=validate.Range(min=1)) -RiskScore = NewType("MaxSignals", int, validate=validate.Range(min=1, max=100)) -RuleName = NewType('RuleName', str, validate=elastic_rule_name_regexp(NAME_PATTERN)) -RuleType = Literal['query', 'saved_query', 'machine_learning', 'eql', 'esql', 'threshold', 'threat_match', 'new_terms'] -SemVer = NewType('SemVer', str, validate=validate.Regexp(VERSION_PATTERN)) -SemVerMinorOnly = NewType('SemVerFullStrict', str, validate=validate.Regexp(MINOR_SEMVER)) -Severity = Literal['low', 'medium', 'high', 'critical'] -Sha256 = NewType('Sha256', str, validate=validate.Regexp(SHA256_PATTERN)) -SubTechniqueURL = NewType('SubTechniqueURL', str, validate=validate.Regexp(SUBTECHNIQUE_URL)) -StoreType = Literal['appState', 'globalState'] -TacticURL = NewType('TacticURL', str, validate=validate.Regexp(TACTIC_URL)) -TechniqueURL = NewType('TechniqueURL', str, validate=validate.Regexp(TECHNIQUE_URL)) -ThresholdValue = NewType("ThresholdValue", int, validate=validate.Range(min=1)) -TimelineTemplateId = NewType('TimelineTemplateId', str, validate=elastic_timeline_template_id_validator()) -TimelineTemplateTitle = NewType('TimelineTemplateTitle', str, validate=elastic_timeline_template_title_validator()) + +Operator = Literal["equals"] +OSType = Literal["windows", "linux", "macos"] + +Severity = Literal["low", "medium", "high", "critical"] +Maturity = Literal["development", "experimental", "beta", "production", "deprecated"] +RuleType = Literal["query", "saved_query", "machine_learning", "eql", "esql", "threshold", "threat_match", "new_terms"] +StoreType = Literal["appState", "globalState"] TransformTypes = Literal["osquery", "investigate"] -UUIDString = NewType('UUIDString', str, validate=validate.Regexp(UUID_PATTERN)) -BuildingBlockType = Literal['default'] +BuildingBlockType = Literal["default"] + +NON_EMPTY_STRING_FIELD = fields.String(validate=validate.Length(min=1)) +NonEmptyStr = Annotated[str, NON_EMPTY_STRING_FIELD] + +AlertSuppressionGroupBy = Annotated[ + list[NonEmptyStr], fields.List(NON_EMPTY_STRING_FIELD, validate=validate.Length(min=1, max=3)) +] +AlertSuppressionMissing = Annotated[str, fields.String(validate=validate.OneOf(["suppress", "doNotSuppress"]))] +AlertSuppressionValue = Annotated[int, fields.Integer(validate=validate.Range(min=1))] +BranchVer = Annotated[str, fields.String(validate=validate.Regexp(BRANCH_PATTERN))] +CardinalityFields = Annotated[ + list[NonEmptyStr], + fields.List(NON_EMPTY_STRING_FIELD, validate=validate.Length(min=0, max=3)), +] +ConditionSemVer = Annotated[str, fields.String(validate=validate.Regexp(CONDITION_VERSION_PATTERN))] +Date = Annotated[str, fields.String(validate=validate.Regexp(DATE_PATTERN))] +Interval = Annotated[str, fields.String(validate=validate.Regexp(INTERVAL_PATTERN))] +MaxSignals = Annotated[int, fields.Integer(validate=validate.Range(min=1))] +NewTermsFields = Annotated[ + list[NonEmptyStr], fields.List(NON_EMPTY_STRING_FIELD, validate=validate.Length(min=1, max=3)) +] +PositiveInteger = Annotated[int, fields.Integer(validate=validate.Range(min=1))] +RiskScore = Annotated[int, fields.Integer(validate=validate.Range(min=1, max=100))] +RuleName = Annotated[str, fields.String(validate=elastic_rule_name_regexp(NAME_PATTERN))] +SemVer = Annotated[str, fields.String(validate=validate.Regexp(VERSION_PATTERN))] +SemVerMinorOnly = Annotated[str, fields.String(validate=validate.Regexp(MINOR_SEMVER))] +Sha256 = Annotated[str, fields.String(validate=validate.Regexp(SHA256_PATTERN))] +SubTechniqueURL = Annotated[str, fields.String(validate=validate.Regexp(SUBTECHNIQUE_URL))] +TacticURL = Annotated[str, fields.String(validate=validate.Regexp(TACTIC_URL))] +TechniqueURL = Annotated[str, fields.String(validate=validate.Regexp(TECHNIQUE_URL))] +ThresholdValue = Annotated[int, fields.Integer(validate=validate.Range(min=1))] +TimelineTemplateId = Annotated[str, fields.String(validate=elastic_timeline_template_id_validator())] +TimelineTemplateTitle = Annotated[str, fields.String(validate=elastic_timeline_template_title_validator())] +UUIDString = Annotated[str, fields.String(validate=validate.Regexp(UUID_PATTERN))] # experimental machine learning features and releases -MachineLearningType = getattr(Literal, '__getitem__')(tuple(MACHINE_LEARNING_PACKAGES)) # noqa: E999 -MachineLearningTypeLower = getattr(Literal, '__getitem__')( - tuple(map(str.lower, MACHINE_LEARNING_PACKAGES))) # noqa: E999 -## +MachineLearningType = Literal[MACHINE_LEARNING_PACKAGES] +MACHINE_LEARNING_PACKAGES_LOWER = tuple(map(str.lower, MACHINE_LEARNING_PACKAGES)) +MachineLearningTypeLower = Literal[MACHINE_LEARNING_PACKAGES_LOWER] ActionTypeId = Literal[ - ".slack", ".slack_api", ".email", ".index", ".pagerduty", ".swimlane", ".webhook", ".servicenow", - ".servicenow-itom", ".servicenow-sir", ".jira", ".resilient", ".opsgenie", ".teams", ".torq", ".tines", - ".d3security" + ".slack", + ".slack_api", + ".email", + ".index", + ".pagerduty", + ".swimlane", + ".webhook", + ".servicenow", + ".servicenow-itom", + ".servicenow-sir", + ".jira", + ".resilient", + ".opsgenie", + ".teams", + ".torq", + ".tines", + ".d3security", ] EsDataTypes = Literal[ - 'binary', 'boolean', - 'keyword', 'constant_keyword', 'wildcard', - 'long', 'integer', 'short', 'byte', 'double', 'float', 'half_float', 'scaled_float', 'unsigned_long', - 'date', 'date_nanos', - 'alias', 'object', 'flatten', 'nested', 'join', - 'integer_range', 'float_range', 'long_range', 'double_range', 'date_range', 'ip_range', - 'ip', 'version', 'murmur3', 'aggregate_metric_double', 'histogram', - 'text', 'text_match_only', 'annotated-text', 'completion', 'search_as_you_type', 'token_count', - 'dense_vector', 'sparse_vector', 'rank_feature', 'rank_features', - 'geo_point', 'geo_shape', 'point', 'shape', - 'percolator' + "binary", + "boolean", + "keyword", + "constant_keyword", + "wildcard", + "long", + "integer", + "short", + "byte", + "double", + "float", + "half_float", + "scaled_float", + "unsigned_long", + "date", + "date_nanos", + "alias", + "object", + "flatten", + "nested", + "join", + "integer_range", + "float_range", + "long_range", + "double_range", + "date_range", + "ip_range", + "ip", + "version", + "murmur3", + "aggregate_metric_double", + "histogram", + "text", + "text_match_only", + "annotated-text", + "completion", + "search_as_you_type", + "token_count", + "dense_vector", + "sparse_vector", + "rank_feature", + "rank_features", + "geo_point", + "geo_shape", + "point", + "shape", + "percolator", ] # definitions for the integration to index mapping unit test case -IGNORE_IDS = ["eb079c62-4481-4d6e-9643-3ca499df7aaa", "699e9fdb-b77c-4c01-995c-1c15019b9c43", - "0c9a14d9-d65d-486f-9b5b-91e4e6b22bd0", "a198fbbd-9413-45ec-a269-47ae4ccf59ce", - "0c41e478-5263-4c69-8f9e-7dfd2c22da64", "aab184d3-72b3-4639-b242-6597c99d8bca", - "a61809f3-fb5b-465c-8bff-23a8a068ac60", "f3e22c8b-ea47-45d1-b502-b57b6de950b3", - "fcf18de8-ad7d-4d01-b3f7-a11d5b3883af"] -IGNORE_INDICES = ['.alerts-security.*', 'logs-*', 'metrics-*', 'traces-*', 'endgame-*', - 'filebeat-*', 'packetbeat-*', 'auditbeat-*', 'winlogbeat-*'] +IGNORE_IDS = [ + "eb079c62-4481-4d6e-9643-3ca499df7aaa", + "699e9fdb-b77c-4c01-995c-1c15019b9c43", + "0c9a14d9-d65d-486f-9b5b-91e4e6b22bd0", + "a198fbbd-9413-45ec-a269-47ae4ccf59ce", + "0c41e478-5263-4c69-8f9e-7dfd2c22da64", + "aab184d3-72b3-4639-b242-6597c99d8bca", + "a61809f3-fb5b-465c-8bff-23a8a068ac60", + "f3e22c8b-ea47-45d1-b502-b57b6de950b3", + "fcf18de8-ad7d-4d01-b3f7-a11d5b3883af", +] +IGNORE_INDICES = [ + ".alerts-security.*", + "logs-*", + "metrics-*", + "traces-*", + "endgame-*", + "filebeat-*", + "packetbeat-*", + "auditbeat-*", + "winlogbeat-*", +] diff --git a/detection_rules/schemas/registry_package.py b/detection_rules/schemas/registry_package.py index 23aa7e46e00..334bb7b9be0 100644 --- a/detection_rules/schemas/registry_package.py +++ b/detection_rules/schemas/registry_package.py @@ -6,16 +6,15 @@ """Definitions for packages destined for the registry.""" from dataclasses import dataclass, field -from typing import Dict, List, Optional -from .definitions import ConditionSemVer, SemVer -from ..mixins import MarshmallowDataclassMixin +from detection_rules.mixins import MarshmallowDataclassMixin +from detection_rules.schemas.definitions import ConditionSemVer, SemVer @dataclass class ConditionElastic: subscription: str - capabilities: Optional[List[str]] + capabilities: list[str] | None @dataclass @@ -35,26 +34,26 @@ class Icon: class RegistryPackageManifestBase(MarshmallowDataclassMixin): """Base class for registry packages.""" - categories: List[str] + categories: list[str] description: str format_version: SemVer - icons: List[Icon] + icons: list[Icon] name: str - owner: Dict[str, str] + owner: dict[str, str] title: str type: str version: SemVer - internal: Optional[bool] - policy_templates: Optional[List[str]] - screenshots: Optional[List[str]] + internal: bool | None + policy_templates: list[str] | None + screenshots: list[str] | None @dataclass class RegistryPackageManifestV1(RegistryPackageManifestBase): """Registry packages using elastic-package v1.""" - conditions: Dict[str, ConditionSemVer] + conditions: dict[str, ConditionSemVer] license: str release: str @@ -64,4 +63,4 @@ class RegistryPackageManifestV3(RegistryPackageManifestBase): """Registry packages using elastic-package v3.""" conditions: Condition - source: Dict[str, str] + source: dict[str, str] diff --git a/detection_rules/schemas/stack_compat.py b/detection_rules/schemas/stack_compat.py index 0981f30cb5c..ca0cabf965f 100644 --- a/detection_rules/schemas/stack_compat.py +++ b/detection_rules/schemas/stack_compat.py @@ -4,30 +4,29 @@ # 2.0. from dataclasses import Field -from typing import Dict, List, Optional, Tuple +from typing import Any from semver import Version -from ..misc import cached +from detection_rules.misc import cached @cached -def get_restricted_field(schema_field: Field) -> Tuple[Optional[Version], Optional[Version]]: +def get_restricted_field(schema_field: Field[Any]) -> tuple[Version | None, Version | None]: """Get an optional min and max compatible versions of a field (from a schema or dataclass).""" # nested get is to support schema fields being passed directly from dataclass or fields in schema class, since # marshmallow_dataclass passes the embedded metadata directly - min_compat = schema_field.metadata.get('metadata', schema_field.metadata).get('min_compat') - max_compat = schema_field.metadata.get('metadata', schema_field.metadata).get('max_compat') + min_compat = schema_field.metadata.get("metadata", schema_field.metadata).get("min_compat") + max_compat = schema_field.metadata.get("metadata", schema_field.metadata).get("max_compat") min_compat = Version.parse(min_compat, optional_minor_and_patch=True) if min_compat else None max_compat = Version.parse(max_compat, optional_minor_and_patch=True) if max_compat else None return min_compat, max_compat @cached -def get_restricted_fields(schema_fields: List[Field]) -> Dict[str, Tuple[Optional[Version], - Optional[Version]]]: +def get_restricted_fields(schema_fields: list[Field[Any]]) -> dict[str, tuple[Version | None, Version | None]]: """Get a list of optional min and max compatible versions of fields (from a schema or dataclass).""" - restricted = {} + restricted: dict[str, tuple[Version | None, Version | None]] = {} for _field in schema_fields: min_compat, max_compat = get_restricted_field(_field) if min_compat or max_compat: @@ -37,18 +36,20 @@ def get_restricted_fields(schema_fields: List[Field]) -> Dict[str, Tuple[Optiona @cached -def get_incompatible_fields(schema_fields: List[Field], package_version: Version) -> \ - Optional[Dict[str, tuple]]: +def get_incompatible_fields( + schema_fields: list[Field[Any]], + package_version: Version, +) -> dict[str, tuple[Version | None, Version | None]] | None: """Get a list of fields that are incompatible with the package version.""" if not schema_fields: - return + return None - incompatible = {} + incompatible: dict[str, tuple[Version | None, Version | None]] = {} restricted_fields = get_restricted_fields(schema_fields) for field_name, values in restricted_fields.items(): min_compat, max_compat = values - if min_compat and package_version < min_compat or max_compat and package_version > max_compat: + if (min_compat and package_version < min_compat) or (max_compat and package_version > max_compat): incompatible[field_name] = (min_compat, max_compat) return incompatible diff --git a/detection_rules/utils.py b/detection_rules/utils.py index bbc8515f6d8..c1d3a9e4dfb 100644 --- a/detection_rules/utils.py +++ b/detection_rules/utils.py @@ -4,10 +4,10 @@ # 2.0. """Util functions.""" + import base64 import contextlib import functools -import glob import gzip import hashlib import io @@ -17,46 +17,38 @@ import shutil import subprocess import zipfile -from dataclasses import is_dataclass, astuple -from datetime import datetime, date, timezone +from collections.abc import Callable, Iterator +from dataclasses import astuple, is_dataclass +from datetime import UTC, date, datetime from pathlib import Path -from typing import Dict, Union, Optional, Callable from string import Template +from typing import Any import click -import pytoml -import eql.utils -from eql.utils import load_dump, stream_json_lines +import eql.utils # type: ignore[reportMissingTypeStubs] +import pytoml # type: ignore[reportMissingTypeStubs] +from eql.utils import load_dump # type: ignore[reportMissingTypeStubs] from github.Repository import Repository -import kql - - CURR_DIR = Path(__file__).resolve().parent ROOT_DIR = CURR_DIR.parent ETC_DIR = ROOT_DIR / "detection_rules" / "etc" INTEGRATION_RULE_DIR = ROOT_DIR / "rules" / "integrations" -class NonelessDict(dict): - """Wrapper around dict that doesn't populate None values.""" - - def __setitem__(self, key, value): - if value is not None: - dict.__setitem__(self, key, value) - - class DateTimeEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, (date, datetime)): - return obj.isoformat() + def default(self, o: Any) -> Any: + if isinstance(o, (date | datetime)): + return o.isoformat() + return None marshmallow_schemas = {} -def gopath() -> Optional[str]: - """Retrieve $GOPATH.""" +def gopath() -> str | None: + """Retrieve $GOPATH""" + env_path = os.getenv("GOPATH") if env_path: return env_path @@ -66,170 +58,150 @@ def gopath() -> Optional[str]: output = subprocess.check_output([go_bin, "env"], encoding="utf-8").splitlines() for line in output: if line.startswith("GOPATH="): - return line[len("GOPATH="):].strip('"') + return line[len("GOPATH=") :].strip('"') + return None -def dict_hash(obj: dict) -> str: +def dict_hash(obj: dict[Any, Any]) -> str: """Hash a dictionary deterministically.""" - raw_bytes = base64.b64encode(json.dumps(obj, sort_keys=True).encode('utf-8')) + raw_bytes = base64.b64encode(json.dumps(obj, sort_keys=True).encode("utf-8")) return hashlib.sha256(raw_bytes).hexdigest() -def ensure_list_of_strings(value: str | list) -> list[str]: +def ensure_list_of_strings(value: str | list[str]) -> list[str]: """Ensure or convert a value is a list of strings.""" if isinstance(value, str): # Check if the string looks like a JSON list - if value.startswith('[') and value.endswith(']'): + if value.startswith("[") and value.endswith("]"): try: # Attempt to parse the string as a JSON list parsed_value = json.loads(value) if isinstance(parsed_value, list): - return [str(v) for v in parsed_value] + return [str(v) for v in parsed_value] # type: ignore[reportUnknownVariableType] except json.JSONDecodeError: pass # If it's not a JSON list, split by commas if present # Else return a list with the original string - return list(map(lambda x: x.strip().strip('"'), value.split(','))) - elif isinstance(value, list): - return [str(v) for v in value] - else: - return [] - - -def get_json_iter(f): - """Get an iterator over a JSON file.""" - first = f.read(2) - f.seek(0) - - if first[0] == '[' or first == "{\n": - return json.load(f) - else: - data = list(stream_json_lines(f)) - return data + return [x.strip().strip('"') for x in value.split(",")] + return [str(v) for v in value] -def get_nested_value(dictionary, compound_key): - """Get a nested value from a dictionary.""" - keys = compound_key.split('.') +def get_nested_value(obj: Any, compound_key: str) -> Any: + """Get a nested value from a obj.""" + keys = compound_key.split(".") for key in keys: - if isinstance(dictionary, dict): - dictionary = dictionary.get(key) + if isinstance(obj, dict): + obj = obj.get(key) # type: ignore[reportUnknownVariableType] else: return None - return dictionary + return obj # type: ignore[reportUnknownVariableType] -def get_path(*paths) -> Path: +def get_path(paths: list[str]) -> Path: """Get a file by relative path.""" return ROOT_DIR.joinpath(*paths) -def get_etc_path(*paths) -> Path: +def get_etc_path(paths: list[str]) -> Path: """Load a file from the detection_rules/etc/ folder.""" return ETC_DIR.joinpath(*paths) -def get_etc_glob_path(*patterns) -> list: +def get_etc_glob_path(patterns: list[str]) -> list[Path]: """Load a file from the detection_rules/etc/ folder.""" - pattern = os.path.join(*patterns) - return glob.glob(str(ETC_DIR / pattern)) + pattern = os.path.join(*patterns) # noqa: PTH118 + return list(ETC_DIR.glob(pattern)) -def get_etc_file(name, mode="r"): +def get_etc_file(name: str, mode: str = "r") -> str: """Load a file from the detection_rules/etc/ folder.""" - with open(get_etc_path(name), mode) as f: + with get_etc_path([name]).open(mode) as f: return f.read() -def load_etc_dump(*path): +def load_etc_dump(paths: list[str]) -> Any: """Load a json/yml/toml file from the detection_rules/etc/ folder.""" - return eql.utils.load_dump(str(get_etc_path(*path))) + return eql.utils.load_dump(str(get_etc_path(paths))) # type: ignore[reportUnknownVariableType] -def save_etc_dump(contents, *path, **kwargs): +def save_etc_dump(contents: dict[str, Any], path: list[str], sort_keys: bool = True, indent: int = 2) -> None: """Save a json/yml/toml file from the detection_rules/etc/ folder.""" - path = str(get_etc_path(*path)) - _, ext = os.path.splitext(path) - sort_keys = kwargs.pop('sort_keys', True) - indent = kwargs.pop('indent', 2) - - if ext == ".json": - with open(path, "wt") as f: - json.dump(contents, f, cls=DateTimeEncoder, sort_keys=sort_keys, indent=indent, **kwargs) + path_joined = get_etc_path(path) + + if path_joined.suffix == ".json": + with path_joined.open("w") as f: + json.dump(contents, f, cls=DateTimeEncoder, sort_keys=sort_keys, indent=indent) else: - return eql.utils.save_dump(contents, path) + eql.utils.save_dump(contents, path) # type: ignore[reportUnknownVariableType] -def set_all_validation_bypass(env_value: bool = False): +def set_all_validation_bypass(env_value: bool = False) -> None: """Set all validation bypass environment variables.""" - os.environ['DR_BYPASS_NOTE_VALIDATION_AND_PARSE'] = str(env_value) - os.environ['DR_BYPASS_BBR_LOOKBACK_VALIDATION'] = str(env_value) - os.environ['DR_BYPASS_TAGS_VALIDATION'] = str(env_value) - os.environ['DR_BYPASS_TIMELINE_TEMPLATE_VALIDATION'] = str(env_value) + os.environ["DR_BYPASS_NOTE_VALIDATION_AND_PARSE"] = str(env_value) + os.environ["DR_BYPASS_BBR_LOOKBACK_VALIDATION"] = str(env_value) + os.environ["DR_BYPASS_TAGS_VALIDATION"] = str(env_value) + os.environ["DR_BYPASS_TIMELINE_TEMPLATE_VALIDATION"] = str(env_value) -def set_nested_value(dictionary, compound_key, value): - """Set a nested value in a dictionary.""" - keys = compound_key.split('.') +def set_nested_value(obj: dict[str, Any], compound_key: str, value: Any) -> None: + """Set a nested value in a obj.""" + keys = compound_key.split(".") for key in keys[:-1]: - dictionary = dictionary.setdefault(key, {}) - dictionary[keys[-1]] = value + obj = obj.setdefault(key, {}) + obj[keys[-1]] = value -def gzip_compress(contents) -> bytes: +def gzip_compress(contents: str) -> bytes: gz_file = io.BytesIO() with gzip.GzipFile(mode="w", fileobj=gz_file) as f: - if not isinstance(contents, bytes): - contents = contents.encode("utf8") - f.write(contents) + encoded = contents if isinstance(contents, bytes) else contents.encode("utf8") + _ = f.write(encoded) return gz_file.getvalue() -def read_gzip(path): - with gzip.GzipFile(path, mode='r') as gz: +def read_gzip(path: str | Path) -> str: + with gzip.GzipFile(str(path), mode="r") as gz: return gz.read().decode("utf8") @contextlib.contextmanager -def unzip(contents): # type: (bytes) -> zipfile.ZipFile +def unzip(contents: bytes) -> Iterator[zipfile.ZipFile]: """Get zipped contents.""" zipped = io.BytesIO(contents) archive = zipfile.ZipFile(zipped, mode="r") try: yield archive - finally: archive.close() -def unzip_and_save(contents, path, member=None, verbose=True): +def unzip_and_save(contents: bytes, path: str, member: str | None = None, verbose: bool = True) -> None: """Save unzipped from raw zipped contents.""" with unzip(contents) as archive: - if member: - archive.extract(member, path) + _ = archive.extract(member, path) else: archive.extractall(path) if verbose: - name_list = archive.namelist()[member] if not member else archive.namelist() - print('Saved files to {}: \n\t- {}'.format(path, '\n\t- '.join(name_list))) + name_list = archive.namelist() + print("Saved files to {}: \n\t- {}".format(path, "\n\t- ".join(name_list))) -def unzip_to_dict(zipped: zipfile.ZipFile, load_json=True) -> Dict[str, Union[dict, str]]: +def unzip_to_dict(zipped: zipfile.ZipFile, load_json: bool = True) -> dict[str, Any]: """Unzip and load contents to dict with filenames as keys.""" - bundle = {} + bundle: dict[str, Any] = {} for filename in zipped.namelist(): - if filename.endswith('/'): + if filename.endswith("/"): continue fp = Path(filename) contents = zipped.read(filename) - if load_json and fp.suffix == '.json': + if load_json and fp.suffix == ".json": contents = json.loads(contents) bundle[fp.name] = contents @@ -237,7 +209,12 @@ def unzip_to_dict(zipped: zipfile.ZipFile, load_json=True) -> Dict[str, Union[di return bundle -def event_sort(events, timestamp='@timestamp', date_format='%Y-%m-%dT%H:%M:%S.%f%z', asc=True): +def event_sort( + events: list[Any], + timestamp: str = "@timestamp", + date_format: str = "%Y-%m-%dT%H:%M:%S.%f%z", + order_asc: bool = True, +) -> list[Any]: """Sort events from elasticsearch by timestamp.""" def round_microseconds(t: str) -> str: @@ -247,40 +224,31 @@ def round_microseconds(t: str) -> str: # Return early if the timestamp string is empty return t - parts = t.split('.') - if len(parts) == 2: + parts = t.split(".") + if len(parts) == 2: # noqa: PLR2004 # Remove trailing "Z" from microseconds part micro_seconds = parts[1].rstrip("Z") - if len(micro_seconds) > 6: + if len(micro_seconds) > 6: # noqa: PLR2004 # If the microseconds part has more than 6 digits # Convert the microseconds part to a float and round to 6 decimal places rounded_micro_seconds = round(float(f"0.{micro_seconds}"), 6) # Format the rounded value to always have 6 decimal places # Reconstruct the timestamp string with the rounded microseconds part - formatted_micro_seconds = f'{rounded_micro_seconds:0.6f}'.split(".")[-1] + formatted_micro_seconds = f"{rounded_micro_seconds:0.6f}".split(".")[-1] t = f"{parts[0]}.{formatted_micro_seconds}Z" return t - def _event_sort(event: dict) -> datetime: + def _event_sort(event: dict[str, Any]) -> datetime: """Calculates the sort key for an event as a datetime object.""" t = round_microseconds(event[timestamp]) # Return the timestamp as a datetime object for comparison - return datetime.strptime(t, date_format) + return datetime.strptime(t, date_format) # noqa: DTZ007 - return sorted(events, key=_event_sort, reverse=not asc) - - -def combine_sources(*sources): # type: (list[list]) -> list - """Combine lists of events from multiple sources.""" - combined = [] - for source in sources: - combined.extend(source.copy()) - - return event_sort(combined) + return sorted(events, key=_event_sort, reverse=not order_asc) def convert_time_span(span: str) -> int: @@ -290,55 +258,55 @@ def convert_time_span(span: str) -> int: return eql.ast.TimeRange(amount, unit).as_milliseconds() -def evaluate(rule, events, normalize_kql_keywords: bool = False): - """Evaluate a query against events.""" - evaluator = kql.get_evaluator(kql.parse(rule.query), normalize_kql_keywords=normalize_kql_keywords) - filtered = list(filter(evaluator, events)) - return filtered - - -def unix_time_to_formatted(timestamp): # type: (int|str) -> str +def unix_time_to_formatted(timestamp: float | str) -> str: """Converts unix time in seconds or milliseconds to the default format.""" - if isinstance(timestamp, (int, float)): - if timestamp > 2 ** 32: + if isinstance(timestamp, (int | float)): + if timestamp > 2**32: timestamp = round(timestamp / 1000, 3) - return datetime.fromtimestamp(timestamp, timezone.utc).strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z' + return datetime.fromtimestamp(timestamp, UTC).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" + return timestamp -def normalize_timing_and_sort(events, timestamp='@timestamp', asc=True): +def normalize_timing_and_sort( + events: list[dict[str, Any]], + timestamp: str = "@timestamp", + order_asc: bool = True, +) -> list[Any]: """Normalize timestamp formats and sort events.""" for event in events: _timestamp = event[timestamp] if not isinstance(_timestamp, str): event[timestamp] = unix_time_to_formatted(_timestamp) - return event_sort(events, timestamp=timestamp, asc=asc) + return event_sort(events, timestamp=timestamp, order_asc=order_asc) -def freeze(obj): +def freeze(obj: Any) -> Any: """Helper function to make mutable objects immutable and hashable.""" if not isinstance(obj, type) and is_dataclass(obj): - obj = astuple(obj) + obj = astuple(obj) # type: ignore[reportUnknownVariableType] - if isinstance(obj, (list, tuple)): - return tuple(freeze(o) for o in obj) - elif isinstance(obj, dict): - return freeze(sorted(obj.items())) - else: - return obj + if isinstance(obj, (list | tuple)): + return tuple(freeze(o) for o in obj) # type: ignore[reportUnknownVariableType] + if isinstance(obj, dict): + items = obj.items() # type: ignore[reportUnknownVariableType] + return freeze(sorted(items)) # type: ignore[reportUnknownVariableType] + return obj -_cache = {} +_cache: dict[int, dict[tuple[Any, Any], Any]] = {} -def cached(f): +# Should be replaced with `functools.cache` +# https://docs.python.org/3/library/functools.html#functools.cache +def cached(f: Callable[..., Any]) -> Callable[..., Any]: """Helper function to memoize functions.""" func_key = id(f) @functools.wraps(f) - def wrapped(*args, **kwargs): - _cache.setdefault(func_key, {}) + def wrapped(*args: Any, **kwargs: Any) -> Any: + _ = _cache.setdefault(func_key, {}) cache_key = freeze(args), freeze(kwargs) if cache_key not in _cache[func_key]: @@ -346,69 +314,75 @@ def wrapped(*args, **kwargs): return _cache[func_key][cache_key] - def clear(): - _cache.pop(func_key, None) + def clear() -> None: + _ = _cache.pop(func_key, None) - wrapped.clear = clear + wrapped.clear = clear # type: ignore[reportAttributeAccessIssue] return wrapped -def clear_caches(): +def clear_caches() -> None: _cache.clear() -def rulename_to_filename(name: str, tactic_name: str = None, ext: str = '.toml') -> str: +def rulename_to_filename(name: str, tactic_name: str | None = None, ext: str = ".toml") -> str: """Convert a rule name to a filename.""" - name = re.sub(r'[^_a-z0-9]+', '_', name.strip().lower()).strip('_') + name = re.sub(r"[^_a-z0-9]+", "_", name.strip().lower()).strip("_") if tactic_name: - pre = rulename_to_filename(name=tactic_name, ext='') - name = f'{pre}_{name}' - return name + ext or '' + pre = rulename_to_filename(name=tactic_name, ext="") + name = f"{pre}_{name}" + return name + ext or "" -def load_rule_contents(rule_file: Path, single_only=False) -> list: +def load_rule_contents(rule_file: Path, single_only: bool = False) -> list[Any]: """Load a rule file from multiple formats.""" - _, extension = os.path.splitext(rule_file) + extension = rule_file.suffix raw_text = rule_file.read_text() - if extension in ('.ndjson', '.jsonl'): + if extension in (".ndjson", ".jsonl"): # kibana exported rule object is ndjson with the export metadata on the last line contents = [json.loads(line) for line in raw_text.splitlines()] - if len(contents) > 1 and 'exported_count' in contents[-1]: + if len(contents) > 1 and "exported_count" in contents[-1]: contents.pop(-1) if single_only and len(contents) > 1: - raise ValueError('Multiple rules not allowed') + raise ValueError("Multiple rules not allowed") return contents or [{}] - elif extension == '.toml': - rule = pytoml.loads(raw_text) - elif extension.lower() in ('yaml', 'yml'): + if extension == ".toml": + rule = pytoml.loads(raw_text) # type: ignore[reportUnknownVariableType] + elif extension.lower() in ("yaml", "yml"): rule = load_dump(str(rule_file)) else: return [] if isinstance(rule, dict): return [rule] - elif isinstance(rule, list): - return rule - else: - raise ValueError(f"Expected a list or dictionary in {rule_file}") + if isinstance(rule, list): + return rule # type: ignore[reportUnknownVariableType] + raise ValueError(f"Expected a list or dictionary in {rule_file}") -def load_json_from_branch(repo: Repository, file_path: str, branch: Optional[str]): +def load_json_from_branch(repo: Repository, file_path: str, branch: str) -> dict[str, Any]: """Load JSON file from a specific branch.""" - content_file = repo.get_contents(file_path, ref=branch) - return json.loads(content_file.decoded_content.decode("utf-8")) + content_files = repo.get_contents(file_path, ref=branch) + + if isinstance(content_files, list): + raise ValueError("Receive a list instead of a single value") # noqa: TRY004 + content_file = content_files + content = content_file.decoded_content + data = content.decode("utf-8") + return json.loads(data) -def compare_versions(base_json: dict, branch_json: dict) -> list[tuple[str, str, int, int]]: + +def compare_versions(base_json: dict[str, Any], branch_json: dict[str, Any]) -> list[tuple[str, str, int, int]]: """Compare versions of two lock version file JSON objects.""" - changes = [] - for key in base_json: + changes: list[tuple[str, str, int, int]] = [] + for key, base_val in base_json.items(): if key in branch_json: - base_version = base_json[key].get("version") + base_version = base_val.get("version") branch_name = branch_json[key].get("rule_name") branch_version = branch_json[key].get("version") if base_version != branch_version: @@ -418,7 +392,7 @@ def compare_versions(base_json: dict, branch_json: dict) -> list[tuple[str, str, def check_double_bumps(changes: list[tuple[str, str, int, int]]) -> list[tuple[str, str, int, int]]: """Check for double bumps in version changes of the result of compare versions of a version lock file.""" - double_bumps = [] + double_bumps: list[tuple[str, str, int, int]] = [] for key, name, removed, added in changes: # Determine the modulo dynamically based on the highest number of digits max_digits = max(len(str(removed)), len(str(added))) @@ -429,7 +403,11 @@ def check_double_bumps(changes: list[tuple[str, str, int, int]]) -> list[tuple[s def check_version_lock_double_bumps( - repo: Repository, file_path: str, base_branch: str, branch: str = "", local_file: Path = None + repo: Repository, + file_path: str, + base_branch: str, + branch: str = "", + local_file: Path | None = None, ) -> list[tuple[str, str, int, int]]: """Check for double bumps in version changes of the result of compare versions of a version lock file.""" base_json = load_json_from_branch(repo, file_path, base_branch) @@ -440,18 +418,16 @@ def check_version_lock_double_bumps( branch_json = load_json_from_branch(repo, file_path, branch) changes = compare_versions(base_json, branch_json) - double_bumps = check_double_bumps(changes) - - return double_bumps + return check_double_bumps(changes) -def format_command_options(ctx): +def format_command_options(ctx: click.Context) -> str: """Echo options for a click command.""" formatter = ctx.make_formatter() - opts = [] + opts: list[tuple[str, str]] = [] for param in ctx.command.get_params(ctx): - if param.name == 'help': + if param.name == "help": continue rv = param.get_help_record(ctx) @@ -459,15 +435,18 @@ def format_command_options(ctx): opts.append(rv) if opts: - with formatter.section('Options'): + with formatter.section("Options"): formatter.write_dl(opts) return formatter.getvalue() -def make_git(*prefix_args) -> Optional[Callable]: +def make_git(*prefix_args: Any) -> Callable[..., str]: git_exe = shutil.which("git") - prefix_args = [str(arg) for arg in prefix_args] + prefix_arg_strs = [str(arg) for arg in prefix_args] + + if "-C" not in prefix_arg_strs: + prefix_arg_strs = ["-C", str(ROOT_DIR), *prefix_arg_strs] if not git_exe: click.secho("Unable to find git", err=True, fg="red") @@ -476,58 +455,56 @@ def make_git(*prefix_args) -> Optional[Callable]: if ctx is not None: ctx.exit(1) - return + raise ValueError("Git not found") - def git(*args, print_output=False): - nonlocal prefix_args - - if '-C' not in prefix_args: - prefix_args = ['-C', get_path()] + prefix_args - - full_args = [git_exe] + prefix_args + [str(arg) for arg in args] - if print_output: - return subprocess.check_call(full_args) + def git(*args: Any) -> str: + arg_strs = [str(arg) for arg in args] + full_args = [git_exe, *prefix_arg_strs, *arg_strs] return subprocess.check_output(full_args, encoding="utf-8").rstrip() return git -def git(*args, **kwargs): +def git(*args: Any, **kwargs: Any) -> str | int: """Find and run a one-off Git command.""" - return make_git()(*args, **kwargs) + g = make_git() + return g(*args, **kwargs) + + +FuncT = Callable[..., Any] -def add_params(*params): +def add_params(*params: Any) -> Callable[[FuncT], FuncT]: """Add parameters to a click command.""" - def decorator(f): - if not hasattr(f, '__click_params__'): - f.__click_params__ = [] - f.__click_params__.extend(params) + def decorator(f: FuncT) -> FuncT: + if not hasattr(f, "__click_params__"): + f.__click_params__ = [] # type: ignore[reportFunctionMemberAccess] + f.__click_params__.extend(params) # type: ignore[reportFunctionMemberAccess] return f return decorator -class Ndjson(list): +class Ndjson(list[dict[str, Any]]): """Wrapper for ndjson data.""" - def to_string(self, sort_keys: bool = False): + def to_string(self, sort_keys: bool = False) -> str: """Format contents list to ndjson string.""" - return '\n'.join(json.dumps(c, sort_keys=sort_keys) for c in self) + '\n' + return "\n".join(json.dumps(c, sort_keys=sort_keys) for c in self) + "\n" @classmethod - def from_string(cls, ndjson_string: str, **kwargs): + def from_string(cls, ndjson_string: str, **kwargs: Any) -> "Ndjson": """Load ndjson string to a list.""" contents = [json.loads(line, **kwargs) for line in ndjson_string.strip().splitlines()] return Ndjson(contents) - def dump(self, filename: Path, sort_keys=False): + def dump(self, filename: Path, sort_keys: bool = False) -> None: """Save contents to an ndjson file.""" - filename.write_text(self.to_string(sort_keys=sort_keys)) + _ = filename.write_text(self.to_string(sort_keys=sort_keys)) @classmethod - def load(cls, filename: Path, **kwargs): + def load(cls, filename: Path, **kwargs: Any) -> "Ndjson": """Load content from an ndjson file.""" return cls.from_string(filename.read_text(), **kwargs) @@ -535,19 +512,18 @@ def load(cls, filename: Path, **kwargs): class PatchedTemplate(Template): """String template with updated methods from future versions.""" - def get_identifiers(self): + def get_identifiers(self) -> list[str]: """Returns a list of the valid identifiers in the template, in the order they first appear, ignoring any invalid identifiers.""" # https://github.com/python/cpython/blob/3b4f8fc83dcea1a9d0bc5bd33592e5a3da41fa71/Lib/string.py#LL157-L171C19 - ids = [] + ids: list[str] = [] for mo in self.pattern.finditer(self.template): - named = mo.group('named') or mo.group('braced') - if named is not None and named not in ids: + named = mo.group("named") or mo.group("braced") + if named and named not in ids: # add a named group only the first time it appears ids.append(named) - elif named is None and mo.group('invalid') is None and mo.group('escaped') is None: + elif not named and mo.group("invalid") is None and mo.group("escaped") is None: # If all the groups are None, there must be # another group we're not expecting - raise ValueError('Unrecognized named group in pattern', - self.pattern) + raise ValueError("Unrecognized named group in pattern", self.pattern) return ids diff --git a/detection_rules/version_lock.py b/detection_rules/version_lock.py index 36c62633bb6..1c6d893bda7 100644 --- a/detection_rules/version_lock.py +++ b/detection_rules/version_lock.py @@ -3,26 +3,25 @@ # 2.0; you may not use this file except in compliance with the Elastic License # 2.0. """Helper utilities to manage the version lock.""" + from copy import deepcopy from dataclasses import dataclass from pathlib import Path -from typing import ClassVar, Dict, List, Optional, Union +from typing import Any, ClassVar, Literal import click from semver import Version from .config import parse_rules_config from .mixins import LockDataclassMixin, MarshmallowDataclassMixin -from .rule_loader import RuleCollection from .schemas import definitions from .utils import cached - RULES_CONFIG = parse_rules_config() # This was the original version the lock was created under. This constant has been replaced by -# schemas.get_min_supported_stack_version to dynamically determine the minimum -# MIN_LOCK_VERSION_DEFAULT = Version("7.13.0") +# schemas.get_min_supported_stack_version to dynamically determine the minimum: +# - MIN_LOCK_VERSION_DEFAULT = Version("7.13.0") @dataclass(frozen=True) @@ -35,30 +34,33 @@ class BaseEntry: @dataclass(frozen=True) class PreviousEntry(BaseEntry): - # this is Optional for resiliency in already tagged branches missing this field. This means we should strictly # validate elsewhere - max_allowable_version: Optional[int] + max_allowable_version: int | None = None @dataclass(frozen=True) class VersionLockFileEntry(MarshmallowDataclassMixin, BaseEntry): """Schema for a rule entry in the version lock.""" - min_stack_version: Optional[definitions.SemVerMinorOnly] - previous: Optional[Dict[definitions.SemVerMinorOnly, PreviousEntry]] + + min_stack_version: definitions.SemVerMinorOnly | None = None + previous: dict[definitions.SemVerMinorOnly, PreviousEntry] | None = None @dataclass(frozen=True) class VersionLockFile(LockDataclassMixin): """Schema for the full version lock file.""" - data: Dict[Union[definitions.UUIDString, definitions.KNOWN_BAD_RULE_IDS], VersionLockFileEntry] + + data: dict[definitions.UUIDString | definitions.KNOWN_BAD_RULE_IDS, VersionLockFileEntry] file_path: ClassVar[Path] = RULES_CONFIG.version_lock_file - def __contains__(self, rule_id: str): + def __contains__(self, rule_id: str) -> bool: """Check if a rule is in the map by comparing IDs.""" return rule_id in self.data - def __getitem__(self, item) -> VersionLockFileEntry: + def __getitem__( + self, item: definitions.UUIDString | Literal["119c8877-8613-416d-a98a-96b6664ee73a5"] + ) -> VersionLockFileEntry: """Return entries by rule id.""" if item not in self.data: raise KeyError(item) @@ -68,7 +70,8 @@ def __getitem__(self, item) -> VersionLockFileEntry: @dataclass(frozen=True) class DeprecatedRulesEntry(MarshmallowDataclassMixin): """Schema for rule entry in the deprecated rules file.""" - deprecation_date: Union[definitions.Date, definitions.KNOWN_BAD_DEPRECATED_DATES] + + deprecation_date: definitions.Date | definitions.KNOWN_BAD_DEPRECATED_DATES rule_name: definitions.RuleName stack_version: definitions.SemVer @@ -76,14 +79,17 @@ class DeprecatedRulesEntry(MarshmallowDataclassMixin): @dataclass(frozen=True) class DeprecatedRulesFile(LockDataclassMixin): """Schema for the full deprecated rules file.""" - data: Dict[Union[definitions.UUIDString, definitions.KNOWN_BAD_RULE_IDS], DeprecatedRulesEntry] + + data: dict[definitions.UUIDString | definitions.KNOWN_BAD_RULE_IDS, DeprecatedRulesEntry] file_path: ClassVar[Path] = RULES_CONFIG.deprecated_rules_file - def __contains__(self, rule_id: str): + def __contains__(self, rule_id: str) -> bool: """Check if a rule is in the map by comparing IDs.""" return rule_id in self.data - def __getitem__(self, item) -> DeprecatedRulesEntry: + def __getitem__( + self, item: definitions.UUIDString | Literal["119c8877-8613-416d-a98a-96b6664ee73a5"] + ) -> DeprecatedRulesEntry: """Return entries by rule id.""" if item not in self.data: raise KeyError(item) @@ -91,7 +97,7 @@ def __getitem__(self, item) -> DeprecatedRulesEntry: @cached -def load_versions() -> dict: +def load_versions() -> dict[str, Any]: """Load and validate the default version.lock file.""" version_lock_file = VersionLockFile.load_from_file() return version_lock_file.to_dict() @@ -100,20 +106,20 @@ def load_versions() -> dict: # for tagged branches which existed before the types were added and validation enforced, we will need to manually add # them to allow them to pass validation. These will only ever currently be loaded via the RuleCollection.load_git_tag # method, which is primarily for generating diffs across releases, so there is no risk to versioning -def add_rule_types_to_lock(lock_contents: dict, rule_map: Dict[str, dict]): +def add_rule_types_to_lock(lock_contents: dict[str, Any], rule_map: dict[str, Any]) -> dict[str, Any]: """Add the rule type to entries in the lock file,if missing.""" for rule_id, lock in lock_contents.items(): rule = rule_map.get(rule_id, {}) # this defaults to query if the rule is not found - it is just for validation so should not impact - rule_type = rule.get('rule', {}).get('type', 'query') + rule_type = rule.get("rule", {}).get("type", "query") # the type is a bit less important than the structure to pass validation - lock['type'] = rule_type + lock["type"] = rule_type - if 'previous' in lock: - for _, prev_lock in lock['previous'].items(): - prev_lock['type'] = rule_type + if "previous" in lock: + for prev_lock in lock["previous"].values(): + prev_lock["type"] = rule_type return lock_contents @@ -121,16 +127,24 @@ def add_rule_types_to_lock(lock_contents: dict, rule_map: Dict[str, dict]): class VersionLock: """Version handling for rule files and collections.""" - def __init__(self, version_lock_file: Optional[Path] = None, deprecated_lock_file: Optional[Path] = None, - version_lock: Optional[dict] = None, deprecated_lock: Optional[dict] = None, - name: Optional[str] = None, invalidated: Optional[bool] = False): - + def __init__( # noqa: PLR0913 + self, + version_lock_file: Path | None = None, + deprecated_lock_file: Path | None = None, + version_lock: dict[str, Any] | None = None, + deprecated_lock: dict[str, Any] | None = None, + name: str | None = None, + invalidated: bool | None = False, + ) -> None: if invalidated: err_msg = "This VersionLock configuration is not valid when configued to bypass_version_lock." raise NotImplementedError(err_msg) - assert (version_lock_file or version_lock), 'Must provide version lock file or contents' - assert (deprecated_lock_file or deprecated_lock), 'Must provide deprecated lock file or contents' + if not version_lock_file and not version_lock: + raise ValueError("Must provide version lock file or contents") + + if not deprecated_lock_file and not deprecated_lock: + raise ValueError("Must provide deprecated lock file or contents") self.name = name self.version_lock_file = version_lock_file @@ -139,55 +153,66 @@ def __init__(self, version_lock_file: Optional[Path] = None, deprecated_lock_fil if version_lock_file: self.version_lock = VersionLockFile.load_from_file(version_lock_file) else: - self.version_lock = VersionLockFile.from_dict(dict(data=version_lock)) + self.version_lock = VersionLockFile.from_dict({"data": version_lock}) if deprecated_lock_file: self.deprecated_lock = DeprecatedRulesFile.load_from_file(deprecated_lock_file) else: - self.deprecated_lock = DeprecatedRulesFile.from_dict(dict(data=deprecated_lock)) + self.deprecated_lock = DeprecatedRulesFile.from_dict({"data": deprecated_lock}) @staticmethod - def save_file(path: Path, lock_file: Union[VersionLockFile, DeprecatedRulesFile]): - assert path, f'{path} not set' + def save_file(path: Path, lock_file: VersionLockFile | DeprecatedRulesFile) -> None: lock_file.save_to_file(path) - print(f'Updated {path} file') - - def get_locked_version(self, rule_id: str, min_stack_version: Optional[str] = None) -> Optional[int]: - if rule_id in self.version_lock: - latest_version_info = self.version_lock[rule_id] - if latest_version_info.previous and latest_version_info.previous.get(min_stack_version): - stack_version_info = latest_version_info.previous.get(min_stack_version) - else: - stack_version_info = latest_version_info - return stack_version_info.version - - def get_locked_hash(self, rule_id: str, min_stack_version: Optional[str] = None) -> Optional[str]: + print(f"Updated {path} file") + + def get_locked_version(self, rule_id: str, min_stack_version: str | None = None) -> int | None: + if rule_id not in self.version_lock: + return None + + latest_version_info = self.version_lock[rule_id] + if latest_version_info.previous and latest_version_info.previous.get(min_stack_version): + stack_version_info = latest_version_info.previous.get(min_stack_version) + else: + stack_version_info = latest_version_info + return stack_version_info.version + + def get_locked_hash(self, rule_id: str, min_stack_version: str | None = None) -> str | None: """Get the version info matching the min_stack_version if present.""" - if rule_id in self.version_lock: - latest_version_info = self.version_lock[rule_id] - if latest_version_info.previous and latest_version_info.previous.get(min_stack_version): - stack_version_info = latest_version_info.previous.get(min_stack_version) - else: - stack_version_info = latest_version_info - existing_sha256: str = stack_version_info.sha256 - return existing_sha256 - - def manage_versions(self, rules: RuleCollection, - exclude_version_update=False, save_changes=False, - verbose=True, buffer_int: int = 100) -> (List[str], List[str], List[str]): + if rule_id not in self.version_lock: + return None + latest_version_info = self.version_lock[rule_id] + if latest_version_info.previous and latest_version_info.previous.get(min_stack_version): + stack_version_info = latest_version_info.previous.get(min_stack_version) + else: + stack_version_info = latest_version_info + existing_sha256: str = stack_version_info.sha256 + return existing_sha256 + + def manage_versions( # noqa: PLR0912, PLR0915 + self, + rules: Any, # type: ignore[reportRedeclaration] + exclude_version_update: bool = False, + save_changes: bool = False, + verbose: bool = True, + buffer_int: int = 100, + ) -> tuple[list[definitions.UUIDString], list[str], list[str]]: """Update the contents of the version.lock file and optionally save changes.""" from .packaging import current_stack_version + from .rule import TOMLRule + from .rule_loader import RuleCollection # noqa: TC001 + + rules: RuleCollection = rules # noqa: PLW0127 version_lock_hash = self.version_lock.sha256() lock_file_contents = deepcopy(self.version_lock.to_dict()) current_deprecated_lock = deepcopy(self.deprecated_lock.to_dict()) - verbose_echo = click.echo if verbose else (lambda x: None) + verbose_echo = click.echo if verbose else (lambda _: None) # type: ignore[reportUnknownVariableType] already_deprecated = set(current_deprecated_lock) deprecated_rules = set(rules.deprecated.id_map) - new_rules = set(rule.id for rule in rules if rule.contents.saved_version is None) - deprecated_rules - changed_rules = set(rule.id for rule in rules if rule.contents.is_dirty) - deprecated_rules + new_rules = {rule.id for rule in rules if rule.contents.saved_version is None} - deprecated_rules + changed_rules = {rule.id for rule in rules if rule.contents.is_dirty} - deprecated_rules # manage deprecated rules newly_deprecated = deprecated_rules - already_deprecated @@ -195,22 +220,22 @@ def manage_versions(self, rules: RuleCollection, if not (new_rules or changed_rules or newly_deprecated): return list(changed_rules), list(new_rules), list(newly_deprecated) - verbose_echo('Rule changes detected!') - changes = [] + verbose_echo("Rule changes detected!") - def log_changes(r, route_taken, new_rule_version, *msg): - new = [f' {route_taken}: {r.id}, new version: {new_rule_version}'] - new.extend([f' - {m}' for m in msg if m]) + changes: list[str] = [] + + def log_changes(r: TOMLRule, route_taken: str, new_rule_version: Any, *msg: str) -> None: + new = [f" {route_taken}: {r.id}, new version: {new_rule_version}"] + new.extend([f" - {m}" for m in msg if m]) changes.extend(new) for rule in rules: if rule.contents.metadata.maturity == "production" or rule.id in newly_deprecated: # assume that older stacks are always locked first - min_stack = Version.parse(rule.contents.get_supported_version(), - optional_minor_and_patch=True) + min_stack = Version.parse(rule.contents.get_supported_version(), optional_minor_and_patch=True) lock_from_rule = rule.contents.lock_info(bump=not exclude_version_update) - lock_from_file: dict = lock_file_contents.setdefault(rule.id, {}) + lock_from_file = lock_file_contents.setdefault(rule.id, {}) # scenarios to handle, assuming older stacks are always locked first: # 1) no breaking changes ever made or the first time a rule is created @@ -218,40 +243,42 @@ def log_changes(r, route_taken, new_rule_version, *msg): # 3) on the latest stack, locking in a breaking change # 4) on an old stack, after a breaking change has been made latest_locked_stack_version = rule.contents.convert_supported_version( - lock_from_file.get("min_stack_version")) + lock_from_file.get("min_stack_version") + ) # strip version down to only major.minor to compare against lock file versioning stripped_version = f"{min_stack.major}.{min_stack.minor}" if not lock_from_file or min_stack == latest_locked_stack_version: - route = 'A' + route = "A" # 1) no breaking changes ever made or the first time a rule is created # 2) on the latest, after a breaking change has been locked lock_from_file.update(lock_from_rule) - new_version = lock_from_rule['version'] + new_version = lock_from_rule["version"] # add the min_stack_version to the lock if it's explicitly set if rule.contents.metadata.min_stack_version is not None: lock_from_file["min_stack_version"] = stripped_version - log_msg = f'min_stack_version added: {min_stack}' + log_msg = f"min_stack_version added: {min_stack}" log_changes(rule, route, new_version, log_msg) elif min_stack > latest_locked_stack_version: - route = 'B' + route = "B" # 3) on the latest stack, locking in a breaking change - stripped_latest_locked_stack_version = f"{latest_locked_stack_version.major}." \ - f"{latest_locked_stack_version.minor}" + stripped_latest_locked_stack_version = ( + f"{latest_locked_stack_version.major}.{latest_locked_stack_version.minor}" + ) # preserve buffer space to support forked version spacing if exclude_version_update: buffer_int -= 1 lock_from_rule["version"] = lock_from_file["version"] + buffer_int previous_lock_info = { - "max_allowable_version": lock_from_rule['version'] - 1, + "max_allowable_version": lock_from_rule["version"] - 1, "rule_name": lock_from_file["rule_name"], "sha256": lock_from_file["sha256"], "version": lock_from_file["version"], - "type": lock_from_file["type"] + "type": lock_from_file["type"], } lock_from_file.setdefault("previous", {}) @@ -260,42 +287,50 @@ def log_changes(r, route_taken, new_rule_version, *msg): # overwrite the "latest" part of the lock at the top level lock_from_file.update(lock_from_rule, min_stack_version=stripped_version) - new_version = lock_from_rule['version'] + new_version = lock_from_rule["version"] log_changes( - rule, route, new_version, - f'previous {stripped_latest_locked_stack_version} saved as \ - version: {previous_lock_info["version"]}', - f'current min_stack updated to {stripped_version}' + rule, + route, + new_version, + f"previous {stripped_latest_locked_stack_version} saved as \ + version: {previous_lock_info['version']}", + f"current min_stack updated to {stripped_version}", ) elif min_stack < latest_locked_stack_version: - route = 'C' + route = "C" # 4) on an old stack, after a breaking change has been made (updated fork) - assert stripped_version in lock_from_file.get("previous", {}), \ - f"Expected {rule.id} @ v{stripped_version} in the rule lock" + if stripped_version not in lock_from_file.get("previous", {}): + raise ValueError(f"Expected {rule.id} @ v{stripped_version} in the rule lock") - # TODO: Figure out whether we support locking old versions and if we want to - # "leave room" by skipping versions when breaking changes are made. - # We can still inspect the version lock manually after locks are made, - # since it's a good summary of everything that happens + # TODO: Figure out whether we support locking old versions # noqa: TD002, TD003, FIX002 + # and if we want to "leave room" by skipping versions when breaking changes are made. + # We can still inspect the version lock manually after locks are made, + # since it's a good summary of everything that happens previous_entry = lock_from_file["previous"][stripped_version] - max_allowable_version = previous_entry['max_allowable_version'] + max_allowable_version = previous_entry["max_allowable_version"] - # if version bump collides with future bump: fail - # if space: change and log - info_from_rule = (lock_from_rule['sha256'], lock_from_rule['version']) - info_from_file = (previous_entry['sha256'], previous_entry['version']) + # if version bump collides with future bump, fail + # if space, change and log + info_from_rule = (lock_from_rule["sha256"], lock_from_rule["version"]) + info_from_file = (previous_entry["sha256"], previous_entry["version"]) - if lock_from_rule['version'] > max_allowable_version: - raise ValueError(f'Forked rule: {rule.id} - {rule.name} has changes that will force it to ' - f'exceed the max allowable version of {max_allowable_version}') + if lock_from_rule["version"] > max_allowable_version: + raise ValueError( + f"Forked rule: {rule.id} - {rule.name} has changes that will force it to " + f"exceed the max allowable version of {max_allowable_version}" + ) if info_from_rule != info_from_file: lock_from_file["previous"][stripped_version].update(lock_from_rule) new_version = lock_from_rule["version"] - log_changes(rule, route, 'unchanged', - f'previous version {stripped_version} updated version to {new_version}') + log_changes( + rule, + route, + "unchanged", + f"previous version {stripped_version} updated version to {new_version}", + ) continue else: raise RuntimeError("Unreachable code") @@ -305,34 +340,35 @@ def log_changes(r, route_taken, new_rule_version, *msg): current_deprecated_lock[rule.id] = { "rule_name": rule.name, "stack_version": current_stack_version(), - "deprecation_date": rule.contents.metadata['deprecation_date'] + "deprecation_date": rule.contents.metadata["deprecation_date"], } if save_changes or verbose: - click.echo(f' - {len(changed_rules)} changed rules') - click.echo(f' - {len(new_rules)} new rules') - click.echo(f' - {len(newly_deprecated)} newly deprecated rules') + click.echo(f" - {len(changed_rules)} changed rules") + click.echo(f" - {len(new_rules)} new rules") + click.echo(f" - {len(newly_deprecated)} newly deprecated rules") if not save_changes: verbose_echo( - 'run `build-release --update-version-lock` to update version.lock.json and deprecated_rules.json') + "run `build-release --update-version-lock` to update version.lock.json and deprecated_rules.json" + ) return list(changed_rules), list(new_rules), list(newly_deprecated) - click.echo('Detailed changes: \n' + '\n'.join(changes)) + click.echo("Detailed changes: \n" + "\n".join(changes)) # reset local version lock - self.version_lock = VersionLockFile.from_dict(dict(data=lock_file_contents)) - self.deprecated_lock = DeprecatedRulesFile.from_dict(dict(data=current_deprecated_lock)) + self.version_lock = VersionLockFile.from_dict({"data": lock_file_contents}) + self.deprecated_lock = DeprecatedRulesFile.from_dict({"data": current_deprecated_lock}) new_hash = self.version_lock.sha256() - if version_lock_hash != new_hash: + if version_lock_hash != new_hash and self.version_lock_file: self.save_file(self.version_lock_file, self.version_lock) - if newly_deprecated: + if newly_deprecated and self.deprecated_lock_file: self.save_file(self.deprecated_lock_file, self.deprecated_lock) - return changed_rules, list(new_rules), newly_deprecated + return list(changed_rules), list(new_rules), list(newly_deprecated) name = str(RULES_CONFIG.version_lock_file) diff --git a/hunting/__main__.py b/hunting/__main__.py index 4ce320566a0..44bcc6da339 100644 --- a/hunting/__main__.py +++ b/hunting/__main__.py @@ -8,9 +8,10 @@ from collections import Counter from dataclasses import asdict from pathlib import Path +from typing import Any import click -from tabulate import tabulate +from tabulate import tabulate # type: ignore[reportMissingModuleSource] from detection_rules.misc import parse_user_config @@ -18,30 +19,27 @@ from .markdown import MarkdownGenerator from .run import QueryRunner from .search import QueryIndex -from .utils import (filter_elasticsearch_params, get_hunt_path, load_all_toml, - load_toml, update_index_yml) +from .utils import filter_elasticsearch_params, get_hunt_path, load_all_toml, load_toml, update_index_yml @click.group() -def hunting(): +def hunting() -> None: """Commands for managing hunting queries and converting TOML to Markdown.""" - pass -@hunting.command('generate-markdown') -@click.argument('path', required=False) -def generate_markdown(path: Path = None): +@hunting.command("generate-markdown") +@click.argument("path", required=False, type=Path) +def generate_markdown(path: Path | None = None) -> None: """Convert TOML hunting queries to Markdown format.""" markdown_generator = MarkdownGenerator(HUNTING_DIR) if path: - path = Path(path) - if path.is_file() and path.suffix == '.toml': + if path.is_file() and path.suffix == ".toml": click.echo(f"Generating Markdown for single file: {path}") markdown_generator.process_file(path) elif (HUNTING_DIR / path).is_dir(): click.echo(f"Generating Markdown for folder: {path}") - markdown_generator.process_folder(path) + markdown_generator.process_folder(str(path)) else: raise ValueError(f"Invalid path provided: {path}") else: @@ -52,8 +50,8 @@ def generate_markdown(path: Path = None): markdown_generator.update_index_md() -@hunting.command('refresh-index') -def refresh_index(): +@hunting.command("refresh-index") +def refresh_index() -> None: """Refresh the index.yml file from TOML files and then refresh the index.md file.""" click.echo("Refreshing the index.yml and index.md files.") update_index_yml(HUNTING_DIR) @@ -62,13 +60,13 @@ def refresh_index(): click.echo("Index refresh complete.") -@hunting.command('search') -@click.option('--tactic', type=str, default=None, help="Search by MITRE tactic ID (e.g., TA0001)") -@click.option('--technique', type=str, default=None, help="Search by MITRE technique ID (e.g., T1078)") -@click.option('--sub-technique', type=str, default=None, help="Search by MITRE sub-technique ID (e.g., T1078.001)") -@click.option('--data-source', type=str, default=None, help="Filter by data_source like 'aws', 'macos', or 'linux'") -@click.option('--keyword', type=str, default=None, help="Search by keyword in name, description, and notes") -def search_queries(tactic: str, technique: str, sub_technique: str, data_source: str, keyword: str): +@hunting.command("search") +@click.option("--tactic", type=str, default=None, help="Search by MITRE tactic ID (e.g., TA0001)") +@click.option("--technique", type=str, default=None, help="Search by MITRE technique ID (e.g., T1078)") +@click.option("--sub-technique", type=str, default=None, help="Search by MITRE sub-technique ID (e.g., T1078.001)") +@click.option("--data-source", type=str, default=None, help="Filter by data_source like 'aws', 'macos', or 'linux'") +@click.option("--keyword", type=str, default=None, help="Search by keyword in name, description, and notes") +def search_queries(tactic: str, technique: str, sub_technique: str, data_source: str, keyword: str) -> None: """Search for queries based on MITRE tactic, technique, sub-technique, or data_source.""" if not any([tactic, technique, sub_technique, data_source, keyword]): @@ -90,13 +88,13 @@ def search_queries(tactic: str, technique: str, sub_technique: str, data_source: click.secho(f"\nFound {len(results)} matching queries:\n", fg="green", bold=True) # Prepare the data for tabulate - table_data = [] + table_data: list[str | Any] = [] for result in results: # Customize output to include technique, data_source, and UUID - data_source_str = result['data_source'] - mitre_str = ", ".join(result['mitre']) - uuid = result['uuid'] - table_data.append([result['name'], uuid, result['path'], data_source_str, mitre_str]) + data_source_str = result["data_source"] + mitre_str = ", ".join(result["mitre"]) + uuid = result["uuid"] + table_data.append([result["name"], uuid, result["path"], data_source_str, mitre_str]) # Output results using tabulate table_headers = ["Name", "UUID", "Location", "Data Source", "MITRE"] @@ -106,13 +104,18 @@ def search_queries(tactic: str, technique: str, sub_technique: str, data_source: click.secho("No matching queries found.", fg="red", bold=True) -@hunting.command('view-hunt') -@click.option('--uuid', type=str, help="View a specific hunt by UUID.") -@click.option('--path', type=str, help="View a specific hunt by file path.") -@click.option('--format', 'output_format', default='toml', type=click.Choice(['toml', 'json'], case_sensitive=False), - help="Output format (toml or json).") -@click.option('--query-only', is_flag=True, help="Only display the query content.") -def view_hunt(uuid: str, path: str, output_format: str, query_only: bool): +@hunting.command("view-hunt") +@click.option("--uuid", type=str, help="View a specific hunt by UUID.") +@click.option("--path", type=str, help="View a specific hunt by file path.") +@click.option( + "--format", + "output_format", + default="toml", + type=click.Choice(["toml", "json"], case_sensitive=False), + help="Output format (toml or json).", +) +@click.option("--query-only", is_flag=True, help="Only display the query content.") +def view_hunt(uuid: str, path: str, output_format: str, query_only: bool) -> None: """View a specific hunt by UUID or file path in the specified format (TOML or JSON).""" # Get the hunt path or error message @@ -121,6 +124,9 @@ def view_hunt(uuid: str, path: str, output_format: str, query_only: bool): if error_message: raise click.ClickException(error_message) + if not hunt_path: + raise ValueError("No hunt path found") + # Load the TOML data hunt = load_toml(hunt_path) @@ -134,18 +140,21 @@ def view_hunt(uuid: str, path: str, output_format: str, query_only: bool): return # Output the hunt in the requested format - if output_format == 'toml': + if output_format == "toml": click.echo(hunt_path.read_text()) - elif output_format == 'json': + elif output_format == "json": hunt_dict = asdict(hunt) click.echo(json.dumps(hunt_dict, indent=4)) -@hunting.command('hunt-summary') -@click.option('--breakdown', type=click.Choice(['platform', 'integration', 'language'], - case_sensitive=False), default='platform', - help="Specify how to break down the summary: 'platform', 'integration', or 'language'.") -def hunt_summary(breakdown: str): +@hunting.command("hunt-summary") +@click.option( + "--breakdown", + type=click.Choice(["platform", "integration", "language"], case_sensitive=False), + default="platform", + help="Specify how to break down the summary: 'platform', 'integration', or 'language'.", +) +def hunt_summary(breakdown: str) -> None: """ Generate a summary of hunt queries, broken down by platform, integration, or language. """ @@ -155,9 +164,9 @@ def hunt_summary(breakdown: str): all_hunts = load_all_toml(HUNTING_DIR) # Use Counter for more concise counting - platform_counter = Counter() - integration_counter = Counter() - language_counter = Counter() + platform_counter: Counter[str] = Counter() + integration_counter: Counter[str] = Counter() + language_counter: Counter[str] = Counter() for hunt, path in all_hunts: # Get the platform based on the folder name @@ -168,29 +177,31 @@ def hunt_summary(breakdown: str): integration_counter.update(hunt.integration) # Count languages, renaming 'SQL' to 'OSQuery' - languages = ['OSQuery' if lang == 'SQL' else lang for lang in hunt.language] + languages = ["OSQuery" if lang == "SQL" else lang for lang in hunt.language] language_counter.update(languages) # Prepare and display the table based on the selected breakdown - if breakdown == 'platform': + if breakdown == "platform": table_data = [[platform, count] for platform, count in platform_counter.items()] table_headers = ["Platform (Folder)", "Hunt Count"] - elif breakdown == 'integration': + elif breakdown == "integration": table_data = [[integration, count] for integration, count in integration_counter.items()] table_headers = ["Integration", "Hunt Count"] - elif breakdown == 'language': + elif breakdown == "language": table_data = [[language, count] for language, count in language_counter.items()] table_headers = ["Language", "Hunt Count"] + else: + raise ValueError(f"Unsupported breakdown value: {breakdown}") click.echo(tabulate(table_data, headers=table_headers, tablefmt="fancy_grid")) -@hunting.command('run-query') -@click.option('--uuid', help="The UUID of the hunting query to run.") -@click.option('--file-path', help="The file path of the hunting query to run.") -@click.option('--all', 'run_all', is_flag=True, help="Run all eligible queries in the file.") -@click.option('--wait-time', 'wait_time', default=180, help="Time to wait for query completion.") -def run_query(uuid: str, file_path: str, run_all: bool, wait_time: int): +@hunting.command("run-query") +@click.option("--uuid", help="The UUID of the hunting query to run.") +@click.option("--file-path", help="The file path of the hunting query to run.") +@click.option("--all", "run_all", is_flag=True, help="Run all eligible queries in the file.") +@click.option("--wait-time", "wait_time", default=180, help="Time to wait for query completion.") +def run_query(uuid: str, file_path: str, run_all: bool, wait_time: int) -> None: """Run a hunting query by UUID or file path. Only ES|QL queries are supported.""" # Get the hunt path or error message @@ -200,6 +211,9 @@ def run_query(uuid: str, file_path: str, run_all: bool, wait_time: int): click.echo(error_message) return + if not hunt_path: + raise ValueError("No hunt path found") + # Load the user configuration config = parse_user_config() if not config: @@ -234,7 +248,7 @@ def run_query(uuid: str, file_path: str, run_all: bool, wait_time: int): click.secho("Available queries:", fg="blue", bold=True) for i, query in eligible_queries.items(): click.secho(f"\nQuery {i + 1}:", fg="green", bold=True) - click.echo(query_runner._format_query(query)) + click.echo(query_runner.format_query(query)) click.secho("\n" + "-" * 120, fg="yellow") # Handle query selection @@ -244,8 +258,7 @@ def run_query(uuid: str, file_path: str, run_all: bool, wait_time: int): if query_number - 1 in eligible_queries: selected_query = eligible_queries[query_number - 1] break - else: - click.secho(f"Invalid query number: {query_number}. Please try again.", fg="yellow") + click.secho(f"Invalid query number: {query_number}. Please try again.", fg="yellow") except ValueError: click.secho("Please enter a valid number.", fg="yellow") diff --git a/hunting/definitions.py b/hunting/definitions.py index ef52da689f6..01e519958e3 100644 --- a/hunting/definitions.py +++ b/hunting/definitions.py @@ -6,7 +6,6 @@ import re from dataclasses import dataclass, field from pathlib import Path -from typing import Optional, List # Define the hunting directory path HUNTING_DIR = Path(__file__).parent @@ -16,33 +15,32 @@ ATTACK_URL = "https://attack.mitre.org/techniques/" # Static mapping for specific integrations -STATIC_INTEGRATION_LINK_MAP = { - 'aws_bedrock.invocation': 'aws_bedrock' -} +STATIC_INTEGRATION_LINK_MAP = {"aws_bedrock.invocation": "aws_bedrock"} @dataclass class Hunt: """Dataclass to represent a hunt.""" + author: str description: str - integration: List[str] + integration: list[str] uuid: str name: str - language: List[str] + language: list[str] license: str - query: List[str] - notes: Optional[List[str]] = field(default_factory=list) - mitre: List[str] = field(default_factory=list) - references: Optional[List[str]] = field(default_factory=list) + query: list[str] + notes: list[str] | None = field(default_factory=list) # type: ignore[reportUnknownVariableType] + mitre: list[str] = field(default_factory=list) # type: ignore[reportUnknownVariableType] + references: list[str] | None = field(default_factory=list) # type: ignore[reportUnknownVariableType] - def __post_init__(self): + def __post_init__(self) -> None: """Post-initialization to determine which validation to apply.""" if not self.query: raise ValueError(f"Hunt: {self.name} - Query field must be provided.") # Loop through each query in the array - for idx, q in enumerate(self.query): + for q in self.query: query_start = q.strip().lower() # Only validate queries that start with "from" (ESQL queries) @@ -55,11 +53,11 @@ def validate_esql_query(self, query: str) -> None: if self.author == "Elastic": # Regex patterns for checking "stats by" and "| keep" - stats_by_pattern = re.compile(r'\bstats\b.*?\bby\b', re.DOTALL) - keep_pattern = re.compile(r'\| keep', re.DOTALL) + stats_by_pattern = re.compile(r"\bstats\b.*?\bby\b", re.DOTALL) + keep_pattern = re.compile(r"\| keep", re.DOTALL) # Check if either "stats by" or "| keep" exists in the query if not stats_by_pattern.search(query) and not keep_pattern.search(query): raise ValueError( - f"Hunt: {self.name} contains an ES|QL query that must contain either 'stats by' or 'keep' functions." + f"Hunt: {self.name} contains an ES|QL query that mustcontain either 'stats by' or 'keep' functions." ) diff --git a/hunting/markdown.py b/hunting/markdown.py index 19c0c57995e..74f1063ff23 100644 --- a/hunting/markdown.py +++ b/hunting/markdown.py @@ -4,21 +4,24 @@ # 2.0. from pathlib import Path + import click + from .definitions import ATLAS_URL, ATTACK_URL, STATIC_INTEGRATION_LINK_MAP, Hunt from .utils import load_index_file, load_toml, save_index_file, validate_link class MarkdownGenerator: """Class to generate or update Markdown documentation from TOML or YAML files.""" - def __init__(self, base_path: Path): + + def __init__(self, base_path: Path) -> None: """Initialize with the base path and load the hunting index.""" self.base_path = base_path self.hunting_index = load_index_file() def process_file(self, file_path: Path) -> None: """Process a single TOML file and generate its Markdown representation.""" - if not file_path.is_file() or file_path.suffix != '.toml': + if not file_path.is_file() or file_path.suffix != ".toml": raise ValueError(f"The provided path is not a valid TOML file: {file_path}") click.echo(f"Processing specific TOML file: {file_path}") @@ -83,7 +86,7 @@ def convert_toml_to_markdown(self, hunt_config: Hunt, file_path: Path) -> str: def save_markdown(self, markdown_path: Path, content: str) -> None: """Save the Markdown content to a file.""" - markdown_path.write_text(content, encoding="utf-8") + _ = markdown_path.write_text(content, encoding="utf-8") click.echo(f"Markdown generated: {markdown_path}") def update_or_add_entry(self, hunt_config: Hunt, toml_path: Path) -> None: @@ -92,9 +95,9 @@ def update_or_add_entry(self, hunt_config: Hunt, toml_path: Path) -> None: uuid = hunt_config.uuid entry = { - 'name': hunt_config.name, - 'path': f"./{toml_path.resolve().relative_to(self.base_path).as_posix()}", - 'mitre': hunt_config.mitre + "name": hunt_config.name, + "path": f"./{toml_path.resolve().relative_to(self.base_path).as_posix()}", + "mitre": hunt_config.mitre, } if folder_name not in self.hunting_index: @@ -112,16 +115,16 @@ def create_docs_folder(self, file_path: Path) -> Path: def generate_integration_links(self, integrations: list[str]) -> list[str]: """Generate integration links for the documentation.""" - base_url = 'https://docs.elastic.co/integrations' - generated = [] + base_url = "https://docs.elastic.co/integrations" + generated: list[str] = [] for integration in integrations: if integration in STATIC_INTEGRATION_LINK_MAP: link_str = STATIC_INTEGRATION_LINK_MAP[integration] else: - link_str = integration.replace('.', '/') - link = f'{base_url}/{link_str}' + link_str = integration.replace(".", "/") + link = f"{base_url}/{link_str}" validate_link(link) - generated.append(f'[{integration}]({link})') + generated.append(f"[{integration}]({link})") return generated def update_index_md(self) -> None: @@ -135,10 +138,10 @@ def update_index_md(self) -> None: for folder, files in sorted(self.hunting_index.items()): index_content += f"\n\n## {folder}\n" - for file_info in sorted(files.values(), key=lambda x: x['name']): - md_path = file_info['path'].replace('queries', 'docs').replace('.toml', '.md') + for file_info in sorted(files.values(), key=lambda x: x["name"]): + md_path = file_info["path"].replace("queries", "docs").replace(".toml", ".md") index_content += f"- [{file_info['name']}]({md_path}) (ES|QL)\n" index_md_path = self.base_path / "index.md" - index_md_path.write_text(index_content, encoding="utf-8") + _ = index_md_path.write_text(index_content, encoding="utf-8") click.echo(f"Index Markdown updated at: {index_md_path}") diff --git a/hunting/run.py b/hunting/run.py index 17ce2991474..d7893f7d509 100644 --- a/hunting/run.py +++ b/hunting/run.py @@ -6,32 +6,34 @@ import re import textwrap from pathlib import Path +from typing import Any import click from detection_rules.misc import get_elasticsearch_client +from .definitions import Hunt from .utils import load_toml class QueryRunner: - def __init__(self, es_config: dict): + def __init__(self, es_config: dict[str, Any]) -> None: """Initialize the QueryRunner with Elasticsearch config.""" self.es_config = es_config - def load_hunting_file(self, file_path: Path): + def load_hunting_file(self, file_path: Path) -> Hunt: """Load the hunting file and return the data.""" return load_toml(file_path) def preprocess_query(self, query: str) -> str: """Pre-process the query by removing comments and adding a LIMIT.""" - query = re.sub(r'//.*', '', query) - if not re.search(r'LIMIT', query, re.IGNORECASE): + query = re.sub(r"//.*", "", query) + if not re.search(r"LIMIT", query, re.IGNORECASE): query += " | LIMIT 10" click.echo("No LIMIT detected in query. Added LIMIT 10 to truncate output.") return query - def run_individual_query(self, query: str, wait_timeout: int): + def run_individual_query(self, query: str, _: int) -> None: """Run a single query with the Elasticsearch config.""" es = get_elasticsearch_client(**self.es_config) query = self.preprocess_query(query) @@ -42,35 +44,33 @@ def run_individual_query(self, query: str, wait_timeout: int): # Start the query synchronously response = es.esql.query(query=query) - self.process_results(response) - except Exception as e: + + response_data = response.body + if response_data.get("values"): + click.secho("Query matches found!", fg="red", bold=True) + else: + click.secho("No matches found!", fg="green", bold=True) + + except Exception as e: # noqa: BLE001 # handle missing index error if "Unknown index" in str(e): click.secho("This query references indexes that do not exist in the target stack.", fg="red") click.secho("Check if index exists (via integration installation) and contains data.", fg="red") click.secho("Alternatively, update the query to reference an existing index.", fg="red") else: - click.secho(f"Error running query: {str(e)}", fg="red") + click.secho(f"Error running query: {e!s}", fg="red") - def run_all_queries(self, queries: dict, wait_timeout: int): + def run_all_queries(self, queries: dict[int, Any], wait_timeout: int) -> None: """Run all eligible queries in the hunting file.""" click.secho("Running all eligible queries...", fg="green", bold=True) for i, query in queries.items(): click.secho(f"\nRunning Query {i + 1}:", fg="green", bold=True) - click.echo(self._format_query(query)) + click.echo(self.format_query(query)) self.run_individual_query(query, wait_timeout) click.secho("\n" + "-" * 120, fg="yellow") - def process_results(self, response): - """Process the Elasticsearch query results and display the outcome.""" - response_data = response.body - if response_data.get('values'): - click.secho("Query matches found!", fg="red", bold=True) - else: - click.secho("No matches found!", fg="green", bold=True) - - def _format_query(self, query: str) -> str: + def format_query(self, query: str) -> str: """Format the query with word wrapping for better readability.""" - lines = query.split('\n') - wrapped_lines = [textwrap.fill(line, width=120, subsequent_indent=' ') for line in lines] - return '\n'.join(wrapped_lines) + lines = query.split("\n") + wrapped_lines = [textwrap.fill(line, width=120, subsequent_indent=" ") for line in lines] + return "\n".join(wrapped_lines) diff --git a/hunting/search.py b/hunting/search.py index 615e8cd222d..2c7b4fdf50e 100644 --- a/hunting/search.py +++ b/hunting/search.py @@ -5,20 +5,24 @@ from pathlib import Path +from typing import Any + import click + from detection_rules.attack import tactics_map, technique_lookup -from .utils import load_index_file, load_all_toml + +from .utils import load_all_toml, load_index_file class QueryIndex: - def __init__(self, base_path: Path): + def __init__(self, base_path: Path) -> None: """Initialize with the base path and load the index.""" self.base_path = base_path self.hunting_index = load_index_file() - self.mitre_technique_ids = set() + self.mitre_technique_ids: set[str] = set() self.reverse_tactics_map = {v: k for k, v in tactics_map.items()} - def _process_mitre_filter(self, mitre_filter: tuple): + def _process_mitre_filter(self, mitre_filter: tuple[str, ...]) -> None: """Process the MITRE filter to gather all matching techniques.""" for filter_item in mitre_filter: if filter_item in self.reverse_tactics_map: @@ -26,29 +30,30 @@ def _process_mitre_filter(self, mitre_filter: tuple): elif filter_item in technique_lookup: self._process_technique_id(filter_item) - def _process_tactic_id(self, filter_item): + def _process_tactic_id(self, filter_item: str) -> None: """Helper method to process a tactic ID.""" tactic_name = self.reverse_tactics_map[filter_item] click.echo(f"Found tactic ID {filter_item} (Tactic Name: {tactic_name}). Searching for associated techniques.") for tech_id, details in technique_lookup.items(): - kill_chain_phases = details.get('kill_chain_phases', []) - if any(tactic_name.lower().replace(' ', '-') == phase['phase_name'] for phase in kill_chain_phases): + kill_chain_phases = details.get("kill_chain_phases", []) + if any(tactic_name.lower().replace(" ", "-") == phase["phase_name"] for phase in kill_chain_phases): self.mitre_technique_ids.add(tech_id) - def _process_technique_id(self, filter_item): + def _process_technique_id(self, filter_item: str) -> None: """Helper method to process a technique or sub-technique ID.""" self.mitre_technique_ids.add(filter_item) - if '.' not in filter_item: + if "." not in filter_item: sub_techniques = { - sub_tech_id for sub_tech_id in technique_lookup - if sub_tech_id.startswith(f"{filter_item}.") + sub_tech_id for sub_tech_id in technique_lookup if sub_tech_id.startswith(f"{filter_item}.") } self.mitre_technique_ids.update(sub_techniques) - def search(self, mitre_filter: tuple = (), data_source: str = None, keyword: str = None) -> list: + def search( + self, mitre_filter: tuple[str, ...] = (), data_source: str | None = None, keyword: str | None = None + ) -> list[dict[str, Any]]: """Search the index based on MITRE techniques, data source, or keyword.""" - results = [] + results: list[dict[str, Any]] = [] # Step 1: If data source is provided, filter by data source first if data_source: @@ -65,8 +70,9 @@ def search(self, mitre_filter: tuple = (), data_source: str = None, keyword: str self._process_mitre_filter(mitre_filter) if results: # Filter existing results further by MITRE if data source results already exist - results = [result for result in results if - any(tech in self.mitre_technique_ids for tech in result['mitre'])] + results = [ + result for result in results if any(tech in self.mitre_technique_ids for tech in result["mitre"]) + ] else: # Otherwise, perform a fresh search based on MITRE filter results = self._search_index(mitre_filter) @@ -83,9 +89,9 @@ def search(self, mitre_filter: tuple = (), data_source: str = None, keyword: str return self._handle_no_results(results, mitre_filter, data_source, keyword) - def _search_index(self, mitre_filter: tuple = ()) -> list: + def _search_index(self, mitre_filter: tuple[str, ...] = ()) -> list[dict[str, Any]]: """Private method to search the index based on MITRE filter.""" - results = [] + results: list[dict[str, Any]] = [] # Load all TOML data for detailed fields hunting_content = load_all_toml(self.base_path) @@ -96,23 +102,23 @@ def _search_index(self, mitre_filter: tuple = ()) -> list: # Prepare the result with full hunt content fields matches = hunt_content.__dict__.copy() - matches['mitre'] = hunt_content.mitre - matches['data_source'] = hunt_content.integration - matches['uuid'] = hunt_content.uuid - matches['path'] = file_path + matches["mitre"] = hunt_content.mitre + matches["data_source"] = hunt_content.integration + matches["uuid"] = hunt_content.uuid + matches["path"] = file_path results.append(matches) return results - def _search_keyword(self, keyword: str) -> list: + def _search_keyword(self, keyword: str) -> list[dict[str, Any]]: """Private method to search description, name, notes, and references fields for a keyword.""" - results = [] + results: list[dict[str, Any]] = [] hunting_content = load_all_toml(self.base_path) for hunt_content, file_path in hunting_content: # Assign blank if notes or references are missing - notes = '::'.join(hunt_content.notes) if hunt_content.notes else '' - references = '::'.join(hunt_content.references) if hunt_content.references else '' + notes = "::".join(hunt_content.notes) if hunt_content.notes else "" + references = "::".join(hunt_content.references) if hunt_content.references else "" # Combine name, description, notes, and references for the search combined_content = f"{hunt_content.name}::{hunt_content.description}::{notes}::{references}" @@ -120,63 +126,69 @@ def _search_keyword(self, keyword: str) -> list: if keyword.lower() in combined_content.lower(): # Copy hunt_content data and prepare the result matches = hunt_content.__dict__.copy() - matches['mitre'] = hunt_content.mitre - matches['data_source'] = hunt_content.integration - matches['uuid'] = hunt_content.uuid - matches['path'] = file_path + matches["mitre"] = hunt_content.mitre + matches["data_source"] = hunt_content.integration + matches["uuid"] = hunt_content.uuid + matches["path"] = file_path results.append(matches) return results - def _filter_by_data_source(self, data_source: str) -> list: + def _filter_by_data_source(self, data_source: str) -> list[dict[str, Any]]: """Filter the index by data source, checking both the actual files and the index.""" - results = [] - seen_uuids = set() # Track UUIDs to avoid duplicates + results: list[dict[str, Any]] = [] + seen_uuids: set[str] = set() # Track UUIDs to avoid duplicates # Load all TOML data for detailed fields hunting_content = load_all_toml(self.base_path) # Step 1: Check files first by their 'integration' field for hunt_content, file_path in hunting_content: - if data_source in hunt_content.integration: - if hunt_content.uuid not in seen_uuids: - # Prepare the result with full hunt content fields - matches = hunt_content.__dict__.copy() - matches['mitre'] = hunt_content.mitre - matches['data_source'] = hunt_content.integration - matches['uuid'] = hunt_content.uuid - matches['path'] = file_path - results.append(matches) - seen_uuids.add(hunt_content.uuid) + if data_source in hunt_content.integration and hunt_content.uuid not in seen_uuids: + # Prepare the result with full hunt content fields + matches = hunt_content.__dict__.copy() + matches["mitre"] = hunt_content.mitre + matches["data_source"] = hunt_content.integration + matches["uuid"] = hunt_content.uuid + matches["path"] = file_path + results.append(matches) + seen_uuids.add(hunt_content.uuid) # Step 2: Check the index for generic data sources (e.g., 'aws', 'linux') if data_source in self.hunting_index: - for query_uuid, query_data in self.hunting_index[data_source].items(): + for query_uuid in self.hunting_index[data_source]: if query_uuid not in seen_uuids: # Find corresponding TOML content for this query - hunt_content = next((hunt for hunt, path in hunting_content if hunt.uuid == query_uuid), None) - if hunt_content: + h = next(((hunt, path) for hunt, path in hunting_content if hunt.uuid == query_uuid), None) + if h: + hunt_content, path = h # Prepare the result with full hunt content fields matches = hunt_content.__dict__.copy() - matches['mitre'] = hunt_content.mitre - matches['data_source'] = hunt_content.integration - matches['uuid'] = hunt_content.uuid - matches['path'] = file_path + matches["mitre"] = hunt_content.mitre + matches["data_source"] = hunt_content.integration + matches["uuid"] = hunt_content.uuid + matches["path"] = path results.append(matches) seen_uuids.add(query_uuid) return results - def _matches_keyword(self, result: dict, keyword: str) -> bool: + def _matches_keyword(self, result: dict[str, Any], keyword: str) -> bool: """Check if the result matches the keyword in name, description, or notes.""" # Combine relevant fields for keyword search - notes = '::'.join(result.get('notes', [])) if 'notes' in result else '' - references = '::'.join(result.get('references', [])) if 'references' in result else '' + notes = "::".join(result.get("notes", [])) if "notes" in result else "" + references = "::".join(result.get("references", [])) if "references" in result else "" combined_content = f"{result['name']}::{result['description']}::{notes}::{references}" return keyword.lower() in combined_content.lower() - def _handle_no_results(self, results: list, mitre_filter=None, data_source=None, keyword=None) -> list: + def _handle_no_results( + self, + results: list[dict[str, Any]], + mitre_filter: tuple[str, ...] | None = None, + data_source: str | None = None, + keyword: str | None = None, + ) -> list[dict[str, Any]]: """Handle cases where no results are found.""" if not results: if mitre_filter and not self.mitre_technique_ids: diff --git a/hunting/utils.py b/hunting/utils.py index c704d16245d..d6779ea6f28 100644 --- a/hunting/utils.py +++ b/hunting/utils.py @@ -6,7 +6,7 @@ import inspect import tomllib from pathlib import Path -from typing import Union +from typing import Any import click import urllib3 @@ -17,21 +17,21 @@ from .definitions import HUNTING_DIR, Hunt -def get_hunt_path(uuid: str, file_path: str) -> (Path, str): +def get_hunt_path(uuid: str, file_path: str) -> tuple[Path | None, str | None]: """Resolve the path of the hunting query using either a UUID or file path.""" if uuid: # Load the index and find the hunt by UUID index_data = load_index_file() - for data_source, hunts in index_data.items(): + for hunts in index_data.values(): if uuid in hunts: hunt_data = hunts[uuid] # Combine the relative path from the index with the HUNTING_DIR - hunt_path = HUNTING_DIR / hunt_data['path'] + hunt_path = HUNTING_DIR / hunt_data["path"] return hunt_path.resolve(), None return None, f"No hunt found for UUID: {uuid}" - elif file_path: + if file_path: # Use the provided file path hunt_path = Path(file_path) if not hunt_path.is_file(): @@ -41,20 +41,18 @@ def get_hunt_path(uuid: str, file_path: str) -> (Path, str): return None, "Either UUID or file path must be provided." -def load_index_file() -> dict: +def load_index_file() -> dict[str, Any]: """Load the hunting index.yml file.""" index_file = HUNTING_DIR / "index.yml" if not index_file.exists(): click.echo(f"No index.yml found at {index_file}.") return {} - with open(index_file, 'r') as f: - hunting_index = yaml.safe_load(f) + with index_file.open() as f: + return yaml.safe_load(f) - return hunting_index - -def load_toml(source: Union[Path, str]) -> Hunt: +def load_toml(source: Path | str) -> Hunt: """Load and validate TOML content as Hunt dataclass.""" if isinstance(source, Path): if not source.is_file(): @@ -69,28 +67,28 @@ def load_toml(source: Union[Path, str]) -> Hunt: return Hunt(**toml_dict["hunt"]) -def load_all_toml(base_path: Path): +def load_all_toml(base_path: Path) -> list[tuple[Hunt, Path]]: """Load all TOML files from the directory and return a list of Hunt configurations and their paths.""" - hunts = [] + hunts: list[tuple[Hunt, Path]] = [] for toml_file in base_path.rglob("*.toml"): hunt_config = load_toml(toml_file) hunts.append((hunt_config, toml_file)) return hunts -def save_index_file(base_path: Path, directories: dict) -> None: +def save_index_file(base_path: Path, directories: dict[str, Any]) -> None: """Save the updated index.yml file.""" index_file = base_path / "index.yml" - with open(index_file, 'w') as f: + with index_file.open("w") as f: yaml.safe_dump(directories, f, default_flow_style=False, sort_keys=False) print(f"Index YAML updated at: {index_file}") -def validate_link(link: str): +def validate_link(link: str) -> None: """Validate and return the link.""" http = urllib3.PoolManager() - response = http.request('GET', link) - if response.status != 200: + response = http.request("GET", link) + if response.status != 200: # noqa: PLR2004 raise ValueError(f"Invalid link: {link}") @@ -109,9 +107,9 @@ def update_index_yml(base_path: Path) -> None: uuid = hunt_config.uuid entry = { - 'name': hunt_config.name, - 'path': f"./{toml_file.relative_to(base_path).as_posix()}", - 'mitre': hunt_config.mitre + "name": hunt_config.name, + "path": f"./{toml_file.relative_to(base_path).as_posix()}", + "mitre": hunt_config.mitre, } # Check if the folder_name exists and if it's a list, convert it to a dictionary @@ -120,14 +118,14 @@ def update_index_yml(base_path: Path) -> None: else: if isinstance(directories[folder_name], list): # Convert the list to a dictionary, using UUIDs as keys - directories[folder_name] = {item['uuid']: item for item in directories[folder_name]} + directories[folder_name] = {item["uuid"]: item for item in directories[folder_name]} directories[folder_name][uuid] = entry # Save the updated index.yml save_index_file(base_path, directories) -def filter_elasticsearch_params(config: dict) -> dict: +def filter_elasticsearch_params(config: dict[str, Any]) -> dict[str, Any]: """Filter out unwanted keys from the config by inspecting the Elasticsearch client constructor.""" # Get the parameter names from the Elasticsearch class constructor es_params = inspect.signature(get_elasticsearch_client).parameters diff --git a/pyproject.toml b/pyproject.toml index 0cd15ae174e..f87ed96d5ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "detection_rules" -version = "1.2.26" +version = "1.3.0" description = "Detection Rules is the home for rules used by Elastic Security. This repository is used for the development, maintenance, testing, validation, and release of rules for Elastic Security’s Detection Engine." readme = "README.md" requires-python = ">=3.12" @@ -19,32 +19,43 @@ classifiers = [ "Topic :: Utilities" ] dependencies = [ - "Click~=8.1.7", - "elasticsearch~=8.12.1", - "eql==0.9.19", - "jsl==0.2.4", - "jsonschema>=4.21.1", - "marko==2.0.3", - "marshmallow-dataclass[union]~=8.6.0", - "marshmallow-jsonschema~=0.13.0", - "marshmallow-union~=0.1.15", - "marshmallow~=3.21.1", - "pywin32 ; platform_system=='Windows'", - "pytoml==0.1.21", - "PyYAML~=6.0.1", - "requests~=2.31.0", - "toml==0.10.2", - "typing-inspect==0.9.0", - "typing-extensions==4.10.0", - "XlsxWriter~=3.2.0", - "semver==3.0.2", - "PyGithub==2.2.0", - "detection-rules-kql @ git+https://github.com/elastic/detection-rules.git#subdirectory=lib/kql", - "detection-rules-kibana @ git+https://github.com/elastic/detection-rules.git#subdirectory=lib/kibana", - "setuptools==75.2.0" + "Click~=8.1.7", + "elasticsearch~=8.12.1", + "eql==0.9.19", + "jsl==0.2.4", + "jsonschema>=4.21.1", + "marko==2.0.3", + "marshmallow-dataclass[union]>=8.7", + "marshmallow-jsonschema~=0.13.0", + "marshmallow-union~=0.1.15", + "marshmallow~=3.26.1", + "pywin32 ; platform_system=='Windows'", + # FIXME: pytoml is outdated and should not be used + "pytoml==0.1.21", + "PyYAML~=6.0.1", + "requests~=2.31.0", + "toml==0.10.2", + "typing-inspect==0.9.0", + "typing-extensions>=4.12", + "XlsxWriter~=3.2.0", + "semver==3.0.2", + "PyGithub==2.2.0", + "detection-rules-kql @ git+https://github.com/elastic/detection-rules.git#subdirectory=lib/kql", + "detection-rules-kibana @ git+https://github.com/elastic/detection-rules.git#subdirectory=lib/kibana", + "setuptools==75.2.0" ] [project.optional-dependencies] -dev = ["pep8-naming==0.13.0", "flake8==7.0.0", "pyflakes==3.2.0", "pytest>=8.1.1", "nodeenv==1.8.0", "pre-commit==3.6.2"] +dev = [ + "pep8-naming==0.13.0", + "flake8==7.0.0", + "pyflakes==3.2.0", + "pytest>=8.1.1", + "nodeenv==1.8.0", + "pre-commit==3.6.2", + "ruff>=0.11", + "pyright>=1.1", +] + hunting = ["tabulate==0.9.0"] [project.urls] @@ -53,15 +64,133 @@ hunting = ["tabulate==0.9.0"] "Research" = "https://www.elastic.co/security-labs" "Elastic" = "https://www.elastic.co" +[build-system] +requires = ["setuptools", "wheel", "setuptools_scm"] +build-backend = "setuptools.build_meta" + [tool.setuptools] package-data = {"kql" = ["*.g"]} packages = ["detection_rules", "hunting"] [tool.pytest.ini_options] filterwarnings = [ - "ignore::DeprecationWarning" + "ignore::DeprecationWarning" ] -[build-system] -requires = ["setuptools", "wheel", "setuptools_scm"] -build-backend = "setuptools.build_meta" +[tool.ruff] +line-length = 120 +indent-width = 4 +include = [ + "pyproject.toml", + "detection_rules/**/*.py", + "hunting/**/*.py", + "tests/**/*.py", +] +show-fixes = true + +[tool.ruff.lint] +select = [ + "E", # pycodestyle + "F", # Pyflakes + "UP", # pyupgrade + "B", # flake8-bugbear + "SIM", # flake8-simplify + "I", # isort + "N", # pep8-naming + "UP", # pyupgrade + "YTT", # flake8-2020 + "ANN", # flake8-annotations + "ASYNC", # flake8-async + "S", # flake8-bandit + "BLE", # flake8-blind-except + "B", # flake8-bugbear + "A", # flake8-builtins + "COM", # flake8-commas + "C4", # flake8-comprehensions + "DTZ", # flake8-datetimez + "T10", # flake8-debugger + "DJ", # flake8-django + "EM", # flake8-errmsg + "EXE", # flake8-executable + "ISC", # flake8-implicit-str-concat + "ICN", # flake8-import-conventions + "G", # flake8-logging-format + "INP", # flake8-no-pep420 + "PIE", # flake8-pie + "PYI", # flake8-pyi + "PT", # flake8-pytest-style + "Q", # flake8-quotes + "RSE", # flake8-raise + "RET", # flake8-return + "SLF", # flake8-self + "SLOT", # flake8-slots + "TID", # flake8-tidy-imports + "TCH", # flake8-type-checking + "INT", # flake8-gettext + "ARG", # flake8-unused-arguments + "PTH", # flake8-use-pathlib + "TD", # flake8-todos + "FIX", # flake8-fixme + "ERA", # eradicate + "PGH", # pygrep-hooks + "PL", # Pylint + "TRY", # tryceratops + "FLY", # flynt + "PERF", # Perflint + "RUF", # Ruff-specific rules +] +ignore = [ + "ANN401", # any-type + "EM101", # raw-string-in-exception + "EM102", # f-string-in-exception + "PT009", # pytest-unittest-assertion + "TRY003", # raise-vanilla-args + + "N815", # mixed-case-variable-in-class-scope + + "PLC0415", # import-outside-top-level, erratic behavior + "S603", # subprocess-without-shell-equals-true, prone to false positives + + "COM812", # missing-trailing-comma, might cause issues with ruff formatter +] + +[tool.ruff.lint.per-file-ignores] +"tests/*" = [ + "ANN001", # missing-type-function-argument + "ANN002", # missing-type-args + "ANN003", # missing-type-kwargs + "ANN101", # missing-type-self + "ANN102", # missing-type-cls + "ANN201", # missing-return-type-undocumented-public-function + "ANN202", # missing-return-type-private-function + "ANN205", # missing-return-type-static-method + "ARG001", # unused-function-argument + "ANN206", # missing-return-type-class-method + "PLR2004", # magic-value-comparison + "SIM300", # yoda-conditions + "S101", # assert + "PT009", # pytest-unittest-assertion + "PT012", # pytest-raises-with-multiple-statements + "PT027", # pytest-unittest-raises-assertion + "FIX001", # line-contains-fixme + "FIX002", # line-contains-todo + + # FIXME: the long static strings should be moved to the resource files + "E501", # line-too-long + + # FIXME: we should avoid TODOs in the code as much as possible + "TD002", # missing-todo-author + "TD003", # missing-todo-link +] + +[tool.pyright] +include = [ + "detection_rules/", + "hunting/", +] +exclude = [ + "tests/", +] +reportMissingTypeStubs = true +reportUnusedCallResult = "error" +typeCheckingMode = "strict" diff --git a/tests/__init__.py b/tests/__init__.py index 43f48cf2144..a3ca9636224 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -4,53 +4,56 @@ # 2.0. """Detection Rules tests.""" -import glob + import json import os +import pathlib -from detection_rules.utils import combine_sources +from detection_rules.eswrap import combine_sources -CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) -DATA_DIR = os.path.join(CURRENT_DIR, 'data') -TP_DIR = os.path.join(DATA_DIR, 'true_positives') -FP_DIR = os.path.join(DATA_DIR, 'false_positives') +CURRENT_DIR = pathlib.Path(__file__).resolve().parent +DATA_DIR = CURRENT_DIR / "data" +TP_DIR = DATA_DIR / "true_positives" +FP_DIR = DATA_DIR / "false_positives" def get_fp_dirs(): """Get a list of fp dir names.""" - return glob.glob(os.path.join(FP_DIR, '*')) + return FP_DIR.glob("*") def get_fp_data_files(): """get FP data files by fp dir name.""" data = {} for fp_dir in get_fp_dirs(): - fp_dir_name = os.path.basename(fp_dir) - relative_dir_name = os.path.join('false_positives', fp_dir_name) + path = pathlib.Path(fp_dir) + fp_dir_name = path.name + relative_dir_name = pathlib.Path("false_positives") / fp_dir_name data[fp_dir_name] = combine_sources(*get_data_files(relative_dir_name).values()) return data -def get_data_files_list(*folder, ext='ndjson', recursive=False): +def get_data_files_list(*folder, ext="ndjson", recursive=False): """Get TP or FP file list.""" folder = os.path.sep.join(folder) - data_dir = [DATA_DIR, folder] - if recursive: - data_dir.append('**') + data_dir = pathlib.Path(DATA_DIR) / folder + + glob = "**" if recursive else "" + glob += f"*.{ext}" - data_dir.append('*.{}'.format(ext)) - return glob.glob(os.path.join(*data_dir), recursive=recursive) + return data_dir.glob(glob) -def get_data_files(*folder, ext='ndjson', recursive=False): +def get_data_files(*folder, ext="ndjson", recursive=False): """Get data from data files.""" data_files = {} for data_file in get_data_files_list(*folder, ext=ext, recursive=recursive): - with open(data_file, 'r') as f: - file_name = os.path.splitext(os.path.basename(data_file))[0] + path = pathlib.Path(data_file) + with path.open() as f: + file_name = path.stem - if ext in ('.ndjson', '.jsonl'): + if ext in (".ndjson", ".jsonl"): data = f.readlines() data_files[file_name] = [json.loads(d) for d in data] else: @@ -60,7 +63,8 @@ def get_data_files(*folder, ext='ndjson', recursive=False): def get_data_file(*folder): - file = os.path.join(DATA_DIR, os.path.sep.join(folder)) - if os.path.exists(file): - with open(file, 'r') as f: + path = pathlib.Path(DATA_DIR) / os.path.sep.join(folder) + if path.exists(): + with path.open() as f: return json.load(f) + return None diff --git a/tests/base.py b/tests/base.py index b56ede1c5c2..cbd3d52597c 100644 --- a/tests/base.py +++ b/tests/base.py @@ -4,22 +4,21 @@ # 2.0. """Shared resources for tests.""" + import os import unittest -from pathlib import Path from functools import lru_cache -from typing import Union +from pathlib import Path from detection_rules.config import parse_rules_config from detection_rules.rule import TOMLRule from detection_rules.rule_loader import DeprecatedCollection, DeprecatedRule, RuleCollection, production_filter - RULE_LOADER_FAIL = False RULE_LOADER_FAIL_MSG = None RULE_LOADER_FAIL_RAISED = False -CUSTOM_RULES_DIR = os.getenv('CUSTOM_RULES_DIR', None) +CUSTOM_RULES_DIR = os.getenv("CUSTOM_RULES_DIR", None) RULES_CONFIG = parse_rules_config() @@ -28,7 +27,7 @@ def load_rules() -> RuleCollection: if CUSTOM_RULES_DIR: rc = RuleCollection() path = Path(CUSTOM_RULES_DIR) - assert path.exists(), f'Custom rules directory {path} does not exist' + assert path.exists(), f"Custom rules directory {path} does not exist" rc.load_directories(directories=RULES_CONFIG.rule_dirs) rc.freeze() return rc @@ -36,7 +35,7 @@ def load_rules() -> RuleCollection: def default_bbr(rc: RuleCollection) -> RuleCollection: - rules = [r for r in rc.rules if 'rules_building_block' in r.path.parent.parts] + rules = [r for r in rc.rules if "rules_building_block" in r.path.parent.parts] return RuleCollection(rules=rules) @@ -49,10 +48,7 @@ class BaseRuleTest(unittest.TestCase): @classmethod def setUpClass(cls): - global RULE_LOADER_FAIL, RULE_LOADER_FAIL_MSG - - # too noisy; refactor - # os.environ["DR_NOTIFY_INTEGRATION_UPDATE_AVAILABLE"] = "1" + global RULE_LOADER_FAIL, RULE_LOADER_FAIL_MSG # noqa: PLW0603 if not RULE_LOADER_FAIL: try: @@ -62,7 +58,7 @@ def setUpClass(cls): cls.all_rules = rc.filter(production_filter) cls.bbr = rc_bbr.rules cls.deprecated_rules: DeprecatedCollection = rc.deprecated - except Exception as e: + except Exception as e: # noqa: BLE001 RULE_LOADER_FAIL = True RULE_LOADER_FAIL_MSG = str(e) @@ -70,20 +66,20 @@ def setUpClass(cls): cls.rules_config = RULES_CONFIG @staticmethod - def rule_str(rule: Union[DeprecatedRule, TOMLRule], trailer=' ->') -> str: - return f'{rule.id} - {rule.name}{trailer or ""}' + def rule_str(rule: DeprecatedRule | TOMLRule, trailer=" ->") -> str: + return f"{rule.id} - {rule.name}{trailer or ''}" def setUp(self) -> None: - global RULE_LOADER_FAIL, RULE_LOADER_FAIL_MSG, RULE_LOADER_FAIL_RAISED + global RULE_LOADER_FAIL_RAISED # noqa: PLW0603 if RULE_LOADER_FAIL: # limit the loader failure to just one run # raise a dedicated test failure for the loader if not RULE_LOADER_FAIL_RAISED: RULE_LOADER_FAIL_RAISED = True - with self.subTest('Test that the rule loader loaded with no validation or other failures.'): - self.fail(f'Rule loader failure: \n{RULE_LOADER_FAIL_MSG}') + with self.subTest("Test that the rule loader loaded with no validation or other failures."): + self.fail(f"Rule loader failure: {RULE_LOADER_FAIL_MSG}") - self.skipTest('Rule loader failure') + self.skipTest("Rule loader failure") else: super().setUp() diff --git a/tests/kuery/test_dsl.py b/tests/kuery/test_dsl.py index 4af3217ebc0..ff87ec96987 100644 --- a/tests/kuery/test_dsl.py +++ b/tests/kuery/test_dsl.py @@ -4,6 +4,7 @@ # 2.0. import unittest + import kql @@ -51,10 +52,9 @@ def test_and_query(self): def test_not_query(self): self.validate("not field:value", {"must_not": [{"match": {"field": "value"}}]}) self.validate("field:(not value)", {"must_not": [{"match": {"field": "value"}}]}) - self.validate("field:(a and not b)", { - "filter": [{"match": {"field": "a"}}], - "must_not": [{"match": {"field": "b"}}] - }) + self.validate( + "field:(a and not b)", {"filter": [{"match": {"field": "a"}}], "must_not": [{"match": {"field": "b"}}]} + ) self.validate( "not field:value and not field2:value2", {"must_not": [{"match": {"field": "value"}}, {"match": {"field2": "value2"}}]}, @@ -74,13 +74,10 @@ def test_not_query(self): optimize=False, ) - self.validate("not (field:value and field2:value2)", - { - "must_not": [ - {"match": {"field": "value"}}, - {"match": {"field2": "value2"}} - ] - }) + self.validate( + "not (field:value and field2:value2)", + {"must_not": [{"match": {"field": "value"}}, {"match": {"field2": "value2"}}]}, + ) def test_optimizations(self): self.validate( @@ -120,21 +117,9 @@ def test_optimizations(self): self.validate( "a:(v1 or v2 or v3) and b:(v4 or v5)", { - "should": [ - {"match": {"a": "v1"}}, - {"match": {"a": "v2"}}, - {"match": {"a": "v3"}} - ], + "should": [{"match": {"a": "v1"}}, {"match": {"a": "v2"}}, {"match": {"a": "v3"}}], "filter": [ - { - "bool": { - "should": [ - {"match": {"b": "v4"}}, - {"match": {"b": "v5"}} - ], - "minimum_should_match": 1 - } - } + {"bool": {"should": [{"match": {"b": "v4"}}, {"match": {"b": "v5"}}], "minimum_should_match": 1}} ], "minimum_should_match": 1, }, diff --git a/tests/kuery/test_eql2kql.py b/tests/kuery/test_eql2kql.py index de0e4404cb7..dd2d56860d0 100644 --- a/tests/kuery/test_eql2kql.py +++ b/tests/kuery/test_eql2kql.py @@ -3,13 +3,13 @@ # 2.0; you may not use this file except in compliance with the Elastic License # 2.0. -import eql import unittest + +import eql import kql class TestEql2Kql(unittest.TestCase): - def validate(self, kql_source, eql_source): self.assertEqual(kql_source, str(kql.from_eql(eql_source))) @@ -55,8 +55,8 @@ def test_ip_checks(self): def test_wildcard_field(self): with eql.parser.elasticsearch_validate_optional_fields: - self.validate('field:value-*', 'field : "value-*"') - self.validate('field:value-?', 'field : "value-?"') + self.validate("field:value-*", 'field : "value-*"') + self.validate("field:value-?", 'field : "value-?"') with eql.parser.elasticsearch_validate_optional_fields, self.assertRaises(AssertionError): self.validate('field:"value-*"', 'field == "value-*"') diff --git a/tests/kuery/test_evaluator.py b/tests/kuery/test_evaluator.py index 97033e97b03..d30210670e6 100644 --- a/tests/kuery/test_evaluator.py +++ b/tests/kuery/test_evaluator.py @@ -7,116 +7,105 @@ import kql +document = { + "number": 1, + "boolean": True, + "ip": "192.168.16.3", + "string": "hello world", + "string_list": ["hello world", "example"], + "number_list": [1, 2, 3], + "boolean_list": [True, False], + "structured": [{"a": [{"b": 1}]}], +} -class EvaluatorTests(unittest.TestCase): - - document = { - "number": 1, - "boolean": True, - "ip": "192.168.16.3", - "string": "hello world", - - "string_list": ["hello world", "example"], - "number_list": [1, 2, 3], - "boolean_list": [True, False], - "structured": [ - { - "a": [ - {"b": 1} - ] - } - ], - } - - def evaluate(self, source_text, document=None): - if document is None: - document = self.document +class EvaluatorTests(unittest.TestCase): + def evaluate(self, source_text): evaluator = kql.get_evaluator(source_text, optimize=False) return evaluator(document) def test_single_value(self): - self.assertTrue(self.evaluate('number:1')) + self.assertTrue(self.evaluate("number:1")) self.assertTrue(self.evaluate('number:"1"')) - self.assertTrue(self.evaluate('boolean:true')) + self.assertTrue(self.evaluate("boolean:true")) self.assertTrue(self.evaluate('string:"hello world"')) - self.assertFalse(self.evaluate('number:0')) - self.assertFalse(self.evaluate('boolean:false')) + self.assertFalse(self.evaluate("number:0")) + self.assertFalse(self.evaluate("boolean:false")) self.assertFalse(self.evaluate('string:"missing"')) def test_list_value(self): - self.assertTrue(self.evaluate('number_list:1')) - self.assertTrue(self.evaluate('number_list:2')) - self.assertTrue(self.evaluate('number_list:3')) + self.assertTrue(self.evaluate("number_list:1")) + self.assertTrue(self.evaluate("number_list:2")) + self.assertTrue(self.evaluate("number_list:3")) - self.assertTrue(self.evaluate('boolean_list:true')) - self.assertTrue(self.evaluate('boolean_list:false')) + self.assertTrue(self.evaluate("boolean_list:true")) + self.assertTrue(self.evaluate("boolean_list:false")) self.assertTrue(self.evaluate('string_list:"hello world"')) - self.assertTrue(self.evaluate('string_list:example')) + self.assertTrue(self.evaluate("string_list:example")) - self.assertFalse(self.evaluate('number_list:4')) + self.assertFalse(self.evaluate("number_list:4")) self.assertFalse(self.evaluate('string_list:"missing"')) def test_and_values(self): - self.assertTrue(self.evaluate('number_list:(1 and 2)')) - self.assertTrue(self.evaluate('boolean_list:(false and true)')) + self.assertTrue(self.evaluate("number_list:(1 and 2)")) + self.assertTrue(self.evaluate("boolean_list:(false and true)")) self.assertFalse(self.evaluate('string:("missing" and "hello world")')) - self.assertFalse(self.evaluate('number:(0 and 1)')) - self.assertFalse(self.evaluate('boolean:(false and true)')) + self.assertFalse(self.evaluate("number:(0 and 1)")) + self.assertFalse(self.evaluate("boolean:(false and true)")) def test_not_value(self): - self.assertTrue(self.evaluate('number_list:1')) - self.assertFalse(self.evaluate('not number_list:1')) - self.assertFalse(self.evaluate('number_list:(not 1)')) + self.assertTrue(self.evaluate("number_list:1")) + self.assertFalse(self.evaluate("not number_list:1")) + self.assertFalse(self.evaluate("number_list:(not 1)")) def test_or_values(self): - self.assertTrue(self.evaluate('number:(0 or 1)')) - self.assertTrue(self.evaluate('number:(1 or 2)')) - self.assertTrue(self.evaluate('boolean:(false or true)')) + self.assertTrue(self.evaluate("number:(0 or 1)")) + self.assertTrue(self.evaluate("number:(1 or 2)")) + self.assertTrue(self.evaluate("boolean:(false or true)")) self.assertTrue(self.evaluate('string:("missing" or "hello world")')) - self.assertFalse(self.evaluate('number:(0 or 3)')) + self.assertFalse(self.evaluate("number:(0 or 3)")) def test_and_expr(self): - self.assertTrue(self.evaluate('number:1 and boolean:true')) + self.assertTrue(self.evaluate("number:1 and boolean:true")) - self.assertFalse(self.evaluate('number:1 and boolean:false')) + self.assertFalse(self.evaluate("number:1 and boolean:false")) def test_or_expr(self): - self.assertTrue(self.evaluate('number:1 or boolean:false')) - self.assertFalse(self.evaluate('number:0 or boolean:false')) + self.assertTrue(self.evaluate("number:1 or boolean:false")) + self.assertFalse(self.evaluate("number:0 or boolean:false")) def test_range(self): - self.assertTrue(self.evaluate('number < 2')) - self.assertFalse(self.evaluate('number > 2')) + self.assertTrue(self.evaluate("number < 2")) + self.assertFalse(self.evaluate("number > 2")) def test_cidr_match(self): - self.assertTrue(self.evaluate('ip:192.168.0.0/16')) + self.assertTrue(self.evaluate("ip:192.168.0.0/16")) - self.assertFalse(self.evaluate('ip:10.0.0.0/8')) + self.assertFalse(self.evaluate("ip:10.0.0.0/8")) def test_quoted_wildcard(self): self.assertFalse(self.evaluate("string:'*'")) self.assertFalse(self.evaluate("string:'?'")) def test_wildcard(self): - self.assertTrue(self.evaluate('string:hello*')) - self.assertTrue(self.evaluate('string:*world')) - self.assertFalse(self.evaluate('string:foobar*')) + self.assertTrue(self.evaluate("string:hello*")) + self.assertTrue(self.evaluate("string:*world")) + self.assertFalse(self.evaluate("string:foobar*")) def test_field_exists(self): - self.assertTrue(self.evaluate('number:*')) - self.assertTrue(self.evaluate('boolean:*')) - self.assertTrue(self.evaluate('ip:*')) - self.assertTrue(self.evaluate('string:*')) - self.assertTrue(self.evaluate('string_list:*')) - self.assertTrue(self.evaluate('number_list:*')) - self.assertTrue(self.evaluate('boolean_list:*')) - - self.assertFalse(self.evaluate('a:*')) + self.assertTrue(self.evaluate("number:*")) + self.assertTrue(self.evaluate("boolean:*")) + self.assertTrue(self.evaluate("ip:*")) + self.assertTrue(self.evaluate("string:*")) + self.assertTrue(self.evaluate("string_list:*")) + self.assertTrue(self.evaluate("number_list:*")) + self.assertTrue(self.evaluate("boolean_list:*")) + + self.assertFalse(self.evaluate("a:*")) def test_flattening(self): self.assertTrue(self.evaluate("structured.a.b:*")) diff --git a/tests/kuery/test_kql2eql.py b/tests/kuery/test_kql2eql.py index bfa9a242589..efe04cc6de0 100644 --- a/tests/kuery/test_kql2eql.py +++ b/tests/kuery/test_kql2eql.py @@ -4,13 +4,12 @@ # 2.0. import unittest -import eql +import eql import kql class TestKql2Eql(unittest.TestCase): - def validate(self, kql_source, eql_source, schema=None): self.assertEqual(kql.to_eql(kql_source, schema=schema), eql.parse_expression(eql_source)) @@ -54,7 +53,7 @@ def test_list_of_values(self): self.validate("a:(0 or 1 and 2 or (3 and 4))", "a == 0 or a == 1 and a == 2 or (a == 3 and a == 4)") def test_lone_value(self): - for value in ["1", "-1.4", "true", "\"string test\""]: + for value in ["1", "-1.4", "true", '"string test"']: with self.assertRaisesRegex(kql.KqlParseError, "Value not tied to field"): kql.to_eql(value) @@ -71,15 +70,15 @@ def test_schema(self): } self.validate("top.numF : 1", "top.numF == 1", schema=schema) - self.validate("top.numF : \"1\"", "top.numF == 1", schema=schema) + self.validate('top.numF : "1"', "top.numF == 1", schema=schema) self.validate("top.keyword : 1", "top.keyword == '1'", schema=schema) - self.validate("top.keyword : \"hello\"", "top.keyword == 'hello'", schema=schema) + self.validate('top.keyword : "hello"', "top.keyword == 'hello'", schema=schema) self.validate("dest:192.168.255.255", "dest == '192.168.255.255'", schema=schema) self.validate("dest:192.168.0.0/16", "cidrMatch(dest, '192.168.0.0/16')", schema=schema) - self.validate("dest:\"192.168.0.0/16\"", "cidrMatch(dest, '192.168.0.0/16')", schema=schema) + self.validate('dest:"192.168.0.0/16"', "cidrMatch(dest, '192.168.0.0/16')", schema=schema) with self.assertRaises(eql.EqlSemanticError): - self.validate("top.text : \"hello\"", "top.text == 'hello'", schema=schema) + self.validate('top.text : "hello"', "top.text == 'hello'", schema=schema) with self.assertRaises(eql.EqlSemanticError): self.validate("top.text : 1 ", "top.text == '1'", schema=schema) @@ -93,7 +92,7 @@ def test_schema(self): with self.assertRaisesRegex(kql.KqlParseError, "Unable to convert nested query to EQL"): kql.to_eql("top:{middle:{bool: true}}", schema=schema) - invalid_ips = ["192.168.0.256", "192.168.0.256/33", "1", "\"1\""] + invalid_ips = ["192.168.0.256", "192.168.0.256/33", "1", '"1"'] for ip in invalid_ips: with self.assertRaisesRegex(kql.KqlParseError, r"Value doesn't match dest's type: ip"): - kql.to_eql("dest:{ip}".format(ip=ip), schema=schema) + kql.to_eql(f"dest:{ip}", schema=schema) diff --git a/tests/kuery/test_lint.py b/tests/kuery/test_lint.py index 31953cbd792..519025202d8 100644 --- a/tests/kuery/test_lint.py +++ b/tests/kuery/test_lint.py @@ -4,21 +4,21 @@ # 2.0. import unittest + import kql class LintTests(unittest.TestCase): - def validate(self, source, linted, *args): self.assertEqual(kql.lint(source), linted, *args) def test_lint_field(self): self.validate("a : b", "a:b") - self.validate("\"a\": b", "a:b") - self.validate("a : \"b\"", "a:b") + self.validate('"a": b', "a:b") + self.validate('a : "b"', "a:b") self.validate("a : (b)", "a:b") self.validate("a:1.234", "a:1.234") - self.validate("a:\"1.234\"", "a:1.234") + self.validate('a:"1.234"', "a:1.234") def test_upper_tokens(self): queries = [ @@ -80,7 +80,7 @@ def test_double_negate(self): self.validate("not (not (a:(not b) or c:(not d)))", "not a:b or not c:d") def test_ip(self): - self.validate("a:ff02\\:\\:fb", "a:\"ff02::fb\"") + self.validate("a:ff02\\:\\:fb", 'a:"ff02::fb"') def test_compound(self): self.validate("a:1 and b:2 and not (c:3 or c:4)", "a:1 and b:2 and not c:(3 or 4)") diff --git a/tests/kuery/test_parser.py b/tests/kuery/test_parser.py index 444d55f1bd1..c09cb2fccb8 100644 --- a/tests/kuery/test_parser.py +++ b/tests/kuery/test_parser.py @@ -4,19 +4,19 @@ # 2.0. import unittest + import kql from kql.ast import ( + Exists, Field, FieldComparison, FieldRange, - String, Number, - Exists, + String, ) class ParserTests(unittest.TestCase): - def validate(self, source, tree, *args, **kwargs): kwargs.setdefault("optimize", False) self.assertEqual(kql.parse(source, *args, **kwargs), tree) @@ -28,14 +28,14 @@ def test_keyword(self): "b": "long", } - self.validate('a.text:hello', FieldComparison(Field("a.text"), String("hello")), schema=schema) - self.validate('a.keyword:hello', FieldComparison(Field("a.keyword"), String("hello")), schema=schema) + self.validate("a.text:hello", FieldComparison(Field("a.text"), String("hello")), schema=schema) + self.validate("a.keyword:hello", FieldComparison(Field("a.keyword"), String("hello")), schema=schema) self.validate('a.text:"hello"', FieldComparison(Field("a.text"), String("hello")), schema=schema) self.validate('a.keyword:"hello"', FieldComparison(Field("a.keyword"), String("hello")), schema=schema) - self.validate('a.text:1', FieldComparison(Field("a.text"), String("1")), schema=schema) - self.validate('a.keyword:1', FieldComparison(Field("a.keyword"), String("1")), schema=schema) + self.validate("a.text:1", FieldComparison(Field("a.text"), String("1")), schema=schema) + self.validate("a.keyword:1", FieldComparison(Field("a.keyword"), String("1")), schema=schema) self.validate('a.text:"1"', FieldComparison(Field("a.text"), String("1")), schema=schema) self.validate('a.keyword:"1"', FieldComparison(Field("a.keyword"), String("1")), schema=schema) @@ -43,10 +43,10 @@ def test_keyword(self): def test_conversion(self): schema = {"num": "long", "text": "text"} - self.validate('num:1', FieldComparison(Field("num"), Number(1)), schema=schema) + self.validate("num:1", FieldComparison(Field("num"), Number(1)), schema=schema) self.validate('num:"1"', FieldComparison(Field("num"), Number(1)), schema=schema) - self.validate('text:1', FieldComparison(Field("text"), String("1")), schema=schema) + self.validate("text:1", FieldComparison(Field("text"), String("1")), schema=schema) self.validate('text:"1"', FieldComparison(Field("text"), String("1")), schema=schema) def test_list_equals(self): @@ -57,11 +57,11 @@ def test_number_exists(self): def test_multiple_types_success(self): schema = {"common.a": "keyword", "common.b": "keyword"} - self.validate("common.* : \"hello\"", FieldComparison(Field("common.*"), String("hello")), schema=schema) + self.validate('common.* : "hello"', FieldComparison(Field("common.*"), String("hello")), schema=schema) def test_multiple_types_fail(self): with self.assertRaises(kql.KqlParseError): - kql.parse("common.* : \"hello\"", schema={"common.a": "keyword", "common.b": "ip"}) + kql.parse('common.* : "hello"', schema={"common.a": "keyword", "common.b": "ip"}) def test_number_wildcard_fail(self): with self.assertRaises(kql.KqlParseError): @@ -81,7 +81,7 @@ def test_type_family_fail(self): def test_date(self): schema = {"@time": "date"} - self.validate('@time <= now-10d', FieldRange(Field("@time"), "<=", String("now-10d")), schema=schema) + self.validate("@time <= now-10d", FieldRange(Field("@time"), "<=", String("now-10d")), schema=schema) with self.assertRaises(kql.KqlParseError): kql.parse("@time > 5", schema=schema) diff --git a/tests/test_all_rules.py b/tests/test_all_rules.py index d93d4df4270..840753e5c85 100644 --- a/tests/test_all_rules.py +++ b/tests/test_all_rules.py @@ -4,28 +4,35 @@ # 2.0. """Test that all rules have valid metadata and syntax.""" + import os import re import unittest import uuid -import warnings from collections import defaultdict from pathlib import Path -import eql.ast - +import eql +import kql from marshmallow import ValidationError from semver import Version -import kql from detection_rules import attack from detection_rules.config import load_current_package_version -from detection_rules.integrations import (find_latest_compatible_version, - load_integrations_manifests, - load_integrations_schemas) +from detection_rules.integrations import ( + find_latest_compatible_version, + load_integrations_manifests, + load_integrations_schemas, +) from detection_rules.packaging import current_stack_version -from detection_rules.rule import (AlertSuppressionMapping, EQLRuleData, QueryRuleData, QueryValidator, - ThresholdAlertSuppression, TOMLRuleContents) +from detection_rules.rule import ( + AlertSuppressionMapping, + EQLRuleData, + QueryRuleData, + QueryValidator, + ThresholdAlertSuppression, + TOMLRuleContents, +) from detection_rules.rule_loader import FILE_PATTERN, RULES_CONFIG from detection_rules.rule_validators import EQLValidator, KQLValidator from detection_rules.schemas import definitions, get_min_supported_stack_version, get_stack_schemas @@ -42,34 +49,37 @@ class TestValidRules(BaseRuleTest): def test_schema_and_dupes(self): """Ensure that every rule matches the schema and there are no duplicates.""" - self.assertGreaterEqual(len(self.all_rules), 1, 'No rules were loaded from rules directory!') + self.assertGreaterEqual(len(self.all_rules), 1, "No rules were loaded from rules directory!") def test_file_names(self): """Test that the file names meet the requirement.""" file_pattern = FILE_PATTERN - self.assertIsNone(re.match(file_pattern, 'NotValidRuleFile.toml'), - f'Incorrect pattern for verifying rule names: {file_pattern}') - self.assertIsNone(re.match(file_pattern, 'still_not_a_valid_file_name.not_json'), - f'Incorrect pattern for verifying rule names: {file_pattern}') + self.assertIsNone( + re.match(file_pattern, "NotValidRuleFile.toml"), + f"Incorrect pattern for verifying rule names: {file_pattern}", + ) + self.assertIsNone( + re.match(file_pattern, "still_not_a_valid_file_name.not_json"), + f"Incorrect pattern for verifying rule names: {file_pattern}", + ) for rule in self.all_rules: file_name = str(rule.path.name) - self.assertIsNotNone(re.match(file_pattern, file_name), f'Invalid file name for {rule.path}') + self.assertIsNotNone(re.match(file_pattern, file_name), f"Invalid file name for {rule.path}") def test_all_rule_queries_optimized(self): """Ensure that every rule query is in optimized form.""" for rule in self.all_rules: - if ( - rule.contents.data.get("language") == "kuery" and not any( - item in rule.contents.data.query for item in definitions.QUERY_FIELD_OP_EXCEPTIONS - ) + if rule.contents.data.get("language") == "kuery" and not any( + item in rule.contents.data.query for item in definitions.QUERY_FIELD_OP_EXCEPTIONS ): source = rule.contents.data.query tree = kql.parse(source, optimize=False, normalize_kql_keywords=RULES_CONFIG.normalize_kql_keywords) optimized = tree.optimize(recursive=True) - err_message = f'\n{self.rule_str(rule)} Query not optimized for rule\n' \ - f'Expected: {optimized}\nActual: {source}' + err_message = ( + f"\n{self.rule_str(rule)} Query not optimized for rule\nExpected: {optimized}\nActual: {source}" + ) self.assertEqual(tree, optimized, err_message) def test_duplicate_file_names(self): @@ -99,7 +109,7 @@ def test_bbr_validation(self): "rule_id": str(uuid.uuid4()), "severity": "low", "type": "query", - "timestamp_override": "event.ingested" + "timestamp_override": "event.ingested", } def build_rule(query, bbr_type="default", from_field="now-120m", interval="60m"): @@ -107,7 +117,7 @@ def build_rule(query, bbr_type="default", from_field="now-120m", interval="60m") "creation_date": "1970/01/01", "updated_date": "1970/01/01", "min_stack_version": load_current_package_version(), - "integration": ["cloud_defend"] + "integration": ["cloud_defend"], } data = base_fields.copy() data["query"] = query @@ -133,61 +143,63 @@ def build_rule(query, bbr_type="default", from_field="now-120m", interval="60m") def test_max_signals_note(self): """Ensure the max_signals note is present when max_signals > 1000.""" - max_signal_standard_setup = 'For information on troubleshooting the maximum alerts warning '\ - 'please refer to this [guide]'\ - '(https://www.elastic.co/guide/en/security/current/alerts-ui-monitor.html#troubleshoot-max-alerts).' # noqa: E501 + max_signal_standard_setup = ( + "For information on troubleshooting the maximum alerts warning " + "please refer to this [guide]" + "(https://www.elastic.co/guide/en/security/current/alerts-ui-monitor.html#troubleshoot-max-alerts)." + ) for rule in self.all_rules: if rule.contents.data.max_signals and rule.contents.data.max_signals > 1000: - error_message = f'{self.rule_str(rule)} max_signals cannot exceed 1000.' - self.fail(f'{self.rule_str(rule)} max_signals cannot exceed 1000.') + error_message = f"{self.rule_str(rule)} max_signals cannot exceed 1000." + self.fail(f"{self.rule_str(rule)} max_signals cannot exceed 1000.") if rule.contents.data.max_signals and rule.contents.data.max_signals == 1000: - error_message = f'{self.rule_str(rule)} note required for max_signals == 1000' + error_message = f"{self.rule_str(rule)} note required for max_signals == 1000" self.assertIsNotNone(rule.contents.data.setup, error_message) if max_signal_standard_setup not in rule.contents.data.setup: - self.fail(f'{self.rule_str(rule)} expected max_signals note missing\n\n' - f'Expected: {max_signal_standard_setup}\n\n' - f'Actual: {rule.contents.data.setup}') + self.fail( + f"{self.rule_str(rule)} expected max_signals note missing\n\n" + f"Expected: {max_signal_standard_setup}\n\n" + f"Actual: {rule.contents.data.setup}" + ) def test_from_filed_value(self): - """ Add "from" Field Validation for All Rules""" + """Add "from" Field Validation for All Rules""" failures = [] - valid_format = re.compile(r'^now-\d+[yMwdhHms]$') + valid_format = re.compile(r"^now-\d+[yMwdhHms]$") for rule in self.all_rules: - from_field = rule.contents.data.get('from_') - if from_field is not None: - if not valid_format.match(from_field): - err_msg = f'{self.rule_str(rule)} has invalid value {from_field}' - failures.append(err_msg) + from_field = rule.contents.data.get("from_") + if from_field and not valid_format.match(from_field): + err_msg = f"{self.rule_str(rule)} has invalid value {from_field}" + failures.append(err_msg) if failures: fail_msg = """ The following rules have invalid 'from' filed value \n """ - self.fail(fail_msg + '\n'.join(failures)) + self.fail(fail_msg + "\n".join(failures)) def test_index_or_data_view_id_present(self): """Ensure that either 'index' or 'data_view_id' is present for prebuilt rules.""" failures = [] machine_learning_packages = [val.lower() for val in definitions.MACHINE_LEARNING_PACKAGES] for rule in self.all_rules: - rule_type = rule.contents.data.get('language') - rule_integrations = rule.contents.metadata.get('integration') or [] - if rule_type == 'esql': + rule_type = rule.contents.data.get("language") + rule_integrations = rule.contents.metadata.get("integration") or [] + if rule_type == "esql": continue # the index is part of the query and would be validated in the query - elif rule.contents.data.type == 'machine_learning' or rule_integrations in machine_learning_packages: + if rule.contents.data.type == "machine_learning" or rule_integrations in machine_learning_packages: continue # Skip all rules of machine learning type or rules that are part of machine learning packages - elif rule.contents.data.type == 'threat_match': + if rule.contents.data.type == "threat_match": continue # Skip all rules of threat_match type - else: - index = rule.contents.data.get('index') - data_view_id = rule.contents.data.get('data_view_id') - if index is None and data_view_id is None: - err_msg = f'{self.rule_str(rule)} does not have either index or data_view_id' - failures.append(err_msg) + index = rule.contents.data.get("index") + data_view_id = rule.contents.data.get("data_view_id") + if index is None and data_view_id is None: + err_msg = f"{self.rule_str(rule)} does not have either index or data_view_id" + failures.append(err_msg) if failures: fail_msg = """ The following prebuilt rules do not have either 'index' or 'data_view_id' \n """ - self.fail(fail_msg + '\n'.join(failures)) + self.fail(fail_msg + "\n".join(failures)) class TestThreatMappings(BaseRuleTest): @@ -205,14 +217,15 @@ def test_technique_deprecations(self): if threat_mapping: for entry in threat_mapping: - for technique in (entry.technique or []): + for technique in entry.technique or []: if technique.id in revoked + deprecated: - revoked_techniques[technique.id] = replacement_map.get(technique.id, - 'DEPRECATED - DO NOT USE') + revoked_techniques[technique.id] = replacement_map.get( + technique.id, "DEPRECATED - DO NOT USE" + ) if revoked_techniques: - old_new_mapping = "\n".join(f'Actual: {k} -> Expected {v}' for k, v in revoked_techniques.items()) - self.fail(f'{self.rule_str(rule)} Using deprecated ATT&CK techniques: \n{old_new_mapping}') + old_new_mapping = "\n".join(f"Actual: {k} -> Expected {v}" for k, v in revoked_techniques.items()) + self.fail(f"{self.rule_str(rule)} Using deprecated ATT&CK techniques: \n{old_new_mapping}") def test_tactic_to_technique_correlations(self): """Ensure rule threat info is properly related to a single tactic and technique.""" @@ -225,66 +238,89 @@ def test_tactic_to_technique_correlations(self): mismatched = [t.id for t in techniques if t.id not in attack.matrix[tactic.name]] if mismatched: - self.fail(f'mismatched ATT&CK techniques for rule: {self.rule_str(rule)} ' - f'{", ".join(mismatched)} not under: {tactic["name"]}') + self.fail( + f"mismatched ATT&CK techniques for rule: {self.rule_str(rule)} " + f"{', '.join(mismatched)} not under: {tactic['name']}" + ) # tactic expected_tactic = attack.tactics_map[tactic.name] - self.assertEqual(expected_tactic, tactic.id, - f'ATT&CK tactic mapping error for rule: {self.rule_str(rule)}\n' - f'expected: {expected_tactic} for {tactic.name}\n' - f'actual: {tactic.id}') - - tactic_reference_id = tactic.reference.rstrip('/').split('/')[-1] - self.assertEqual(tactic.id, tactic_reference_id, - f'ATT&CK tactic mapping error for rule: {self.rule_str(rule)}\n' - f'tactic ID {tactic.id} does not match the reference URL ID ' - f'{tactic.reference}') + self.assertEqual( + expected_tactic, + tactic.id, + f"ATT&CK tactic mapping error for rule: {self.rule_str(rule)}\n" + f"expected: {expected_tactic} for {tactic.name}\n" + f"actual: {tactic.id}", + ) + + tactic_reference_id = tactic.reference.rstrip("/").split("/")[-1] + self.assertEqual( + tactic.id, + tactic_reference_id, + f"ATT&CK tactic mapping error for rule: {self.rule_str(rule)}\n" + f"tactic ID {tactic.id} does not match the reference URL ID " + f"{tactic.reference}", + ) # techniques for technique in techniques: - expected_technique = attack.technique_lookup[technique.id]['name'] - self.assertEqual(expected_technique, technique.name, - f'ATT&CK technique mapping error for rule: {self.rule_str(rule)}\n' - f'expected: {expected_technique} for {technique.id}\n' - f'actual: {technique.name}') - - technique_reference_id = technique.reference.rstrip('/').split('/')[-1] - self.assertEqual(technique.id, technique_reference_id, - f'ATT&CK technique mapping error for rule: {self.rule_str(rule)}\n' - f'technique ID {technique.id} does not match the reference URL ID ' - f'{technique.reference}') + expected_technique = attack.technique_lookup[technique.id]["name"] + self.assertEqual( + expected_technique, + technique.name, + f"ATT&CK technique mapping error for rule: {self.rule_str(rule)}\n" + f"expected: {expected_technique} for {technique.id}\n" + f"actual: {technique.name}", + ) + + technique_reference_id = technique.reference.rstrip("/").split("/")[-1] + self.assertEqual( + technique.id, + technique_reference_id, + f"ATT&CK technique mapping error for rule: {self.rule_str(rule)}\n" + f"technique ID {technique.id} does not match the reference URL ID " + f"{technique.reference}", + ) # sub-techniques sub_techniques = technique.subtechnique or [] if sub_techniques: for sub_technique in sub_techniques: - expected_sub_technique = attack.technique_lookup[sub_technique.id]['name'] - self.assertEqual(expected_sub_technique, sub_technique.name, - f'ATT&CK sub-technique mapping error for rule: {self.rule_str(rule)}\n' - f'expected: {expected_sub_technique} for {sub_technique.id}\n' - f'actual: {sub_technique.name}') - - sub_technique_reference_id = '.'.join( - sub_technique.reference.rstrip('/').split('/')[-2:]) - self.assertEqual(sub_technique.id, sub_technique_reference_id, - f'ATT&CK sub-technique mapping error for rule: {self.rule_str(rule)}\n' - f'sub-technique ID {sub_technique.id} does not match the reference URL ID ' # noqa: E501 - f'{sub_technique.reference}') + expected_sub_technique = attack.technique_lookup[sub_technique.id]["name"] + self.assertEqual( + expected_sub_technique, + sub_technique.name, + f"ATT&CK sub-technique mapping error for rule: {self.rule_str(rule)}\n" + f"expected: {expected_sub_technique} for {sub_technique.id}\n" + f"actual: {sub_technique.name}", + ) + + sub_technique_reference_id = ".".join( + sub_technique.reference.rstrip("/").split("/")[-2:] + ) + self.assertEqual( + sub_technique.id, + sub_technique_reference_id, + f"ATT&CK sub-technique mapping error for rule: {self.rule_str(rule)}\n" + f"sub-technique ID {sub_technique.id} does not match the reference URL ID " + f"{sub_technique.reference}", + ) def test_duplicated_tactics(self): """Check that a tactic is only defined once.""" for rule in self.all_rules: threat_mapping = rule.contents.data.threat tactics = [t.tactic.name for t in threat_mapping or []] - duplicates = sorted(set(t for t in tactics if tactics.count(t) > 1)) + duplicates = sorted({t for t in tactics if tactics.count(t) > 1}) if duplicates: - self.fail(f'{self.rule_str(rule)} duplicate tactics defined for {duplicates}. ' - f'Flatten to a single entry per tactic') + self.fail( + f"{self.rule_str(rule)} duplicate tactics defined for {duplicates}. " + f"Flatten to a single entry per tactic" + ) -@unittest.skipIf(os.environ.get('DR_BYPASS_TAGS_VALIDATION') is not None, "Skipping tag validation") +@unittest.skipIf(os.environ.get("DR_BYPASS_TAGS_VALIDATION") is not None, "Skipping tag validation") class TestRuleTags(BaseRuleTest): """Test tags data for rules.""" @@ -297,59 +333,62 @@ def test_casing_and_spacing(self): rule_tags = rule.contents.data.tags if rule_tags: - invalid_tags = {t: expected_case[t.casefold()] for t in rule_tags - if t.casefold() in list(expected_case) and t != expected_case[t.casefold()]} + invalid_tags = { + t: expected_case[t.casefold()] + for t in rule_tags + if t.casefold() in list(expected_case) and t != expected_case[t.casefold()] + } if invalid_tags: - error_msg = f'{self.rule_str(rule)} Invalid casing for expected tags\n' - error_msg += f'Actual tags: {", ".join(invalid_tags)}\n' - error_msg += f'Expected tags: {", ".join(invalid_tags.values())}' + error_msg = f"{self.rule_str(rule)} Invalid casing for expected tags\n" + error_msg += f"Actual tags: {', '.join(invalid_tags)}\n" + error_msg += f"Expected tags: {', '.join(invalid_tags.values())}" self.fail(error_msg) def test_required_tags(self): """Test that expected tags are present within rules.""" required_tags_map = { - 'logs-endpoint.events.*': {'all': ['Domain: Endpoint', 'Data Source: Elastic Defend']}, - 'endgame-*': {'all': ['Data Source: Elastic Endgame']}, - 'logs-aws*': {'all': ['Data Source: AWS', 'Data Source: Amazon Web Services', 'Domain: Cloud']}, - 'logs-azure*': {'all': ['Data Source: Azure', 'Domain: Cloud']}, - 'logs-o365*': {'all': ['Data Source: Microsoft 365', 'Domain: Cloud']}, - 'logs-okta*': {'all': ['Data Source: Okta']}, - 'logs-gcp*': {'all': ['Data Source: Google Cloud Platform', 'Data Source: GCP', 'Domain: Cloud']}, - 'logs-google_workspace*': {'all': ['Data Source: Google Workspace', 'Domain: Cloud']}, - 'logs-cloud_defend.alerts-*': {'all': ['Data Source: Elastic Defend for Containers', 'Domain: Container']}, - 'logs-cloud_defend*': {'all': ['Data Source: Elastic Defend for Containers', 'Domain: Container']}, - 'logs-kubernetes.*': {'all': ['Data Source: Kubernetes']}, - 'apm-*-transaction*': {'all': ['Data Source: APM']}, - 'traces-apm*': {'all': ['Data Source: APM']}, - '.alerts-security.*': {'all': ['Rule Type: Higher-Order Rule']}, - 'logs-cyberarkpas.audit*': {'all': ['Data Source: CyberArk PAS']}, - 'logs-endpoint.alerts-*': {'all': ['Data Source: Elastic Defend']}, - 'logs-windows.sysmon_operational-*': {'all': ['Data Source: Sysmon']}, - 'logs-windows.powershell*': {'all': ['Data Source: PowerShell Logs']}, - 'logs-system.security*': {'all': ['Data Source: Windows Security Event Logs']}, - 'logs-system.forwarded*': {'all': ['Data Source: Windows Security Event Logs']}, - 'logs-system.system*': {'all': ['Data Source: Windows System Event Logs']}, - 'logs-sentinel_one_cloud_funnel.*': {'all': ['Data Source: SentinelOne']}, - 'logs-fim.event-*': {'all': ['Data Source: File Integrity Monitoring']}, - 'logs-m365_defender.event-*': {'all': ['Data Source: Microsoft Defender for Endpoint']}, - 'logs-crowdstrike.fdr*': {'all': ['Data Source: Crowdstrike']} + "logs-endpoint.events.*": {"all": ["Domain: Endpoint", "Data Source: Elastic Defend"]}, + "endgame-*": {"all": ["Data Source: Elastic Endgame"]}, + "logs-aws*": {"all": ["Data Source: AWS", "Data Source: Amazon Web Services", "Domain: Cloud"]}, + "logs-azure*": {"all": ["Data Source: Azure", "Domain: Cloud"]}, + "logs-o365*": {"all": ["Data Source: Microsoft 365", "Domain: Cloud"]}, + "logs-okta*": {"all": ["Data Source: Okta"]}, + "logs-gcp*": {"all": ["Data Source: Google Cloud Platform", "Data Source: GCP", "Domain: Cloud"]}, + "logs-google_workspace*": {"all": ["Data Source: Google Workspace", "Domain: Cloud"]}, + "logs-cloud_defend.alerts-*": {"all": ["Data Source: Elastic Defend for Containers", "Domain: Container"]}, + "logs-cloud_defend*": {"all": ["Data Source: Elastic Defend for Containers", "Domain: Container"]}, + "logs-kubernetes.*": {"all": ["Data Source: Kubernetes"]}, + "apm-*-transaction*": {"all": ["Data Source: APM"]}, + "traces-apm*": {"all": ["Data Source: APM"]}, + ".alerts-security.*": {"all": ["Rule Type: Higher-Order Rule"]}, + "logs-cyberarkpas.audit*": {"all": ["Data Source: CyberArk PAS"]}, + "logs-endpoint.alerts-*": {"all": ["Data Source: Elastic Defend"]}, + "logs-windows.sysmon_operational-*": {"all": ["Data Source: Sysmon"]}, + "logs-windows.powershell*": {"all": ["Data Source: PowerShell Logs"]}, + "logs-system.security*": {"all": ["Data Source: Windows Security Event Logs"]}, + "logs-system.forwarded*": {"all": ["Data Source: Windows Security Event Logs"]}, + "logs-system.system*": {"all": ["Data Source: Windows System Event Logs"]}, + "logs-sentinel_one_cloud_funnel.*": {"all": ["Data Source: SentinelOne"]}, + "logs-fim.event-*": {"all": ["Data Source: File Integrity Monitoring"]}, + "logs-m365_defender.event-*": {"all": ["Data Source: Microsoft Defender for Endpoint"]}, + "logs-crowdstrike.fdr*": {"all": ["Data Source: Crowdstrike"]}, } for rule in self.all_rules: rule_tags = rule.contents.data.tags - error_msg = f'{self.rule_str(rule)} Missing tags:\nActual tags: {", ".join(rule_tags)}' + error_msg = f"{self.rule_str(rule)} Missing tags:\nActual tags: {', '.join(rule_tags)}" consolidated_optional_tags = [] is_missing_any_tags = False missing_required_tags = set() if isinstance(rule.contents.data, QueryRuleData): - for index in rule.contents.data.get('index') or []: + for index in rule.contents.data.get("index") or []: expected_tags = required_tags_map.get(index, {}) - expected_all = expected_tags.get('all', []) - expected_any = expected_tags.get('any', []) + expected_all = expected_tags.get("all", []) + expected_any = expected_tags.get("any", []) existing_any_tags = [t for t in rule_tags if t in expected_any] if expected_any: @@ -360,21 +399,20 @@ def test_required_tags(self): is_missing_any_tags = expected_any and not set(expected_any) & set(existing_any_tags) consolidated_optional_tags = [t for t in consolidated_optional_tags if t not in missing_required_tags] - error_msg += f'\nMissing all of: {", ".join(missing_required_tags)}' if missing_required_tags else '' - error_msg += f'\nMissing any of: {", " .join(consolidated_optional_tags)}' if is_missing_any_tags else '' + error_msg += f"\nMissing all of: {', '.join(missing_required_tags)}" if missing_required_tags else "" + error_msg += f"\nMissing any of: {', '.join(consolidated_optional_tags)}" if is_missing_any_tags else "" if missing_required_tags or is_missing_any_tags: self.fail(error_msg) def test_bbr_tags(self): """Test that "Rule Type: BBR" tag is present for all BBR rules.""" - invalid_bbr_rules = [] - for rule in self.bbr: - if 'Rule Type: BBR' not in rule.contents.data.tags: - invalid_bbr_rules.append(self.rule_str(rule)) + invalid_bbr_rules = [ + self.rule_str(rule) for rule in self.bbr if "Rule Type: BBR" not in rule.contents.data.tags + ] if invalid_bbr_rules: - error_rules = '\n'.join(invalid_bbr_rules) - self.fail(f'The following building block rule(s) have missing tag: Rule Type: BBR:\n{error_rules}') + error_rules = "\n".join(invalid_bbr_rules) + self.fail(f"The following building block rule(s) have missing tag: Rule Type: BBR:\n{error_rules}") def test_primary_tactic_as_tag(self): """Test that the primary tactic is present as a tag.""" @@ -386,7 +424,7 @@ def test_primary_tactic_as_tag(self): for rule in self.all_rules: rule_tags = rule.contents.data.tags - if 'Continuous Monitoring' in rule_tags or rule.contents.data.type == 'machine_learning': + if "Continuous Monitoring" in rule_tags or rule.contents.data.type == "machine_learning": continue threat = rule.contents.data.threat @@ -406,37 +444,31 @@ def test_primary_tactic_as_tag(self): if missing or missing_from_threat: err_msg = self.rule_str(rule) if missing: - err_msg += f'\n expected: {missing}' + err_msg += f"\n expected: {missing}" if missing_from_threat: - err_msg += f'\n unexpected (or missing from threat mapping): {missing_from_threat}' + err_msg += f"\n unexpected (or missing from threat mapping): {missing_from_threat}" invalid.append(err_msg) if invalid: - err_msg = '\n'.join(invalid) - self.fail(f'Rules with misaligned tags and tactics:\n{err_msg}') + err_msg = "\n".join(invalid) + self.fail(f"Rules with misaligned tags and tactics:\n{err_msg}") def test_os_tags(self): """Test that OS tags are present within rules.""" - required_tags_map = { - 'linux': 'OS: Linux', - 'macos': 'OS: macOS', - 'windows': 'OS: Windows' - } + required_tags_map = {"linux": "OS: Linux", "macos": "OS: macOS", "windows": "OS: Windows"} invalid = [] for rule in self.all_rules: dir_name = rule.path.parent.name - # if directory name is linux, macos, or windows, - # ensure the rule has the corresponding tag - if dir_name in ['linux', 'macos', 'windows']: - if required_tags_map[dir_name] not in rule.contents.data.tags: - err_msg = self.rule_str(rule) - err_msg += f'\n expected: {required_tags_map[dir_name]}' - invalid.append(err_msg) + # if directory name is linux, macos, or windows, ensure the rule has the corresponding tag + if dir_name in ["linux", "macos", "windows"] and required_tags_map[dir_name] not in rule.contents.data.tags: + err_msg = self.rule_str(rule) + err_msg += f"\n expected: {required_tags_map[dir_name]}" + invalid.append(err_msg) if invalid: - err_msg = '\n'.join(invalid) - self.fail(f'Rules with missing OS tags:\n{err_msg}') + err_msg = "\n".join(invalid) + self.fail(f"Rules with missing OS tags:\n{err_msg}") def test_ml_rule_type_tags(self): """Test that ML rule type tags are present within rules.""" @@ -445,36 +477,35 @@ def test_ml_rule_type_tags(self): for rule in self.all_rules: rule_tags = rule.contents.data.tags - if rule.contents.data.type == 'machine_learning': - if 'Rule Type: Machine Learning' not in rule_tags: + if rule.contents.data.type == "machine_learning": + if "Rule Type: Machine Learning" not in rule_tags: err_msg = self.rule_str(rule) - err_msg += '\n expected: Rule Type: Machine Learning' + err_msg += "\n expected: Rule Type: Machine Learning" invalid.append(err_msg) - if 'Rule Type: ML' not in rule_tags: + if "Rule Type: ML" not in rule_tags: err_msg = self.rule_str(rule) - err_msg += '\n expected: Rule Type: ML' + err_msg += "\n expected: Rule Type: ML" invalid.append(err_msg) if invalid: - err_msg = '\n'.join(invalid) - self.fail(f'Rules with misaligned ML rule type tags:\n{err_msg}') + err_msg = "\n".join(invalid) + self.fail(f"Rules with misaligned ML rule type tags:\n{err_msg}") def test_investigation_guide_tag(self): """Test that investigation guide tags are present within rules.""" invalid = [] for rule in self.all_rules: - note = rule.contents.data.get('note') + note = rule.contents.data.get("note") if note is not None: - results = re.search(r'Investigating', note, re.M) - if results is not None: - # check if investigation guide tag is present - if 'Resources: Investigation Guide' not in rule.contents.data.tags: - err_msg = self.rule_str(rule) - err_msg += '\n expected: Resources: Investigation Guide' - invalid.append(err_msg) + results = re.search(r"Investigating", note, re.M) + # check if investigation guide tag is present + if results and "Resources: Investigation Guide" not in rule.contents.data.tags: + err_msg = self.rule_str(rule) + err_msg += "\n expected: Resources: Investigation Guide" + invalid.append(err_msg) if invalid: - err_msg = '\n'.join(invalid) - self.fail(f'Rules with missing Investigation tag:\n{err_msg}') + err_msg = "\n".join(invalid) + self.fail(f"Rules with missing Investigation tag:\n{err_msg}") def test_tag_prefix(self): """Ensure all tags have a prefix from an expected list.""" @@ -482,11 +513,16 @@ def test_tag_prefix(self): for rule in self.all_rules: rule_tags = rule.contents.data.tags - expected_prefixes = set([tag.split(":")[0] + ":" for tag in definitions.EXPECTED_RULE_TAGS]) - [invalid.append(f"{self.rule_str(rule)}-{tag}") for tag in rule_tags - if not any(prefix in tag for prefix in expected_prefixes)] + expected_prefixes = {tag.split(":")[0] + ":" for tag in definitions.EXPECTED_RULE_TAGS} + invalid.extend( + [ + f"{self.rule_str(rule)}-{tag}" + for tag in rule_tags + if not any(prefix in tag for prefix in expected_prefixes) + ] + ) if invalid: - self.fail(f'Rules with invalid tags:\n{invalid}') + self.fail(f"Rules with invalid tags:\n{invalid}") def test_no_duplicate_tags(self): """Ensure no rules have duplicate tags.""" @@ -498,7 +534,7 @@ def test_no_duplicate_tags(self): invalid.append(self.rule_str(rule)) if invalid: - self.fail(f'Rules with duplicate tags:\n{invalid}') + self.fail(f"Rules with duplicate tags:\n{invalid}") class TestRuleTimelines(BaseRuleTest): @@ -517,14 +553,15 @@ def test_timeline_has_title(self): self.fail(missing_err) if timeline_id: - unknown_id = f'{self.rule_str(rule)} Unknown timeline_id: {timeline_id}.' - unknown_id += f' replace with {", ".join(TIMELINE_TEMPLATES)} ' \ - f'or update this unit test with acceptable ids' + unknown_id = f"{self.rule_str(rule)} Unknown timeline_id: {timeline_id}." + unknown_id += ( + f" replace with {', '.join(TIMELINE_TEMPLATES)} or update this unit test with acceptable ids" + ) self.assertIn(timeline_id, list(TIMELINE_TEMPLATES), unknown_id) - unknown_title = f'{self.rule_str(rule)} unknown timeline_title: {timeline_title}' - unknown_title += f' replace with {", ".join(TIMELINE_TEMPLATES.values())}' - unknown_title += ' or update this unit test with acceptable titles' + unknown_title = f"{self.rule_str(rule)} unknown timeline_title: {timeline_title}" + unknown_title += f" replace with {', '.join(TIMELINE_TEMPLATES.values())}" + unknown_title += " or update this unit test with acceptable titles" self.assertEqual(timeline_title, TIMELINE_TEMPLATES[timeline_id], unknown_title) @@ -546,39 +583,48 @@ def test_rule_file_name_tactic(self): threat = rule.contents.data.threat authors = rule.contents.data.author - if threat and 'Elastic' in authors: + if threat and "Elastic" in authors: primary_tactic = threat[0].tactic.name - tactic_str = primary_tactic.lower().replace(' ', '_') + tactic_str = primary_tactic.lower().replace(" ", "_") - if tactic_str != filename[:len(tactic_str)]: - bad_name_rules.append(f'{rule.id} - {Path(rule.path).name} -> expected: {tactic_str}') + if tactic_str != filename[: len(tactic_str)]: + bad_name_rules.append(f"{rule.id} - {Path(rule.path).name} -> expected: {tactic_str}") if bad_name_rules: - error_msg = 'filename does not start with the primary tactic - update the tactic or the rule filename' - rule_err_str = '\n'.join(bad_name_rules) - self.fail(f'{error_msg}:\n{rule_err_str}') + error_msg = "filename does not start with the primary tactic - update the tactic or the rule filename" + rule_err_str = "\n".join(bad_name_rules) + self.fail(f"{error_msg}:\n{rule_err_str}") def test_bbr_in_correct_dir(self): """Ensure that BBR are in the correct directory.""" for rule in self.bbr: # Is the rule a BBR - self.assertEqual(rule.contents.data.building_block_type, 'default', - f'{self.rule_str(rule)} should have building_block_type = "default"') + self.assertEqual( + rule.contents.data.building_block_type, + "default", + f'{self.rule_str(rule)} should have building_block_type = "default"', + ) # Is the rule in the rules_building_block directory - self.assertEqual(rule.path.parent.name, 'rules_building_block', - f'{self.rule_str(rule)} should be in the rules_building_block directory') + self.assertEqual( + rule.path.parent.name, + "rules_building_block", + f"{self.rule_str(rule)} should be in the rules_building_block directory", + ) def test_non_bbr_in_correct_dir(self): """Ensure that non-BBR are not in BBR directory.""" - proper_directory = 'rules_building_block' + proper_directory = "rules_building_block" for rule in self.all_rules: - if rule.path.parent.name == 'rules_building_block': - self.assertIn(rule, self.bbr, f'{self.rule_str(rule)} should be in the {proper_directory}') + if rule.path.parent.name == "rules_building_block": + self.assertIn(rule, self.bbr, f"{self.rule_str(rule)} should be in the {proper_directory}") else: # Is the rule of type BBR and not in the correct directory - self.assertEqual(rule.contents.data.building_block_type, None, - f'{self.rule_str(rule)} should be in {proper_directory}') + self.assertEqual( + rule.contents.data.building_block_type, + None, + f"{self.rule_str(rule)} should be in {proper_directory}", + ) class TestRuleMetadata(BaseRuleTest): @@ -589,14 +635,14 @@ def test_updated_date_newer_than_creation(self): invalid = [] for rule in self.all_rules: - created = rule.contents.metadata.creation_date.split('/') - updated = rule.contents.metadata.updated_date.split('/') + created = rule.contents.metadata.creation_date.split("/") + updated = rule.contents.metadata.updated_date.split("/") if updated < created: invalid.append(rule) if invalid: - rules_str = '\n '.join(self.rule_str(r, trailer=None) for r in invalid) - err_msg = f'The following rules have an updated_date older than the creation_date\n {rules_str}' + rules_str = "\n ".join(self.rule_str(r, trailer=None) for r in invalid) + err_msg = f"The following rules have an updated_date older than the creation_date\n {rules_str}" self.fail(err_msg) @unittest.skipIf(RULES_CONFIG.bypass_version_lock, "Skipping deprecated version lock check") @@ -614,58 +660,55 @@ def test_deprecated_rules(self): misplaced_rules.append(r) else: for rules_path in rules_paths: - if "_deprecated" in r.path.relative_to(rules_path).parts \ - and r.contents.metadata.maturity != "deprecated": + if ( + "_deprecated" in r.path.relative_to(rules_path).parts + and r.contents.metadata.maturity != "deprecated" + ): misplaced_rules.append(r) break - misplaced = '\n'.join(f'{self.rule_str(r)} {r.contents.metadata.maturity}' for r in misplaced_rules) - err_str = f'The following rules are stored in _deprecated but are not marked as deprecated:\n{misplaced}' + misplaced = "\n".join(f"{self.rule_str(r)} {r.contents.metadata.maturity}" for r in misplaced_rules) + err_str = f"The following rules are stored in _deprecated but are not marked as deprecated:\n{misplaced}" self.assertListEqual(misplaced_rules, [], err_str) for rule in self.deprecated_rules: meta = rule.contents.metadata deprecated_rules[rule.id] = rule - err_msg = f'{self.rule_str(rule)} cannot be deprecated if it has not been version locked. ' \ - f'Convert to `development` or delete the rule file instead' + err_msg = ( + f"{self.rule_str(rule)} cannot be deprecated if it has not been version locked. " + f"Convert to `development` or delete the rule file instead" + ) self.assertIn(rule.id, versions, err_msg) rule_path = rule.path.relative_to(rules_path) - err_msg = f'{self.rule_str(rule)} deprecated rules should be stored in ' \ - f'"{rule_path.parent / "_deprecated"}" folder' - self.assertEqual('_deprecated', rule_path.parts[-2], err_msg) - - err_msg = f'{self.rule_str(rule)} missing deprecation date' - self.assertIsNotNone(meta['deprecation_date'], err_msg) + err_msg = ( + f"{self.rule_str(rule)} deprecated rules should be stored in " + f'"{rule_path.parent / "_deprecated"}" folder' + ) + self.assertEqual("_deprecated", rule_path.parts[-2], err_msg) - err_msg = f'{self.rule_str(rule)} deprecation_date and updated_date should match' - self.assertEqual(meta['deprecation_date'], meta['updated_date'], err_msg) + err_msg = f"{self.rule_str(rule)} missing deprecation date" + self.assertIsNotNone(meta["deprecation_date"], err_msg) - # skip this so the lock file can be shared across branches - # - # missing_rules = sorted(set(versions).difference(set(self.rule_lookup))) - # missing_rule_strings = '\n '.join(f'{r} - {versions[r]["rule_name"]}' for r in missing_rules) - # err_msg = f'Deprecated rules should not be removed, but moved to the rules/_deprecated folder instead. ' \ - # f'The following rules have been version locked and are missing. ' \ - # f'Re-add to the deprecated folder and update maturity to "deprecated": \n {missing_rule_strings}' - # self.assertEqual([], missing_rules, err_msg) + err_msg = f"{self.rule_str(rule)} deprecation_date and updated_date should match" + self.assertEqual(meta["deprecation_date"], meta["updated_date"], err_msg) for rule_id, entry in deprecations.items(): # if a rule is deprecated and not backported in order to keep the rule active in older branches, then it # will exist in the deprecated_rules.json file and not be in the _deprecated folder - this is expected. # However, that should not occur except by exception - the proper way to handle this situation is to # "fork" the existing rule by adding a new min_stack_version. - if PACKAGE_STACK_VERSION < Version.parse(entry['stack_version'], optional_minor_and_patch=True): + if PACKAGE_STACK_VERSION < Version.parse(entry["stack_version"], optional_minor_and_patch=True): continue - rule_str = f'{rule_id} - {entry["rule_name"]} ->' + rule_str = f"{rule_id} - {entry['rule_name']} ->" self.assertIn(rule_id, deprecated_rules, f'{rule_str} is logged in "deprecated_rules.json" but is missing') def test_deprecated_rules_modified(self): """Test to ensure deprecated rules are not modified.""" - rules_path = get_path("rules", "_deprecated") + rules_path = get_path(["rules", "_deprecated"]) # Use git diff to check if the file(s) has been modified in rules/_deprecated directory detection_rules_git = make_git() @@ -675,13 +718,12 @@ def test_deprecated_rules_modified(self): if result: self.fail(f"Deprecated rules {result} has been modified") - @unittest.skipIf(os.getenv('GITHUB_EVENT_NAME') == 'push', - "Skipping this test when not running on pull requests.") + @unittest.skipIf(os.getenv("GITHUB_EVENT_NAME") == "push", "Skipping this test when not running on pull requests.") def test_rule_change_has_updated_date(self): """Test to ensure modified rules have updated_date field updated.""" - rules_path = get_path("rules") - rules_bbr_path = get_path("rules_building_block") + rules_path = get_path(["rules"]) + rules_bbr_path = get_path(["rules_building_block"]) # Use git diff to check if the file(s) has been modified in rules/ rules_build_block/ directories. # For now this checks even rules/_deprecated any modification there will fail @@ -689,120 +731,146 @@ def test_rule_change_has_updated_date(self): # is not required as there is a specific test for deprecated rules. detection_rules_git = make_git() - result = detection_rules_git("diff", "--diff-filter=M", "origin/main", "--name-only", - rules_path, rules_bbr_path) + result = detection_rules_git( + "diff", "--diff-filter=M", "origin/main", "--name-only", rules_path, rules_bbr_path + ) # If the output is not empty, then file(s) have changed in the directory(s) if result: modified_rules = result.splitlines() failed_rules = [] for modified_rule_path in modified_rules: - diff_output = detection_rules_git('diff', 'origin/main', modified_rule_path) - if not re.search(r'\+\s*updated_date =', diff_output): + diff_output = detection_rules_git("diff", "origin/main", modified_rule_path) + if not re.search(r"\+\s*updated_date =", diff_output): # Rule has been modified but updated_date has not been changed, add to list of failed rules - failed_rules.append(f'{modified_rule_path}') + failed_rules.append(f"{modified_rule_path}") if failed_rules: fail_msg = """ The following rules in the below path(s) have been modified but updated_date has not been changed \n """ - self.fail(fail_msg + '\n'.join(failed_rules)) + self.fail(fail_msg + "\n".join(failed_rules)) - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.3.0"), - "Test only applicable to 8.3+ stacks regarding related integrations build time field.") - def test_integration_tag(self): + @unittest.skipIf( + PACKAGE_STACK_VERSION < Version.parse("8.3.0"), + "Test only applicable to 8.3+ stacks regarding related integrations build time field.", + ) + def test_integration_tag(self): # noqa: PLR0912, PLR0915 """Test integration rules defined by metadata tag.""" failures = [] - non_dataset_packages = definitions.NON_DATASET_PACKAGES + ["winlog"] + non_dataset_packages = [*definitions.NON_DATASET_PACKAGES, "winlog"] packages_manifest = load_integrations_manifests() - valid_integration_folders = [p.name for p in list(Path(INTEGRATION_RULE_DIR).glob("*")) if p.name != 'endpoint'] + valid_integration_folders = [p.name for p in list(Path(INTEGRATION_RULE_DIR).glob("*")) if p.name != "endpoint"] for rule in self.all_rules: # TODO: temp bypass for esql rules; once parsed, we should be able to look for indexes via `FROM` - if not rule.contents.data.get('index'): + if not rule.contents.data.get("index"): continue - if isinstance(rule.contents.data, QueryRuleData) and rule.contents.data.language != 'lucene': - rule_integrations = rule.contents.metadata.get('integration') or [] + if isinstance(rule.contents.data, QueryRuleData) and rule.contents.data.language != "lucene": + rule_integrations = rule.contents.metadata.get("integration") or [] rule_integrations = [rule_integrations] if isinstance(rule_integrations, str) else rule_integrations - rule_promotion = rule.contents.metadata.get('promotion') + rule_promotion = rule.contents.metadata.get("promotion") data = rule.contents.data meta = rule.contents.metadata package_integrations = TOMLRuleContents.get_packaged_integrations(data, meta, packages_manifest) - package_integrations_list = list(set([integration["package"] for integration in package_integrations])) - indices = data.get('index') or [] + package_integrations_list = {integration["package"] for integration in package_integrations} + indices = data.get("index") or [] for rule_integration in rule_integrations: - if ("even.dataset" in rule.contents.data.query and not package_integrations and # noqa: W504 - not rule_promotion and rule_integration not in definitions.NON_DATASET_PACKAGES): # noqa: W504 - err_msg = f'{self.rule_str(rule)} {rule_integration} tag, but integration not \ - found in manifests/schemas.' + if ( + "even.dataset" in rule.contents.data.query + and not package_integrations + and not rule_promotion + and rule_integration not in definitions.NON_DATASET_PACKAGES + ): + err_msg = f"{self.rule_str(rule)} {rule_integration} tag, but integration not \ + found in manifests/schemas." failures.append(err_msg) # checks if the rule path matches the intended integration # excludes BBR rules - if rule_integration in valid_integration_folders and \ - not hasattr(rule.contents.data, 'building_block_type'): - if rule.path.parent.name not in rule_integrations: - err_msg = f'{self.rule_str(rule)} {rule_integration} tag, path is {rule.path.parent.name}' - failures.append(err_msg) + if ( + rule_integration in valid_integration_folders + and not hasattr(rule.contents.data, "building_block_type") + and rule.path.parent.name not in rule_integrations + ): + err_msg = f"{self.rule_str(rule)} {rule_integration} tag, path is {rule.path.parent.name}" + failures.append(err_msg) # checks if an index pattern exists if the package integration tag exists # and is of pattern logs-{integration}* integration_string = "|".join(indices) if not re.search(f"logs-{rule_integration}*", integration_string): - if rule_integration == "windows" and re.search("winlog", integration_string) or \ - any(ri in [*map(str.lower, definitions.MACHINE_LEARNING_PACKAGES)] - for ri in rule_integrations): - continue - elif rule_integration == "apm" and \ - re.search("apm-*-transaction*|traces-apm*", integration_string): + if ( + (rule_integration == "windows" and re.search("winlog", integration_string)) + or any( + ri in [*map(str.lower, definitions.MACHINE_LEARNING_PACKAGES)] + for ri in rule_integrations + ) + ) or ( + rule_integration == "apm" + and re.search("apm-*-transaction*|traces-apm*", integration_string) + ): continue - elif rule.contents.data.type == 'threat_match': + if rule.contents.data.type == "threat_match": continue - err_msg = f'{self.rule_str(rule)} {rule_integration} tag, index pattern missing or incorrect.' + err_msg = f"{self.rule_str(rule)} {rule_integration} tag, index pattern missing or incorrect." failures.append(err_msg) # checks if event.dataset exists in query object and a tag exists in metadata # checks if metadata tag matches from a list of integrations in EPR if package_integrations and sorted(rule_integrations) != sorted(package_integrations_list): - err_msg = f'{self.rule_str(rule)} integration tags: {rule_integrations} != ' \ - f'package integrations: {package_integrations_list}' + err_msg = ( + f"{self.rule_str(rule)} integration tags: {rule_integrations} != " + f"package integrations: {package_integrations_list}" + ) failures.append(err_msg) - else: - # checks if rule has index pattern integration and the integration tag exists - # ignore the External Alerts rule, Threat Indicator Matching Rules, Guided onboarding - if any([re.search("|".join(non_dataset_packages), i, re.IGNORECASE) - for i in rule.contents.data.get('index') or []]): - if not rule.contents.metadata.integration and rule.id not in definitions.IGNORE_IDS and \ - rule.contents.data.type not in definitions.MACHINE_LEARNING: - err_msg = f'substrings {non_dataset_packages} found in '\ - f'{self.rule_str(rule)} rule index patterns are {rule.contents.data.index},' \ - f'but no integration tag found' - failures.append(err_msg) + # checks if rule has index pattern integration and the integration tag exists + # ignore the External Alerts rule, Threat Indicator Matching Rules, Guided onboarding + elif any( + re.search("|".join(non_dataset_packages), i, re.IGNORECASE) + for i in rule.contents.data.get("index") or [] + ): + if ( + not rule.contents.metadata.integration + and rule.id not in definitions.IGNORE_IDS + and rule.contents.data.type not in definitions.MACHINE_LEARNING + ): + err_msg = ( + f"substrings {non_dataset_packages} found in " + f"{self.rule_str(rule)} rule index patterns are {rule.contents.data.index}," + f"but no integration tag found" + ) + failures.append(err_msg) # checks for a defined index pattern, the related integration exists in metadata expected_integrations, missing_integrations = set(), set() - ignore_ml_packages = any(ri in [*map(str.lower, definitions.MACHINE_LEARNING_PACKAGES)] - for ri in rule_integrations) + ignore_ml_packages = any( + ri in [*map(str.lower, definitions.MACHINE_LEARNING_PACKAGES)] for ri in rule_integrations + ) for index in indices: - if index in definitions.IGNORE_INDICES or ignore_ml_packages or \ - rule.id in definitions.IGNORE_IDS or rule.contents.data.type == 'threat_match': + if ( + index in definitions.IGNORE_INDICES + or ignore_ml_packages + or rule.id in definitions.IGNORE_IDS + or rule.contents.data.type == "threat_match" + ): continue # Outlier integration log pattern to identify integration - if index == 'apm-*-transaction*': - index_map = ['apm'] + if index == "apm-*-transaction*": + index_map = ["apm"] else: # Split by hyphen to get the second part of index - index_part1, _, index_part2 = index.partition('-') + index_part1, _, index_part2 = index.partition("-") # Use regular expression to extract alphanumeric words, which is integration name - parsed_integration = re.search(r'\b\w+\b', index_part2 or index_part1) + parsed_integration = re.search(r"\b\w+\b", index_part2 or index_part1) index_map = [parsed_integration.group(0) if parsed_integration else None] if not index_map: - self.fail(f'{self.rule_str(rule)} Could not determine the integration from Index {index}') + self.fail(f"{self.rule_str(rule)} Could not determine the integration from Index {index}") expected_integrations.update(index_map) missing_integrations.update(expected_integrations.difference(set(rule_integrations))) if missing_integrations: - error_msg = f'{self.rule_str(rule)} Missing integration metadata: {", ".join(missing_integrations)}' + error_msg = f"{self.rule_str(rule)} Missing integration metadata: {', '.join(missing_integrations)}" failures.append(error_msg) if failures: @@ -811,7 +879,7 @@ def test_integration_tag(self): Try updating the integrations manifest file: - `python -m detection_rules dev integrations build-manifests`\n """ - self.fail(err_msg + '\n'.join(failures)) + self.fail(err_msg + "\n".join(failures)) def test_invalid_queries(self): invalid_queries_eql = [ @@ -836,7 +904,7 @@ def test_invalid_queries(self): "p7r", "p12", "asc", "jks", "p7b", "signature", "gpg", "pgp.sig", "sst", "pgp", "gpgz", "pfx", "crt", "p8", "sig", "pkcs7", "jceks", "pkcs8", "psc1", "p7c", "csr", "cer", "spc", "ps2xml") - """ + """, ] valid_queries_eql = [ @@ -851,8 +919,7 @@ def test_invalid_queries(self): "token","assig", "pssc", "keystore", "pub", "pgp.asc", "ps1xml", "pem", "gpg.sig", "der", "key", "p7r", "p12", "asc", "jks", "p7b", "signature", "gpg", "pgp.sig", "sst", "pgp", "gpgz", "pfx", "p8", "sig", "pkcs7", "jceks", "pkcs8", "psc1", "p7c", "csr", "cer", "spc", "ps2xml") - """ - + """, ] invalid_queries_kql = [ @@ -875,8 +942,7 @@ def test_invalid_queries(self): """, """ event.dataset:"google_workspace.admin" and event.action:"CREATE_DATA_TRANSFER_REQUEST" - """ - + """, ] base_fields_eql = { @@ -889,7 +955,7 @@ def test_invalid_queries(self): "risk_score": 21, "rule_id": str(uuid.uuid4()), "severity": "low", - "type": "eql" + "type": "eql", } base_fields_kql = { @@ -902,7 +968,7 @@ def test_invalid_queries(self): "risk_score": 21, "rule_id": str(uuid.uuid4()), "severity": "low", - "type": "query" + "type": "query", } def build_rule(query: str, query_language: str): @@ -912,7 +978,7 @@ def build_rule(query: str, query_language: str): "updated_date": "1970/01/01", "query_schema_validation": True, "maturity": "production", - "min_stack_version": load_current_package_version() + "min_stack_version": load_current_package_version(), } if query_language == "eql": data = base_fields_eql.copy() @@ -921,6 +987,7 @@ def build_rule(query: str, query_language: str): data["query"] = query obj = {"metadata": metadata, "rule": data} return TOMLRuleContents.from_dict(obj) + # eql for query in valid_queries_eql: build_rule(query, "eql") @@ -956,83 +1023,58 @@ def test_event_dataset(self): continue data = rule.contents.data meta = rule.contents.metadata - if meta.query_schema_validation is not False or meta.maturity != "deprecated": - if isinstance(data, QueryRuleData) and data.language != 'lucene': - packages_manifest = load_integrations_manifests() - pkg_integrations = TOMLRuleContents.get_packaged_integrations(data, meta, packages_manifest) - - validation_integrations_check = None - - if pkg_integrations: - # validate the query against related integration fields - validation_integrations_check = test_validator.validate_integration(data, - meta, - pkg_integrations) + if (meta.query_schema_validation is not False or meta.maturity != "deprecated") and ( + isinstance(data, QueryRuleData) and data.language != "lucene" + ): + packages_manifest = load_integrations_manifests() + pkg_integrations = TOMLRuleContents.get_packaged_integrations(data, meta, packages_manifest) + + validation_integrations_check = None + + if pkg_integrations: + # validate the query against related integration fields + validation_integrations_check = test_validator.validate_integration( + data, meta, pkg_integrations + ) - if validation_integrations_check and "event.dataset" in rule.contents.data.query: - raise validation_integrations_check + if validation_integrations_check and "event.dataset" in rule.contents.data.query: + raise validation_integrations_check class TestIntegrationRules(BaseRuleTest): """Test integration rules.""" - @unittest.skip("8.3+ Stacks Have Related Integrations Feature") - def test_integration_guide(self): - """Test that rules which require a config note are using standard verbiage.""" - config = '## Setup\n\n' - beats_integration_pattern = config + 'The {} Fleet integration, Filebeat module, or similarly ' \ - 'structured data is required to be compatible with this rule.' - render = beats_integration_pattern.format - integration_notes = { - 'aws': render('AWS'), - 'azure': render('Azure'), - 'cyberarkpas': render('CyberArk Privileged Access Security (PAS)'), - 'gcp': render('GCP'), - 'google_workspace': render('Google Workspace'), - 'o365': render('Office 365 Logs'), - 'okta': render('Okta'), - } - - for rule in self.all_rules: - integration = rule.contents.metadata.integration - note_str = integration_notes.get(integration) - - if note_str: - error_message = f'{self.rule_str(rule)} note required for config information' - self.assertIsNotNone(rule.contents.data.note, error_message) - - if note_str not in rule.contents.data.note: - self.fail(f'{self.rule_str(rule)} expected {integration} config missing\n\n' - f'Expected: {note_str}\n\n' - f'Actual: {rule.contents.data.note}') - def test_rule_demotions(self): """Test to ensure a locked rule is not dropped to development, only deprecated""" versions = loaded_version_lock.version_lock failures = [] for rule in self.all_rules: - if rule.id in versions and rule.contents.metadata.maturity not in ('production', 'deprecated'): - err_msg = f'{self.rule_str(rule)} a version locked rule can only go from production to deprecated\n' - err_msg += f'Actual: {rule.contents.metadata.maturity}' + if rule.id in versions and rule.contents.metadata.maturity not in ("production", "deprecated"): + err_msg = f"{self.rule_str(rule)} a version locked rule can only go from production to deprecated\n" + err_msg += f"Actual: {rule.contents.metadata.maturity}" failures.append(err_msg) if failures: - err_msg = '\n'.join(failures) - self.fail(f'The following rules have been improperly demoted:\n{err_msg}') + err_msg = "\n".join(failures) + self.fail(f"The following rules have been improperly demoted:\n{err_msg}") def test_all_min_stack_rules_have_comment(self): - failures = [] - - for rule in self.all_rules: - if rule.contents.metadata.min_stack_version and not rule.contents.metadata.min_stack_comments: - failures.append(f'{self.rule_str(rule)} missing `metadata.min_stack_comments`. min_stack_version: ' - f'{rule.contents.metadata.min_stack_version}') + failures = [ + ( + f"{self.rule_str(rule)} missing `metadata.min_stack_comments`. min_stack_version: " + f"{rule.contents.metadata.min_stack_version}" + ) + for rule in self.all_rules + if rule.contents.metadata.min_stack_version and not rule.contents.metadata.min_stack_comments + ] if failures: - err_msg = '\n'.join(failures) - self.fail(f'The following ({len(failures)}) rules have a `min_stack_version` defined but missing comments:' - f'\n{err_msg}') + err_msg = "\n".join(failures) + self.fail( + f"The following ({len(failures)}) rules have a `min_stack_version` defined but missing comments:" + f"\n{err_msg}" + ) def test_ml_integration_jobs_exist(self): """Test that machine learning jobs exist in the integration.""" @@ -1044,36 +1086,37 @@ def test_ml_integration_jobs_exist(self): for rule in self.all_rules: if rule.contents.data.type == "machine_learning": - ml_integration_name = next((i for i in rule.contents.metadata.integration - if i in ml_integration_names), None) + ml_integration_name = next( + (i for i in rule.contents.metadata.integration if i in ml_integration_names), None + ) if ml_integration_name: if "machine_learning_job_id" not in dir(rule.contents.data): - failures.append(f'{self.rule_str(rule)} missing `machine_learning_job_id`') + failures.append(f"{self.rule_str(rule)} missing `machine_learning_job_id`") else: rule_job_id = rule.contents.data.machine_learning_job_id ml_schema = integration_schemas.get(ml_integration_name) min_version = Version.parse( rule.contents.metadata.min_stack_version or load_current_package_version(), - optional_minor_and_patch=True + optional_minor_and_patch=True, ) latest_compat_ver = find_latest_compatible_version( package=ml_integration_name, integration="", rule_stack_version=min_version, - packages_manifest=integration_manifests + packages_manifest=integration_manifests, ) compat_integration_schema = ml_schema[latest_compat_ver[0]] - if rule_job_id not in compat_integration_schema['jobs']: + if rule_job_id not in compat_integration_schema["jobs"]: failures.append( - f'{self.rule_str(rule)} machine_learning_job_id `{rule_job_id}` not found ' - f'in version `{latest_compat_ver[0]}` of `{ml_integration_name}` integration. ' - f'existing jobs: {compat_integration_schema["jobs"]}' + f"{self.rule_str(rule)} machine_learning_job_id `{rule_job_id}` not found " + f"in version `{latest_compat_ver[0]}` of `{ml_integration_name}` integration. " + f"existing jobs: {compat_integration_schema['jobs']}" ) if failures: - err_msg = '\n'.join(failures) + err_msg = "\n".join(failures) self.fail( - f'The following ({len(failures)}) rules are missing a valid `machine_learning_job_id`:\n{err_msg}' + f"The following ({len(failures)}) rules are missing a valid `machine_learning_job_id`:\n{err_msg}" ) @@ -1082,11 +1125,11 @@ class TestRuleTiming(BaseRuleTest): def test_event_override(self): """Test that timestamp_override is properly applied to rules.""" - # kql: always require (fallback to @timestamp enabled) - # eql: - # sequences: never - # min_stack_version >= 8.2: any - fallback to @timestamp enabled https://github.com/elastic/kibana/pull/127989 - # if 'event.ingested' is missing, '@timestamp' will be default + # when kql: always require (fallback to @timestamp enabled) + # when eql: + # - sequences: never + # - min_stack_version >= 8.2: any - fallback to @timestamp enabled https://github.com/elastic/kibana/pull/127989 + # - if 'event.ingested' is missing, '@timestamp' will be default errors = [] for rule in self.all_rules: @@ -1095,36 +1138,38 @@ def test_event_override(self): # QueryRuleData should inherently ignore machine learning rules if isinstance(rule.contents.data, QueryRuleData): rule_language = rule.contents.data.language - has_event_ingested = rule.contents.data.get('timestamp_override') == 'event.ingested' + has_event_ingested = rule.contents.data.get("timestamp_override") == "event.ingested" rule_str = self.rule_str(rule, trailer=None) if not has_event_ingested: # TODO: determine if we expand this to ES|QL # ignores any rule that does not use EQL or KQL queries specifically # this does not avoid rule types where variants of KQL are used (e.g. new terms) - if rule_language not in ('eql', 'kuery') or getattr(rule.contents.data, 'is_sequence', False): + if rule_language not in ("eql", "kuery") or getattr(rule.contents.data, "is_sequence", False): continue - else: - errors.append(f'{rule_str} - rule must have `timestamp_override: event.ingested`') + errors.append(f"{rule_str} - rule must have `timestamp_override: event.ingested`") if errors: - self.fail('The following rules are invalid:\n' + '\n'.join(errors)) + self.fail("The following rules are invalid:\n" + "\n".join(errors)) def test_required_lookback(self): """Ensure endpoint rules have the proper lookback time.""" - long_indexes = {'logs-endpoint.events.*'} + long_indexes = {"logs-endpoint.events.*"} missing = [] for rule in self.all_rules: contents = rule.contents - if isinstance(contents.data, QueryRuleData): - if set(getattr(contents.data, "index", None) or []) & long_indexes and not contents.data.from_: - missing.append(rule) + if ( + isinstance(contents.data, QueryRuleData) + and (set(getattr(contents.data, "index", None) or []) & long_indexes) + and not contents.data.from_ + ): + missing.append(rule) if missing: - rules_str = '\n '.join(self.rule_str(r, trailer=None) for r in missing) - err_msg = f'The following rules should have a longer `from` defined, due to indexes used\n {rules_str}' + rules_str = "\n ".join(self.rule_str(r, trailer=None) for r in missing) + err_msg = f"The following rules should have a longer `from` defined, due to indexes used\n {rules_str}" self.fail(err_msg) def test_eql_lookback(self): @@ -1134,8 +1179,8 @@ def test_eql_lookback(self): ten_minutes = 10 * 60 * 1000 for rule in self.all_rules: - if rule.contents.data.type == 'eql' and rule.contents.data.max_span: - if rule.contents.data.look_back == 'unknown': + if rule.contents.data.type == "eql" and rule.contents.data.max_span: + if rule.contents.data.look_back == "unknown": unknowns.append(self.rule_str(rule, trailer=None)) else: look_back = rule.contents.data.look_back @@ -1143,16 +1188,17 @@ def test_eql_lookback(self): expected = look_back + ten_minutes if expected < max_span: - invalids.append(f'{self.rule_str(rule)} lookback: {look_back}, maxspan: {max_span}, ' - f'expected: >={expected}') + invalids.append( + f"{self.rule_str(rule)} lookback: {look_back}, maxspan: {max_span}, expected: >={expected}" + ) if unknowns: - warn_str = '\n'.join(unknowns) - warnings.warn(f'Unable to determine lookbacks for the following rules:\n{warn_str}') + warn_str = "\n".join(unknowns) + print(f"WARNING: Unable to determine lookbacks for the following rules:\n{warn_str}") if invalids: - invalids_str = '\n'.join(invalids) - self.fail(f'The following rules have longer max_spans than lookbacks:\n{invalids_str}') + invalids_str = "\n".join(invalids) + self.fail(f"The following rules have longer max_spans than lookbacks:\n{invalids_str}") def test_eql_interval_to_maxspan(self): """Check the ratio of interval to maxspan for eql rules.""" @@ -1160,33 +1206,34 @@ def test_eql_interval_to_maxspan(self): five_minutes = 5 * 60 * 1000 for rule in self.all_rules: - if rule.contents.data.type == 'eql': + if rule.contents.data.type == "eql": interval = rule.contents.data.interval or five_minutes maxspan = rule.contents.data.max_span ratio = rule.contents.data.interval_ratio # we want to test for at least a ratio of: interval >= 1/2 maxspan # but we only want to make an exception and cap the ratio at 5m interval (2.5m maxspan) - if maxspan and maxspan > (five_minutes / 2) and ratio and ratio < .5: + if maxspan and maxspan > (five_minutes / 2) and ratio and ratio < 0.5: expected = maxspan // 2 - err_msg = f'{self.rule_str(rule)} interval: {interval}, maxspan: {maxspan}, expected: >={expected}' + err_msg = f"{self.rule_str(rule)} interval: {interval}, maxspan: {maxspan}, expected: >={expected}" invalids.append(err_msg) if invalids: - invalids_str = '\n'.join(invalids) - self.fail(f'The following rules have intervals too short for their given max_spans (ms):\n{invalids_str}') + invalids_str = "\n".join(invalids) + self.fail(f"The following rules have intervals too short for their given max_spans (ms):\n{invalids_str}") class TestLicense(BaseRuleTest): """Test rule license.""" - @unittest.skipIf(os.environ.get('CUSTOM_RULES_DIR'), 'Skipping test for custom rules.') + + @unittest.skipIf(os.environ.get("CUSTOM_RULES_DIR"), "Skipping test for custom rules.") def test_elastic_license_only_v2(self): """Test to ensure that production rules with the elastic license are only v2.""" for rule in self.all_rules: rule_license = rule.contents.data.license - if 'elastic license' in rule_license.lower(): - err_msg = f'{self.rule_str(rule)} If Elastic License is used, only v2 should be used' - self.assertEqual(rule_license, 'Elastic License v2', err_msg) + if "elastic license" in rule_license.lower(): + err_msg = f"{self.rule_str(rule)} If Elastic License is used, only v2 should be used" + self.assertEqual(rule_license, "Elastic License v2", err_msg) class TestIncompatibleFields(BaseRuleTest): @@ -1199,11 +1246,11 @@ def test_rule_backports_for_restricted_fields(self): for rule in self.all_rules: invalid = rule.contents.check_restricted_fields_compatibility() if invalid: - invalid_rules.append(f'{self.rule_str(rule)} {invalid}') + invalid_rules.append(f"{self.rule_str(rule)} {invalid}") if invalid_rules: - invalid_str = '\n'.join(invalid_rules) - err_msg = 'The following rules have min_stack_versions lower than allowed for restricted fields:\n' + invalid_str = "\n".join(invalid_rules) + err_msg = "The following rules have min_stack_versions lower than allowed for restricted fields:\n" err_msg += invalid_str self.fail(err_msg) @@ -1229,14 +1276,19 @@ def test_build_fields_min_stack(self): # change which is different because of the build time fields. # This also ensures that the introduced version is greater than the min supported, in order to age off # old and unneeded checks. (i.e. 8.3.0 < 8.9.0 min supported, so it is irrelevant now) - if start_ver is not None and current_stack_ver >= start_ver >= min_supported_stack_version: - if min_stack is None or not Version.parse(min_stack) >= start_ver: - errors.append(f'{build_field} >= {start_ver}') + if ( + start_ver + and current_stack_ver >= start_ver >= min_supported_stack_version + and (min_stack is None or not Version.parse(min_stack) >= start_ver) + ): + errors.append(f"{build_field} >= {start_ver}") if errors: - err_str = ', '.join(errors) - invalids.append(f'{self.rule_str(rule)} uses a rule type with build fields requiring min_stack_versions' - f' to be set: {err_str}') + err_str = ", ".join(errors) + invalids.append( + f"{self.rule_str(rule)} uses a rule type with build fields requiring min_stack_versions" + f" to be set: {err_str}" + ) if invalids: self.fail(invalids) @@ -1249,9 +1301,9 @@ def test_rule_risk_score_severity_mismatch(self): invalid_list = [] risk_severity = { "critical": (74, 100), # updated range for critical - "high": (48, 73), # updated range for high - "medium": (22, 47), # updated range for medium - "low": (0, 21), # updated range for low + "high": (48, 73), # updated range for high + "medium": (22, 47), # updated range for medium + "low": (0, 21), # updated range for low } for rule in self.all_rules: severity = rule.contents.data.severity @@ -1260,11 +1312,11 @@ def test_rule_risk_score_severity_mismatch(self): # Check if the risk_score falls within the range for the severity level min_score, max_score = risk_severity[severity] if not min_score <= risk_score <= max_score: - invalid_list.append(f'{self.rule_str(rule)} Severity: {severity}, Risk Score: {risk_score}') + invalid_list.append(f"{self.rule_str(rule)} Severity: {severity}, Risk Score: {risk_score}") if invalid_list: - invalid_str = '\n'.join(invalid_list) - err_msg = 'The following rules have mismatches between Severity and Risk Score field values:\n' + invalid_str = "\n".join(invalid_list) + err_msg = "The following rules have mismatches between Severity and Risk Score field values:\n" err_msg += invalid_str self.fail(err_msg) @@ -1278,9 +1330,9 @@ def test_note_contains_triage_and_analysis(self): for rule in self.all_rules: if ( - not rule.contents.data.is_elastic_rule or # noqa: W504 - rule.contents.data.building_block_type or # noqa: W504 - rule.contents.data.severity in ("medium", "low") + not rule.contents.data.is_elastic_rule + or rule.contents.data.building_block_type + or rule.contents.data.severity in ("medium", "low") ): # dont enforce continue @@ -1299,16 +1351,14 @@ def test_investigation_guide_uses_rule_name(self): """Check if investigation guide uses rule name in the title.""" errors = [] for rule in self.all_rules.rules: - note = rule.contents.data.get('note') - if note is not None: - # Check if `### Investigating` is present and if so, - # check if it is followed by the rule name. - if '### Investigating' in note: - results = re.search(rf'### Investigating\s+{re.escape(rule.name)}', note, re.I | re.M) - if results is None: - errors.append(f'{self.rule_str(rule)} investigation guide does not use rule name in the title') + note = rule.contents.data.get("note") + # Check if `### Investigating` is present and if so, check if it is followed by the rule name. + if note and "### Investigating" in note: + results = re.search(rf"### Investigating\s+{re.escape(rule.name)}", note, re.I | re.M) + if results is None: + errors.append(f"{self.rule_str(rule)} investigation guide does not use rule name in the title") if errors: - self.fail('\n'.join(errors)) + self.fail("\n".join(errors)) class TestNoteMarkdownPlugins(BaseRuleTest): @@ -1316,50 +1366,56 @@ class TestNoteMarkdownPlugins(BaseRuleTest): def test_note_has_osquery_warning(self): """Test that all rules with osquery entries have the default notification of stack compatibility.""" - osquery_note_pattern = ('> **Note**:\n> This investigation guide uses the [Osquery Markdown Plugin]' - '(https://www.elastic.co/guide/en/security/current/invest-guide-run-osquery.html) ' - 'introduced in Elastic Stack version 8.5.0. Older Elastic Stack versions will display ' - 'unrendered Markdown in this guide.') + osquery_note_pattern = ( + "> **Note**:\n> This investigation guide uses the [Osquery Markdown Plugin]" + "(https://www.elastic.co/guide/en/security/current/invest-guide-run-osquery.html) " + "introduced in Elastic Stack version 8.5.0. Older Elastic Stack versions will display " + "unrendered Markdown in this guide." + ) invest_note_pattern = ( - '> This investigation guide uses the [Investigate Markdown Plugin]' - '(https://www.elastic.co/guide/en/security/current/interactive-investigation-guides.html)' - ' introduced in Elastic Stack version 8.8.0. Older Elastic Stack versions will display ' - 'unrendered Markdown in this guide.') + "> This investigation guide uses the [Investigate Markdown Plugin]" + "(https://www.elastic.co/guide/en/security/current/interactive-investigation-guides.html)" + " introduced in Elastic Stack version 8.8.0. Older Elastic Stack versions will display " + "unrendered Markdown in this guide." + ) for rule in self.all_rules: - if not rule.contents.get('transform'): + if not rule.contents.get("transform"): continue - osquery = rule.contents.transform.get('osquery') + osquery = rule.contents.transform.get("osquery") if osquery and osquery_note_pattern not in rule.contents.data.note: - self.fail(f'{self.rule_str(rule)} Investigation guides using the Osquery Markdown must contain ' - f'the following note:\n{osquery_note_pattern}') + self.fail( + f"{self.rule_str(rule)} Investigation guides using the Osquery Markdown must contain " + f"the following note:\n{osquery_note_pattern}" + ) - investigate = rule.contents.transform.get('investigate') + investigate = rule.contents.transform.get("investigate") if investigate and invest_note_pattern not in rule.contents.data.note: - self.fail(f'{self.rule_str(rule)} Investigation guides using the Investigate Markdown must contain ' - f'the following note:\n{invest_note_pattern}') + self.fail( + f"{self.rule_str(rule)} Investigation guides using the Investigate Markdown must contain " + f"the following note:\n{invest_note_pattern}" + ) def test_plugin_placeholders_match_entries(self): """Test that the number of plugin entries match their respective placeholders in note.""" for rule in self.all_rules: - has_transform = rule.contents.get('transform') is not None - has_note = rule.contents.data.get('note') is not None + has_transform = rule.contents.get("transform") is not None + has_note = rule.contents.data.get("note") is not None note = rule.contents.data.note if has_transform: if not has_note: - self.fail(f'{self.rule_str(rule)} transformed defined with no note') - else: - if not has_note: - continue + self.fail(f"{self.rule_str(rule)} transformed defined with no note") + elif not has_note: + continue note_template = PatchedTemplate(note) - identifiers = [i for i in note_template.get_identifiers() if '_' in i] + identifiers = [i for i in note_template.get_identifiers() if "_" in i] if not has_transform: if identifiers: - self.fail(f'{self.rule_str(rule)} note contains plugin placeholders with no transform entries') + self.fail(f"{self.rule_str(rule)} note contains plugin placeholders with no transform entries") else: continue @@ -1369,38 +1425,40 @@ def test_plugin_placeholders_match_entries(self): note_counts = defaultdict(int) for identifier in identifiers: # "$" is used for other things, so this verifies the pattern of a trailing "_" followed by ints - if '_' not in identifier: + if "_" not in identifier: continue - dash_index = identifier.rindex('_') - if dash_index == len(identifier) or not identifier[dash_index + 1:].isdigit(): + dash_index = identifier.rindex("_") + if dash_index == len(identifier) or not identifier[dash_index + 1 :].isdigit(): continue - plugin, _ = identifier.split('_') + plugin, _ = identifier.split("_") if plugin in transform_counts: note_counts[plugin] += 1 - err_msg = f'{self.rule_str(rule)} plugin entry count mismatch between transform and note' + err_msg = f"{self.rule_str(rule)} plugin entry count mismatch between transform and note" self.assertDictEqual(transform_counts, note_counts, err_msg) def test_if_plugins_explicitly_defined(self): """Check if plugins are explicitly defined with the pattern in note vs using transform.""" for rule in self.all_rules: - note = rule.contents.data.get('note') + note = rule.contents.data.get("note") if note is not None: - results = re.search(r'(!{osquery|!{investigate)', note, re.I | re.M) - err_msg = f'{self.rule_str(rule)} investigation guide plugin pattern detected! Use Transform' + results = re.search(r"(!{osquery|!{investigate)", note, re.I | re.M) + err_msg = f"{self.rule_str(rule)} investigation guide plugin pattern detected! Use Transform" self.assertIsNone(results, err_msg) class TestAlertSuppression(BaseRuleTest): """Test rule alert suppression.""" - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.8.0"), - "Test only applicable to 8.8+ stacks for rule alert suppression feature.") + @unittest.skipIf( + PACKAGE_STACK_VERSION < Version.parse("8.8.0"), + "Test only applicable to 8.8+ stacks for rule alert suppression feature.", + ) def test_group_field_in_schemas(self): """Test to ensure the fields are defined is in ECS/Beats/Integrations schema.""" for rule in self.all_rules: - if rule.contents.data.get('alert_suppression'): + if rule.contents.data.get("alert_suppression"): if isinstance(rule.contents.data.alert_suppression, AlertSuppressionMapping): group_by_fields = rule.contents.data.alert_suppression.group_by elif isinstance(rule.contents.data.alert_suppression, ThresholdAlertSuppression): @@ -1411,8 +1469,8 @@ def test_group_field_in_schemas(self): else: min_stack_version = Version.parse(min_stack_version) integration_tag = rule.contents.metadata.get("integration") - ecs_version = get_stack_schemas()[str(min_stack_version)]['ecs'] - beats_version = get_stack_schemas()[str(min_stack_version)]['beats'] + ecs_version = get_stack_schemas()[str(min_stack_version)]["ecs"] + beats_version = get_stack_schemas()[str(min_stack_version)]["beats"] queryvalidator = QueryValidator(rule.contents.data.query) _, _, schema = queryvalidator.get_beats_schema([], beats_version, ecs_version) if integration_tag: @@ -1422,24 +1480,26 @@ def test_group_field_in_schemas(self): for ints in integration_tag: integration_schema = integration_schemas[ints] int_schema = integration_schema[list(integration_schema.keys())[-1]] - for data_source in int_schema.keys(): + for data_source in int_schema: schema.update(**int_schema[data_source]) for fld in group_by_fields: - if fld not in schema.keys(): - self.fail(f"{self.rule_str(rule)} alert suppression field {fld} not \ - found in ECS, Beats, or non-ecs schemas") + if fld not in schema: + self.fail( + f"{self.rule_str(rule)} alert suppression field {fld} not \ + found in ECS, Beats, or non-ecs schemas" + ) - @unittest.skipIf(PACKAGE_STACK_VERSION < Version.parse("8.14.0") or # noqa: W504 - PACKAGE_STACK_VERSION >= Version.parse("8.18.0"), # noqa: W504 - "Test is applicable to 8.14 --> 8.17 stacks for eql non-sequence rule alert suppression feature.") + @unittest.skipIf( + PACKAGE_STACK_VERSION < Version.parse("8.14.0") or PACKAGE_STACK_VERSION >= Version.parse("8.18.0"), + "Test is applicable to 8.14 --> 8.17 stacks for eql non-sequence rule alert suppression feature.", + ) def test_eql_non_sequence_support_only(self): for rule in self.all_rules: if ( - isinstance(rule.contents.data, EQLRuleData) and rule.contents.data.get("alert_suppression") - and rule.contents.data.is_sequence # noqa: W503 + isinstance(rule.contents.data, EQLRuleData) + and rule.contents.data.get("alert_suppression") + and rule.contents.data.is_sequence ): # is_sequence method not yet available during schema validation # so we have to check in a unit test - self.fail( - f"{self.rule_str(rule)} Sequence rules cannot have alert suppression" - ) + self.fail(f"{self.rule_str(rule)} Sequence rules cannot have alert suppression") diff --git a/tests/test_gh_workflows.py b/tests/test_gh_workflows.py index 10983d2057c..114ca7642ee 100644 --- a/tests/test_gh_workflows.py +++ b/tests/test_gh_workflows.py @@ -6,27 +6,26 @@ """Tests for GitHub workflow functionality.""" import unittest -from pathlib import Path import yaml -from detection_rules.schemas import get_stack_versions, RULES_CONFIG -from detection_rules.utils import get_path +from detection_rules.schemas import RULES_CONFIG, get_stack_versions +from detection_rules.utils import ROOT_DIR -GITHUB_FILES = Path(get_path()) / '.github' -GITHUB_WORKFLOWS = GITHUB_FILES / 'workflows' +GITHUB_FILES = ROOT_DIR / ".github" +GITHUB_WORKFLOWS = GITHUB_FILES / "workflows" class TestWorkflows(unittest.TestCase): """Test GitHub workflow functionality.""" - @unittest.skipIf(RULES_CONFIG.bypass_version_lock, 'Version lock bypassed') + @unittest.skipIf(RULES_CONFIG.bypass_version_lock, "Version lock bypassed") def test_matrix_to_lock_version_defaults(self): """Test that the default versions in the lock-versions workflow mirror those from the schema-map.""" - lock_workflow_file = GITHUB_WORKFLOWS / 'lock-versions.yml' + lock_workflow_file = GITHUB_WORKFLOWS / "lock-versions.yml" lock_workflow = yaml.safe_load(lock_workflow_file.read_text()) - lock_versions = lock_workflow[True]['workflow_dispatch']['inputs']['branches']['default'].split(',') + lock_versions = lock_workflow[True]["workflow_dispatch"]["inputs"]["branches"]["default"].split(",") matrix_versions = get_stack_versions(drop_patch=True) - err_msg = 'lock-versions workflow default does not match current matrix in stack-schema-map' + err_msg = "lock-versions workflow default does not match current matrix in stack-schema-map" self.assertListEqual(lock_versions, matrix_versions[:-1], err_msg) diff --git a/tests/test_hunt_data.py b/tests/test_hunt_data.py index d4a6b3e9eb9..c7b1e04c481 100644 --- a/tests/test_hunt_data.py +++ b/tests/test_hunt_data.py @@ -4,11 +4,12 @@ # 2.0. """Test for hunt toml files.""" + import unittest from hunting.definitions import HUNTING_DIR from hunting.markdown import load_toml -from hunting.utils import load_index_file, load_all_toml +from hunting.utils import load_all_toml, load_index_file class TestHunt(unittest.TestCase): @@ -33,9 +34,7 @@ def test_toml_loading(self): config = load_toml(example_toml) self.assertEqual(config.author, "Elastic") self.assertEqual(config.integration, "aws_bedrock.invocation") - self.assertEqual( - config.name, "Denial of Service or Resource Exhaustion Attacks Detection" - ) + self.assertEqual(config.name, "Denial of Service or Resource Exhaustion Attacks Detection") self.assertEqual(config.language, "ES|QL") def test_load_toml_files(self): @@ -53,9 +52,7 @@ def test_load_toml_files(self): def test_markdown_existence(self): """Ensure each TOML file has a corresponding Markdown file in the docs directory.""" for toml_file in HUNTING_DIR.rglob("*.toml"): - expected_markdown_path = ( - toml_file.parent.parent / "docs" / toml_file.with_suffix(".md").name - ) + expected_markdown_path = toml_file.parent.parent / "docs" / toml_file.with_suffix(".md").name self.assertTrue( expected_markdown_path.exists(), @@ -65,9 +62,7 @@ def test_markdown_existence(self): def test_toml_existence(self): """Ensure each Markdown file has a corresponding TOML file in the queries directory.""" for markdown_file in HUNTING_DIR.rglob("*/docs/*.md"): - expected_toml_path = ( - markdown_file.parent.parent / "queries" / markdown_file.with_suffix(".toml").name - ) + expected_toml_path = markdown_file.parent.parent / "queries" / markdown_file.with_suffix(".toml").name self.assertTrue( expected_toml_path.exists(), @@ -85,17 +80,19 @@ def setUpClass(cls): def test_mitre_techniques_present(self): """Ensure each query has at least one MITRE technique.""" - for folder, queries in self.hunting_index.items(): + for queries in self.hunting_index.values(): for query_uuid, query_data in queries.items(): - self.assertTrue(query_data.get('mitre'), - f"No MITRE techniques found for query: {query_data.get('name', query_uuid)}") + self.assertTrue( + query_data.get("mitre"), + f"No MITRE techniques found for query: {query_data.get('name', query_uuid)}", + ) def test_valid_structure(self): """Ensure each query entry has a valid structure.""" - required_fields = ['name', 'path', 'mitre'] + required_fields = ["name", "path", "mitre"] - for folder, queries in self.hunting_index.items(): - for query_uuid, query_data in queries.items(): + for queries in self.hunting_index.values(): + for query_data in queries.values(): for field in required_fields: self.assertIn(field, query_data, f"Missing field '{field}' in query: {query_data}") @@ -105,14 +102,11 @@ def test_all_files_in_index(self): all_toml_data = load_all_toml(HUNTING_DIR) uuids = [hunt.uuid for hunt, path in all_toml_data] - for folder, queries in self.hunting_index.items(): - for query_uuid in queries: - if query_uuid not in uuids: - missing_index_entries.append(query_uuid) + for queries in self.hunting_index.values(): + missing_index_entries.extend([query_uuid for query_uuid in queries if query_uuid not in uuids]) self.assertFalse( - missing_index_entries, - f"Missing index entries for the following queries: {missing_index_entries}" + missing_index_entries, f"Missing index entries for the following queries: {missing_index_entries}" ) diff --git a/tests/test_packages.py b/tests/test_packages.py index e0adab2fac9..c6ed9e7e870 100644 --- a/tests/test_packages.py +++ b/tests/test_packages.py @@ -4,16 +4,16 @@ # 2.0. """Test that the packages are built correctly.""" + import unittest import uuid -from semver import Version + from marshmallow import ValidationError +from semver import Version from detection_rules import rule_loader -from detection_rules.schemas.registry_package import (RegistryPackageManifestV1, - RegistryPackageManifestV3) from detection_rules.packaging import PACKAGE_FILE, Package - +from detection_rules.schemas.registry_package import RegistryPackageManifestV1, RegistryPackageManifestV3 from tests.base import BaseRuleTest package_configs = Package.load_configs() @@ -25,7 +25,7 @@ class TestPackages(BaseRuleTest): @staticmethod def get_test_rule(version=1, count=1): def get_rule_contents(): - contents = { + return { "author": ["Elastic"], "description": "test description", "language": "kuery", @@ -35,17 +35,12 @@ def get_rule_contents(): "risk_score": 21, "rule_id": str(uuid.uuid4()), "severity": "low", - "type": "query" + "type": "query", } - return contents - rules = [rule_loader.TOMLRule('test.toml', get_rule_contents()) for i in range(count)] + rules = [rule_loader.TOMLRule("test.toml", get_rule_contents()) for i in range(count)] version_info = { - rule.id: { - 'rule_name': rule.name, - 'sha256': rule.contents.get_hash(), - 'version': version - } for rule in rules + rule.id: {"rule_name": rule.name, "sha256": rule.contents.get_hash(), "version": version} for rule in rules } return rules, version_info @@ -53,44 +48,42 @@ def get_rule_contents(): def test_package_loader_production_config(self): """Test that packages are loading correctly.""" - @unittest.skipIf(rule_loader.RULES_CONFIG.bypass_version_lock, 'Version lock bypassed') + @unittest.skipIf(rule_loader.RULES_CONFIG.bypass_version_lock, "Version lock bypassed") def test_package_loader_default_configs(self): """Test configs in detection_rules/etc/packages.yaml.""" Package.from_config(rule_collection=self.rc, config=package_configs) - @unittest.skipIf(rule_loader.RULES_CONFIG.bypass_version_lock, 'Version lock bypassed') + @unittest.skipIf(rule_loader.RULES_CONFIG.bypass_version_lock, "Version lock bypassed") def test_package_summary(self): """Test the generation of the package summary.""" rules = self.rc - package = Package(rules, 'test-package') + package = Package(rules, "test-package") package.generate_summary_and_changelog(package.changed_ids, package.new_ids, package.removed_ids) - @unittest.skipIf(rule_loader.RULES_CONFIG.bypass_version_lock, 'Version lock bypassed') + @unittest.skipIf(rule_loader.RULES_CONFIG.bypass_version_lock, "Version lock bypassed") def test_rule_versioning(self): """Test that all rules are properly versioned and tracked""" self.maxDiff = None rules = self.rc original_hashes = [] - post_bump_hashes = [] # test that no rules have versions defined for rule in rules: - self.assertGreaterEqual(rule.contents.autobumped_version, 1, '{} - {}: version is not being set in package') + self.assertGreaterEqual(rule.contents.autobumped_version, 1, "{} - {}: version is not being set in package") original_hashes.append(rule.contents.get_hash()) - package = Package(rules, 'test-package') + package = Package(rules, "test-package") # test that all rules have versions defined - # package.bump_versions(save_changes=False) for rule in package.rules: - self.assertGreaterEqual(rule.contents.autobumped_version, 1, '{} - {}: version is not being set in package') + self.assertGreaterEqual(rule.contents.autobumped_version, 1, "{} - {}: version is not being set in package") # test that rules validate with version - for rule in package.rules: - post_bump_hashes.append(rule.contents.get_hash()) + + post_bump_hashes = [rule.contents.get_hash() for rule in package.rules] # test that no hashes changed as a result of the version bumps - self.assertListEqual(original_hashes, post_bump_hashes, 'Version bumping modified the hash of a rule') + self.assertListEqual(original_hashes, post_bump_hashes, "Version bumping modified the hash of a rule") class TestRegistryPackage(unittest.TestCase): @@ -98,11 +91,11 @@ class TestRegistryPackage(unittest.TestCase): @classmethod def setUpClass(cls) -> None: - - assert 'registry_data' in package_configs, f'Missing registry_data in {PACKAGE_FILE}' - cls.registry_config = package_configs['registry_data'] - stack_version = Version.parse(cls.registry_config['conditions']['kibana.version'].strip("^"), - optional_minor_and_patch=True) + assert "registry_data" in package_configs, f"Missing registry_data in {PACKAGE_FILE}" + cls.registry_config = package_configs["registry_data"] + stack_version = Version.parse( + cls.registry_config["conditions"]["kibana.version"].strip("^"), optional_minor_and_patch=True + ) if stack_version >= Version.parse("8.12.0"): RegistryPackageManifestV3.from_dict(cls.registry_config) else: @@ -111,7 +104,7 @@ def setUpClass(cls) -> None: def test_registry_package_config(self): """Test that the registry package is validating properly.""" registry_config = self.registry_config.copy() - registry_config['version'] += '7.1.1.' + registry_config["version"] += "7.1.1." with self.assertRaises(ValidationError): RegistryPackageManifestV1.from_dict(registry_config) diff --git a/tests/test_python_library.py b/tests/test_python_library.py index 3d0ab8b09b6..7a02f8900a1 100644 --- a/tests/test_python_library.py +++ b/tests/test_python_library.py @@ -59,9 +59,7 @@ def test_eql_in_set(self): with self.assertRaisesRegex(ValueError, expected_error_message): rc.load_dict(eql_rule) # Change to appropriate destination.address field - eql_rule["rule"][ - "query" - ] = """ + eql_rule["rule"]["query"] = """ sequence by host.id, process.entity_id with maxspan = 10s [network where destination.address in ("192.168.1.1", "::1")] """ diff --git a/tests/test_rules_remote.py b/tests/test_rules_remote.py index e422239ce62..11ff1c36be3 100644 --- a/tests/test_rules_remote.py +++ b/tests/test_rules_remote.py @@ -5,17 +5,19 @@ import unittest -from .base import BaseRuleTest from detection_rules.misc import get_default_config -# from detection_rules.remote_validation import RemoteValidator +from detection_rules.remote_validation import RemoteValidator + +from .base import BaseRuleTest -@unittest.skipIf(get_default_config() is None, 'Skipping remote validation due to missing config') +@unittest.skipIf(get_default_config() is None, "Skipping remote validation due to missing config") class TestRemoteRules(BaseRuleTest): """Test rules against a remote Elastic stack instance.""" - # def test_esql_rules(self): - # """Temporarily explicitly test all ES|QL rules remotely pending parsing lib.""" - # esql_rules = [r for r in self.all_rules if r.contents.data.type == 'esql'] - # rv = RemoteValidator(parse_config=True) - # rv.validate_rules(esql_rules) + @unittest.skip("Temporarily disabled") + def test_esql_rules(self): + """Temporarily explicitly test all ES|QL rules remotely pending parsing lib.""" + esql_rules = [r for r in self.all_rules if r.contents.data.type == "esql"] + rv = RemoteValidator(parse_config=True) + rv.validate_rules(esql_rules) diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 5adb5a77073..0d8a8d3ddd9 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -4,18 +4,20 @@ # 2.0. """Test stack versioned schemas.""" + import copy import unittest import uuid -from semver import Version import eql +from marshmallow import ValidationError +from semver import Version + from detection_rules import utils from detection_rules.config import load_current_package_version from detection_rules.rule import TOMLRuleContents -from detection_rules.schemas import downgrade, RULES_CONFIG +from detection_rules.schemas import RULES_CONFIG, downgrade from detection_rules.version_lock import VersionLockFile -from marshmallow import ValidationError class TestSchemas(unittest.TestCase): @@ -42,7 +44,7 @@ def setUpClass(cls): "tactic": { "id": "TA0001", "name": "Execution", - "reference": "https://attack.mitre.org/tactics/TA0001/" + "reference": "https://attack.mitre.org/tactics/TA0001/", }, "technique": [ { @@ -52,25 +54,25 @@ def setUpClass(cls): } ], } - ] + ], } cls.v79_kql = dict(cls.v78_kql, author=["Elastic"], license="Elastic License v2") cls.v711_kql = copy.deepcopy(cls.v79_kql) # noinspection PyTypeChecker - cls.v711_kql["threat"][0]["technique"][0]["subtechnique"] = [{ - "id": "T1059.001", - "name": "PowerShell", - "reference": "https://attack.mitre.org/techniques/T1059/001/" - }] + cls.v711_kql["threat"][0]["technique"][0]["subtechnique"] = [ + {"id": "T1059.001", "name": "PowerShell", "reference": "https://attack.mitre.org/techniques/T1059/001/"} + ] # noinspection PyTypeChecker - cls.v711_kql["threat"].append({ - "framework": "MITRE ATT&CK", - "tactic": { - "id": "TA0008", - "name": "Lateral Movement", - "reference": "https://attack.mitre.org/tactics/TA0008/" - }, - }) + cls.v711_kql["threat"].append( + { + "framework": "MITRE ATT&CK", + "tactic": { + "id": "TA0008", + "name": "Lateral Movement", + "reference": "https://attack.mitre.org/tactics/TA0008/", + }, + } + ) cls.v79_threshold_contents = { "author": ["Elastic"], @@ -88,14 +90,14 @@ def setUpClass(cls): }, "type": "threshold", } - cls.v712_threshold_rule = dict(copy.deepcopy(cls.v79_threshold_contents), threshold={ - 'field': ['destination.bytes', 'process.args'], - 'value': 75, - 'cardinality': [{ - 'field': 'user.name', - 'value': 2 - }] - }) + cls.v712_threshold_rule = dict( + copy.deepcopy(cls.v79_threshold_contents), + threshold={ + "field": ["destination.bytes", "process.args"], + "value": 75, + "cardinality": [{"field": "user.name", "value": 2}], + }, + ) def test_query_downgrade_7_x(self): """Downgrade a standard KQL rule.""" @@ -142,21 +144,21 @@ def test_threshold_downgrade_7_x(self): return api_contents = self.v712_threshold_rule - self.assertDictEqual(downgrade(api_contents, '7.13'), api_contents) - self.assertDictEqual(downgrade(api_contents, '7.13.1'), api_contents) + self.assertDictEqual(downgrade(api_contents, "7.13"), api_contents) + self.assertDictEqual(downgrade(api_contents, "7.13.1"), api_contents) - exc_msg = 'Cannot downgrade a threshold rule that has multiple threshold fields defined' + exc_msg = "Cannot downgrade a threshold rule that has multiple threshold fields defined" with self.assertRaisesRegex(ValueError, exc_msg): - downgrade(api_contents, '7.9') + downgrade(api_contents, "7.9") v712_threshold_contents_single_field = copy.deepcopy(api_contents) - v712_threshold_contents_single_field['threshold']['field'].pop() + v712_threshold_contents_single_field["threshold"]["field"].pop() with self.assertRaisesRegex(ValueError, "Cannot downgrade a threshold rule that has a defined cardinality"): downgrade(v712_threshold_contents_single_field, "7.9") v712_no_cardinality = copy.deepcopy(v712_threshold_contents_single_field) - v712_no_cardinality['threshold'].pop('cardinality') + v712_no_cardinality["threshold"].pop("cardinality") self.assertEqual(downgrade(v712_no_cardinality, "7.9"), self.v79_threshold_contents) with self.assertRaises(ValueError): @@ -191,14 +193,14 @@ def test_eql_validation(self): "risk_score": 21, "rule_id": str(uuid.uuid4()), "severity": "low", - "type": "eql" + "type": "eql", } def build_rule(query): metadata = { "creation_date": "1970/01/01", "updated_date": "1970/01/01", - "min_stack_version": load_current_package_version() + "min_stack_version": load_current_package_version(), } data = base_fields.copy() data["query"] = query @@ -209,12 +211,22 @@ def build_rule(query): process where process.name == "cmd.exe" """) - example_text_fields = ['client.as.organization.name.text', 'client.user.full_name.text', - 'client.user.name.text', 'destination.as.organization.name.text', - 'destination.user.full_name.text', 'destination.user.name.text', - 'error.message', 'error.stack_trace.text', 'file.path.text', - 'file.target_path.text', 'host.os.full.text', 'host.os.name.text', - 'host.user.full_name.text', 'host.user.name.text'] + example_text_fields = [ + "client.as.organization.name.text", + "client.user.full_name.text", + "client.user.name.text", + "destination.as.organization.name.text", + "destination.user.full_name.text", + "destination.user.name.text", + "error.message", + "error.stack_trace.text", + "file.path.text", + "file.target_path.text", + "host.os.full.text", + "host.os.name.text", + "host.user.full_name.text", + "host.user.name.text", + ] for text_field in example_text_fields: with self.assertRaises(eql.parser.EqlSchemaError): build_rule(f""" @@ -247,7 +259,7 @@ def setUpClass(cls): "rule_name": "Remote File Download via PowerShell", "sha256": "8679cd72bf85b67dde3dcfdaba749ed1fa6560bca5efd03ed41c76a500ce31d6", "type": "eql", - "version": 4 + "version": 4, }, "34fde489-94b0-4500-a76f-b8a157cf9269": { "min_stack_version": "8.2", @@ -256,29 +268,29 @@ def setUpClass(cls): "rule_name": "Telnet Port Activity", "sha256": "3dd4a438c915920e6ddb0a5212603af5d94fb8a6b51a32f223d930d7e3becb89", "type": "query", - "version": 9 + "version": 9, } }, "rule_name": "Telnet Port Activity", "sha256": "b0bdfa73639226fb83eadc0303ad1801e0707743f96a36209aa58228d3bf6a89", "type": "query", - "version": 10 - } + "version": 10, + }, } def test_version_lock_no_previous(self): """Pass field validation on version lock without nested previous fields""" version_lock_contents = copy.deepcopy(self.version_lock_contents) - VersionLockFile.from_dict(dict(data=version_lock_contents)) + VersionLockFile.from_dict({"data": version_lock_contents}) - @unittest.skipIf(RULES_CONFIG.bypass_version_lock, 'Version lock bypassed') + @unittest.skipIf(RULES_CONFIG.bypass_version_lock, "Version lock bypassed") def test_version_lock_has_nested_previous(self): """Fail field validation on version lock with nested previous fields""" version_lock_contents = copy.deepcopy(self.version_lock_contents) with self.assertRaises(ValidationError): previous = version_lock_contents["34fde489-94b0-4500-a76f-b8a157cf9269"]["previous"] version_lock_contents["34fde489-94b0-4500-a76f-b8a157cf9269"]["previous"]["previous"] = previous - VersionLockFile.from_dict(dict(data=version_lock_contents)) + VersionLockFile.from_dict({"data": version_lock_contents}) class TestVersions(unittest.TestCase): @@ -287,6 +299,6 @@ class TestVersions(unittest.TestCase): def test_stack_schema_map(self): """Test to ensure that an entry exists in the stack-schema-map for the current package version.""" package_version = Version.parse(load_current_package_version(), optional_minor_and_patch=True) - stack_map = utils.load_etc_dump('stack-schema-map.yaml') - err_msg = f'There is no entry defined for the current package ({package_version}) in the stack-schema-map' + stack_map = utils.load_etc_dump(["stack-schema-map.yaml"]) + err_msg = f"There is no entry defined for the current package ({package_version}) in the stack-schema-map" self.assertIn(package_version, [Version.parse(v) for v in stack_map], err_msg) diff --git a/tests/test_specific_rules.py b/tests/test_specific_rules.py index 2236745a488..ccefa0b7f63 100644 --- a/tests/test_specific_rules.py +++ b/tests/test_specific_rules.py @@ -5,20 +5,18 @@ import unittest from copy import deepcopy -from pathlib import Path import eql.ast - +import kql from semver import Version -import kql +from detection_rules import ecs +from detection_rules.config import load_current_package_version from detection_rules.integrations import ( find_latest_compatible_version, load_integrations_manifests, load_integrations_schemas, ) -from detection_rules import ecs -from detection_rules.config import load_current_package_version from detection_rules.packaging import current_stack_version from detection_rules.rule import QueryValidator from detection_rules.rule_loader import RuleCollection @@ -40,14 +38,14 @@ class TestEndpointQuery(BaseRuleTest): def test_os_and_platform_in_query(self): """Test that all endpoint rules have an os defined and linux includes platform.""" for rule in self.all_rules: - if not rule.contents.data.get('language') in ('eql', 'kuery'): + if rule.contents.data.get("language") not in ("eql", "kuery"): continue if rule.path.parent.name not in ("windows", "macos", "linux"): # skip cross-platform for now continue ast = rule.contents.data.ast - fields = [str(f) for f in ast if isinstance(f, (kql.ast.Field, eql.ast.Field))] + fields = [str(f) for f in ast if isinstance(f, (kql.ast.Field | eql.ast.Field))] err_msg = f"{self.rule_str(rule)} missing required field for endpoint rule" if "host.os.type" not in fields: @@ -58,11 +56,6 @@ def test_os_and_platform_in_query(self): else: self.assertIn("host.os.type", fields, err_msg) - # going to bypass this for now - # if rule.path.parent.name == 'linux': - # err_msg = f'{self.rule_str(rule)} missing required field for linux endpoint rule' - # self.assertIn('host.os.platform', fields, err_msg) - class TestNewTerms(BaseRuleTest): """Test new term rules.""" @@ -75,14 +68,13 @@ def test_history_window_start(self): for rule in self.all_rules: if rule.contents.data.type == "new_terms": - # validate history window start field exists and is correct - assert ( - rule.contents.data.new_terms.history_window_start - ), "new terms field found with no history_window_start field defined" - assert ( - rule.contents.data.new_terms.history_window_start[0].field == "history_window_start" - ), f"{rule.contents.data.new_terms.history_window_start} should be 'history_window_start'" + assert rule.contents.data.new_terms.history_window_start, ( + "new terms field found with no history_window_start field defined" + ) + assert rule.contents.data.new_terms.history_window_start[0].field == "history_window_start", ( + f"{rule.contents.data.new_terms.history_window_start} should be 'history_window_start'" + ) @unittest.skipIf( PACKAGE_STACK_VERSION < Version.parse("8.4.0"), "Test only applicable to 8.4+ stacks for new terms feature." @@ -91,9 +83,9 @@ def test_new_terms_field_exists(self): # validate new terms and history window start fields are correct for rule in self.all_rules: if rule.contents.data.type == "new_terms": - assert ( - rule.contents.data.new_terms.field == "new_terms_fields" - ), f"{rule.contents.data.new_terms.field} should be 'new_terms_fields' for new_terms rule type" + assert rule.contents.data.new_terms.field == "new_terms_fields", ( + f"{rule.contents.data.new_terms.field} should be 'new_terms_fields' for new_terms rule type" + ) @unittest.skipIf( PACKAGE_STACK_VERSION < Version.parse("8.4.0"), "Test only applicable to 8.4+ stacks for new terms feature." @@ -115,9 +107,9 @@ def test_new_terms_fields(self): else min_stack_version ) - assert ( - min_stack_version >= feature_min_stack - ), f"New Terms rule types only compatible with {feature_min_stack}+" + assert min_stack_version >= feature_min_stack, ( + f"New Terms rule types only compatible with {feature_min_stack}+" + ) ecs_version = get_stack_schemas()[str(min_stack_version)]["ecs"] beats_version = get_stack_schemas()[str(min_stack_version)]["beats"] @@ -139,12 +131,10 @@ def test_new_terms_fields(self): ) if latest_tag_compat_ver: integration_schema = integration_schemas[tag][latest_tag_compat_ver] - for policy_template in integration_schema.keys(): + for policy_template in integration_schema: schema.update(**integration_schemas[tag][latest_tag_compat_ver][policy_template]) for new_terms_field in rule.contents.data.new_terms.value: - assert ( - new_terms_field in schema.keys() - ), f"{new_terms_field} not found in ECS, Beats, or non-ecs schemas" + assert new_terms_field in schema, f"{new_terms_field} not found in ECS, Beats, or non-ecs schemas" @unittest.skipIf( PACKAGE_STACK_VERSION < Version.parse("8.4.0"), "Test only applicable to 8.4+ stacks for new terms feature." @@ -167,9 +157,9 @@ def test_new_terms_max_limit(self): else min_stack_version ) if feature_min_stack <= min_stack_version < feature_min_stack_extended_fields: - assert ( - len(rule.contents.data.new_terms.value) == 1 - ), f"new terms have a max limit of 1 for stack versions below {feature_min_stack_extended_fields}" + assert len(rule.contents.data.new_terms.value) == 1, ( + f"new terms have a max limit of 1 for stack versions below {feature_min_stack_extended_fields}" + ) @unittest.skipIf( PACKAGE_STACK_VERSION < Version.parse("8.6.0"), "Test only applicable to 8.4+ stacks for new terms feature." @@ -179,9 +169,9 @@ def test_new_terms_fields_unique(self): # validate fields are unique for rule in self.all_rules: if rule.contents.data.type == "new_terms": - assert len(set(rule.contents.data.new_terms.value)) == len( - rule.contents.data.new_terms.value - ), f"new terms fields values are not unique - {rule.contents.data.new_terms.value}" + assert len(set(rule.contents.data.new_terms.value)) == len(rule.contents.data.new_terms.value), ( + f"new terms fields values are not unique - {rule.contents.data.new_terms.value}" + ) class TestESQLRules(BaseRuleTest): @@ -190,7 +180,7 @@ class TestESQLRules(BaseRuleTest): def run_esql_test(self, esql_query, expectation, message): """Test that the query validation is working correctly.""" rc = RuleCollection() - file_path = Path(get_path("tests", "data", "command_control_dummy_production_rule.toml")) + file_path = get_path(["tests", "data", "command_control_dummy_production_rule.toml"]) original_production_rule = load_rule_contents(file_path) # Test that a ValidationError is raised if the query doesn't match the schema @@ -200,19 +190,3 @@ def run_esql_test(self, esql_query, expectation, message): expectation.match_expr = message with expectation: rc.load_dict(production_rule) - - def test_esql_queries(self): - """Test ESQL queries.""" - # test_cases = [ - # # invalid queries - # ('from .ds-logs-endpoint.events.process-default-* | wheres process.name like "Microsoft*"', - # pytest.raises(marshmallow.exceptions.ValidationError), r"ESQL query failed"), - # ('from .ds-logs-endpoint.events.process-default-* | where process.names like "Microsoft*"', - # pytest.raises(marshmallow.exceptions.ValidationError), r"ESQL query failed"), - # - # # valid queries - # ('from .ds-logs-endpoint.events.process-default-* | where process.name like "Microsoft*"', - # does_not_raise(), None), - # ] - # for esql_query, expectation, message in test_cases: - # self.run_esql_test(esql_query, expectation, message) diff --git a/tests/test_toml_formatter.py b/tests/test_toml_formatter.py index 4787354fa83..dea5e850eff 100644 --- a/tests/test_toml_formatter.py +++ b/tests/test_toml_formatter.py @@ -5,34 +5,39 @@ import copy import json -import os import unittest +from pathlib import Path import pytoml from detection_rules.rule_formatter import nested_normalize, toml_write from detection_rules.utils import get_etc_path -tmp_file = 'tmp_file.toml' +tmp_file = "tmp_file.toml" class TestRuleTomlFormatter(unittest.TestCase): """Test that the custom toml formatting is not compromising the integrity of the data.""" - with open(get_etc_path("test_toml.json"), "r") as f: - test_data = json.load(f) + + maxDiff = None + + def setUp(self): + with get_etc_path(["test_toml.json"]).open() as f: + self.test_data = json.load(f) def compare_formatted(self, data, callback=None, kwargs=None): """Compare formatted vs expected.""" + tmp_path = Path(tmp_file) try: - toml_write(copy.deepcopy(data), tmp_file) + toml_write(copy.deepcopy(data), tmp_path) - with open(tmp_file, 'r') as f: - formatted_contents = pytoml.load(f) + formatted_data = tmp_path.read_text() + formatted_contents = pytoml.loads(formatted_data) # callbacks such as nested normalize leave in line breaks, so this must be manually done - query = data.get('rule', {}).get('query') + query = data.get("rule", {}).get("query") if query: - data['rule']['query'] = query.strip() + data["rule"]["query"] = query.strip() original = json.dumps(copy.deepcopy(data), sort_keys=True) @@ -41,15 +46,15 @@ def compare_formatted(self, data, callback=None, kwargs=None): formatted_contents = callback(formatted_contents, **kwargs) # callbacks such as nested normalize leave in line breaks, so this must be manually done - query = formatted_contents.get('rule', {}).get('query') + query = formatted_contents.get("rule", {}).get("query") if query: - formatted_contents['rule']['query'] = query.strip() + formatted_contents["rule"]["query"] = query.strip() formatted = json.dumps(formatted_contents, sort_keys=True) - self.assertEqual(original, formatted, 'Formatting may be modifying contents') - + self.assertEqual(original, formatted, "Formatting may be modifying contents") finally: - os.remove(tmp_file) + if tmp_path.exists(): + tmp_path.unlink() def compare_test_data(self, test_dicts, callback=None): """Compare test data against expected.""" @@ -67,12 +72,3 @@ def test_formatter_rule(self): def test_formatter_deep(self): """Test that the data remains unchanged from formatting.""" self.compare_test_data(self.test_data[1:]) - # - # def test_format_of_all_rules(self): - # """Test all rules.""" - # rules = rule_loader.load_rules().values() - # - # for rule in rules: - # is_eql_rule = isinstance(rule.contents.data, EQLRuleData) - # self.compare_formatted( - # rule.rule_format(formatted_query=False), callback=nested_normalize, kwargs={'eql_rule': is_eql_rule}) diff --git a/tests/test_transform_fields.py b/tests/test_transform_fields.py index 53352584181..ebd95a9f0a4 100644 --- a/tests/test_transform_fields.py +++ b/tests/test_transform_fields.py @@ -4,6 +4,7 @@ # 2.0. """Test fields in TOML [transform].""" + import copy import unittest from textwrap import dedent @@ -22,9 +23,9 @@ class TestGuideMarkdownPlugins(unittest.TestCase): def setUpClass(cls) -> None: cls.osquery_patterns = [ """!{osquery{"label":"Osquery - Retrieve DNS Cache","query":"SELECT * FROM dns_cache"}}""", - """!{osquery{"label":"Osquery - Retrieve All Services","query":"SELECT description, display_name, name, path, pid, service_type, start_type, status, user_account FROM services"}}""", # noqa: E501 - """!{osquery{"label":"Osquery - Retrieve Services Running on User Accounts","query":"SELECT description, display_name, name, path, pid, service_type, start_type, status, user_account FROM services WHERE NOT (user_account LIKE '%LocalSystem' OR user_account LIKE '%LocalService' OR user_account LIKE '%NetworkService' OR user_account == null)"}}""", # noqa: E501 - """!{osquery{"label":"Retrieve Service Unisgned Executables with Virustotal Link","query":"SELECT concat('https://www.virustotal.com/gui/file/', sha1) AS VtLink, name, description, start_type, status, pid, services.path FROM services JOIN authenticode ON services.path = authenticode.path OR services.module_path = authenticode.path JOIN hash ON services.path = hash.path WHERE authenticode.result != 'trusted'"}}""", # noqa: E501 + """!{osquery{"label":"Osquery - Retrieve All Services","query":"SELECT description, display_name, name, path, pid, service_type, start_type, status, user_account FROM services"}}""", + """!{osquery{"label":"Osquery - Retrieve Services Running on User Accounts","query":"SELECT description, display_name, name, path, pid, service_type, start_type, status, user_account FROM services WHERE NOT (user_account LIKE '%LocalSystem' OR user_account LIKE '%LocalService' OR user_account LIKE '%NetworkService' OR user_account == null)"}}""", + """!{osquery{"label":"Retrieve Service Unisgned Executables with Virustotal Link","query":"SELECT concat('https://www.virustotal.com/gui/file/', sha1) AS VtLink, name, description, start_type, status, pid, services.path FROM services JOIN authenticode ON services.path = authenticode.path OR services.module_path = authenticode.path JOIN hash ON services.path = hash.path WHERE authenticode.result != 'trusted'"}}""", ] @staticmethod @@ -45,7 +46,7 @@ def load_rule() -> TOMLRule: "license": "Elastic License v2", "from": "now-9m", "name": "Test Suspicious Print Spooler SPL File Created", - "note": 'Test note', + "note": "Test note", "references": ["https://safebreach.com/Post/How-we-bypassed-CVE-2020-1048-Patch-and-got-CVE-2020-1337"], "risk_score": 47, "rule_id": "43716252-4a45-4694-aff0-5245b7b6c7cd", @@ -85,13 +86,13 @@ def load_rule() -> TOMLRule: "language": "eql", }, } - sample_rule = rc.load_dict(windows_rule) - return sample_rule + return rc.load_dict(windows_rule) def test_transform_guide_markdown_plugins(self) -> None: sample_rule = self.load_rule() rule_dict = sample_rule.contents.to_dict() - osquery_toml = dedent(""" + osquery_toml = dedent( + """ [transform] [[transform.osquery]] label = "Osquery - Retrieve DNS Cache" @@ -108,9 +109,11 @@ def test_transform_guide_markdown_plugins(self) -> None: [[transform.osquery]] label = "Retrieve Service Unisgned Executables with Virustotal Link" query = "SELECT concat('https://www.virustotal.com/gui/file/', sha1) AS VtLink, name, description, start_type, status, pid, services.path FROM services JOIN authenticode ON services.path = authenticode.path OR services.module_path = authenticode.path JOIN hash ON services.path = hash.path WHERE authenticode.result != 'trusted'" - """.strip()) # noqa: E501 + """.strip() + ) - sample_note = dedent(""" + sample_note = dedent( + """ ## Triage and analysis ### Investigating Unusual Process For a Windows Host @@ -135,15 +138,16 @@ def test_transform_guide_markdown_plugins(self) -> None: - $osquery_2 - $osquery_3 - Retrieve the files' SHA-256 hash values using the PowerShell `Get-FileHash` cmdlet and search for the existence and reputation of the hashes in resources like VirusTotal, Hybrid-Analysis, CISCO Talos, Any.run, etc. - """.strip()) # noqa: E501 + """.strip() + ) transform = pytoml.loads(osquery_toml) - rule_dict['rule']['note'] = sample_note + rule_dict["rule"]["note"] = sample_note rule_dict.update(**transform) new_rule_contents = TOMLRuleContents.from_dict(rule_dict) new_rule = TOMLRule(path=sample_rule.path, contents=new_rule_contents) - rendered_note = new_rule.contents.to_api_format()['note'] + rendered_note = new_rule.contents.to_api_format()["note"] for pattern in self.osquery_patterns: self.assertIn(pattern, rendered_note) @@ -152,7 +156,7 @@ def test_plugin_conversion(self): """Test the conversion function to ensure parsing is correct.""" sample_rule = self.load_rule() rule_dict = sample_rule.contents.to_dict() - rule_dict['rule']['note'] = "$osquery_0" + rule_dict["rule"]["note"] = "$osquery_0" for pattern in self.osquery_patterns: transform = guide_plugin_convert_(contents=pattern) @@ -160,6 +164,6 @@ def test_plugin_conversion(self): rule_dict_copy.update(**transform) new_rule_contents = TOMLRuleContents.from_dict(rule_dict_copy) new_rule = TOMLRule(path=sample_rule.path, contents=new_rule_contents) - rendered_note = new_rule.contents.to_api_format()['note'] + rendered_note = new_rule.contents.to_api_format()["note"] self.assertIn(pattern, rendered_note) diff --git a/tests/test_utils.py b/tests/test_utils.py index 4788906420f..b853f357cb2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,37 +4,38 @@ # 2.0. """Test util time functions.""" + import random import time import unittest -from detection_rules.utils import normalize_timing_and_sort, cached -from detection_rules.eswrap import Events from detection_rules.ecs import get_kql_schema +from detection_rules.eswrap import Events +from detection_rules.utils import cached, normalize_timing_and_sort class TestTimeUtils(unittest.TestCase): """Test util time functions.""" @staticmethod - def get_events(timestamp_field='@timestamp'): + def get_events(timestamp_field="@timestamp"): """Get test data.""" date_formats = { - 'epoch_millis': lambda x: int(round(time.time(), 3) + x) * 1000, - 'epoch_second': lambda x: round(time.time()) + x, - 'unix_micros': lambda x: time.time() + x, - 'unix_millis': lambda x: round(time.time(), 3) + x, - 'strict_date_optional_time': lambda x: '2020-05-13T04:36:' + str(15 + x) + '.394Z' + "epoch_millis": lambda x: int(round(time.time(), 3) + x) * 1000, + "epoch_second": lambda x: round(time.time()) + x, + "unix_micros": lambda x: time.time() + x, + "unix_millis": lambda x: round(time.time(), 3) + x, + "strict_date_optional_time": lambda x: "2020-05-13T04:36:" + str(15 + x) + ".394Z", } def _get_data(func): data = [ - {timestamp_field: func(0), 'foo': 'bar', 'id': 1}, - {timestamp_field: func(1), 'foo': 'bar', 'id': 2}, - {timestamp_field: func(2), 'foo': 'bar', 'id': 3}, - {timestamp_field: func(3), 'foo': 'bar', 'id': 4}, - {timestamp_field: func(4), 'foo': 'bar', 'id': 5}, - {timestamp_field: func(5), 'foo': 'bar', 'id': 6} + {timestamp_field: func(0), "foo": "bar", "id": 1}, + {timestamp_field: func(1), "foo": "bar", "id": 2}, + {timestamp_field: func(2), "foo": "bar", "id": 3}, + {timestamp_field: func(3), "foo": "bar", "id": 4}, + {timestamp_field: func(4), "foo": "bar", "id": 5}, + {timestamp_field: func(5), "foo": "bar", "id": 6}, ] random.shuffle(data) return data @@ -43,8 +44,8 @@ def _get_data(func): def assert_sort(self, normalized_events, date_format): """Assert normalize and sort.""" - order = [e['id'] for e in normalized_events] - self.assertListEqual([1, 2, 3, 4, 5, 6], order, 'Sorting failed for date_format: {}'.format(date_format)) + order = [e["id"] for e in normalized_events] + self.assertListEqual([1, 2, 3, 4, 5, 6], order, f"Sorting failed for date_format: {date_format}") def test_time_normalize(self): """Test normalize_timing_from_date_format.""" @@ -57,8 +58,8 @@ def test_event_class_normalization(self): """Test that events are normalized properly within Events.""" events_data = self.get_events() for date_format, events in events_data.items(): - normalized = Events({'winlogbeat': events}) - self.assert_sort(normalized.events['winlogbeat'], date_format) + normalized = Events({"winlogbeat": events}) + self.assert_sort(normalized.events["winlogbeat"], date_format) def test_schema_multifields(self): """Tests that schemas are loading multifields correctly.""" @@ -88,15 +89,15 @@ def increment(*args, **kwargs): self.assertEqual(increment(), 1) self.assertEqual(increment(["hello", "world"]), 2) - self.assertEqual(increment({"hello": [("world", )]}), 3) - self.assertEqual(increment({"hello": [("world", )]}), 3) + self.assertEqual(increment({"hello": [("world",)]}), 3) + self.assertEqual(increment({"hello": [("world",)]}), 3) self.assertEqual(increment(), 1) self.assertEqual(increment(["hello", "world"]), 2) - self.assertEqual(increment({"hello": [("world", )]}), 3) + self.assertEqual(increment({"hello": [("world",)]}), 3) increment.clear() - self.assertEqual(increment({"hello": [("world", )]}), 4) + self.assertEqual(increment({"hello": [("world",)]}), 4) self.assertEqual(increment(["hello", "world"]), 5) self.assertEqual(increment(), 6) self.assertEqual(increment(None), 7) diff --git a/tests/test_version_locking.py b/tests/test_version_locking.py index 37e0e5e420e..bbb8555aef6 100644 --- a/tests/test_version_locking.py +++ b/tests/test_version_locking.py @@ -10,20 +10,20 @@ from semver import Version from detection_rules.schemas import get_min_supported_stack_version -from detection_rules.version_lock import loaded_version_lock, RULES_CONFIG +from detection_rules.version_lock import RULES_CONFIG, loaded_version_lock class TestVersionLock(unittest.TestCase): """Test version locking.""" - @unittest.skipIf(RULES_CONFIG.bypass_version_lock, 'Version lock bypassed') + @unittest.skipIf(RULES_CONFIG.bypass_version_lock, "Version lock bypassed") def test_previous_entries_gte_current_min_stack(self): """Test that all previous entries for all locks in the version lock are >= the current min_stack.""" errors = {} min_version = get_min_supported_stack_version() for rule_id, lock in loaded_version_lock.version_lock.to_dict().items(): - if 'previous' in lock: - prev_vers = [Version.parse(v, optional_minor_and_patch=True) for v in list(lock['previous'])] + if "previous" in lock: + prev_vers = [Version.parse(v, optional_minor_and_patch=True) for v in list(lock["previous"])] outdated = [f"{v.major}.{v.minor}" for v in prev_vers if v < min_version] if outdated: errors[rule_id] = outdated @@ -31,7 +31,9 @@ def test_previous_entries_gte_current_min_stack(self): # This should only ever happen when bumping the backport matrix support up, which is based on the # stack-schema-map if errors: - err_str = '\n'.join(f'{k}: {", ".join(v)}' for k, v in errors.items()) - self.fail(f'The following version.lock entries have previous locked versions which are lower than the ' - f'currently supported min_stack ({min_version}). To address this, run the ' - f'`dev trim-version-lock {min_version}` command.\n\n{err_str}') + err_str = "\n".join(f"{k}: {', '.join(v)}" for k, v in errors.items()) + self.fail( + f"The following version.lock entries have previous locked versions which are lower than the " + f"currently supported min_stack ({min_version}). To address this, run the " + f"`dev trim-version-lock {min_version}` command.\n\n{err_str}" + )