diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8116f5b --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +.env +.venv +__pycache__ +.mypy_cache +.pytest_cache \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..7b5aae8 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,37 @@ +fail_fast: true + +repos: + - repo: https://github.com/ambv/black + rev: 22.3.0 + hooks: + - id: black + args: [--diff, --check] + + - repo: local + hooks: + - id: pylint + name: pylint + entry: pylint + language: system + types: [python] + require_serial: true + args: ["--min-public-methods=1"] + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.902 + hooks: + - id: mypy + exclude: ^tests/ + additional_dependencies: ['types-pytz'] + args: [--strict, --ignore-missing-imports] + + - repo: local + hooks: + - id: pytest-check + name: pytest-check + stages: [commit] + types: [python] + entry: pytest + language: system + pass_filenames: false + always_run: true \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..164f9e7 --- /dev/null +++ b/README.md @@ -0,0 +1,46 @@ +# Subscriber Usage Report Generator +Generate Usage Report for this day. Query data from MongoDB databases and write the report on a MySQL database. +# Required Packages +- Python 3.11 +- MongoDB 6.0 Instance +- MySQL 8.0 + - Before installing `mysqlclient`, follow instructions [here](https://github.com/PyMySQL/mysqlclient) at the `Install` section. +# How to Install? +- Create a `.env` file from the `sample.env` file. +```bash +cp sample.env .env +nano .env +``` +- Create and activate the Virtual Environment. + +**Linux/MacOS** +```bash +python3.11 -m venv .venv +source .venv/bin/activate +``` +**Windows** +```bash +python3.11 -m venv .venv +.venv/bin/Activate.ps1 +``` +- Install the packages. +```bash +(.venv) pip install -r requirements.txt +``` +# How to Run? +``` +python run.py +``` +# How to Test/Develop? +- Assuming you're still in the Virtual Environment, install the Development packages. +```bash +(.venv) pip install -r requirements-dev.txt +``` +- Install the `pre-commit` hooks. +```bash +(.venv) pre-commit install +``` +- Run all linter and unit tests. +```bash +(.venv) pre-commit +``` \ No newline at end of file diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/config.py b/app/config.py new file mode 100644 index 0000000..dd66060 --- /dev/null +++ b/app/config.py @@ -0,0 +1,75 @@ +"""Load or add configurations here. All secrets should be on .env""" +import os +from dotenv import load_dotenv + + +load_dotenv() + + +def get_env_var(env_var_key: str) -> str: + """Get the Environment Variable given the key. + + Args: + env_var_key (str): Environment Variable key. + + Raises: + ValueError: Environment Variable is not found. + + Returns: + str: Environment Variable. + """ + env_var = os.environ.get(env_var_key, None) + if env_var is None: + raise ValueError(f"Missing Environment Variable: {env_var_key}") + + return env_var + + +# Reporting MySQL DB +REPORTING_MYSQL_DB = { + "SERVER": get_env_var("REPORTING_SQL_SERVER"), + "PORT": get_env_var("REPORTING_SQL_PORT"), + "DATABASE": get_env_var("REPORTING_SQL_DATABASE"), + "USERNAME": get_env_var("REPORTING_SQL_USERNAME"), + "PASSWORD": get_env_var("REPORTING_SQL_PASSWORD"), +} +REPORTING_AULDATALEAK_TABLENAME = get_env_var("REPORTING_AULDATALEAK_TABLENAME") + +# Audit MongoDB +AUDIT_MONGODB = { + "SERVER": get_env_var("AUDIT_MONGO_SERVER"), + "REPLICASET": get_env_var("AUDIT_MONGO_REPLICASET"), + "USERNAME": get_env_var("AUDIT_MONGO_USERNAME"), + "PASSWORD": get_env_var("AUDIT_MONGO_PASSWORD"), + "DATABASE": get_env_var("AUDIT_MONGO_DATABASE"), + "COLLECTION": get_env_var("AUDIT_MONGO_DATABASE"), +} + +# Usage MongoDB's +ARC_MONGODB_NODES = { + "A": { + "SERVER": get_env_var("ARC_MONGO_SERVER_A"), + "REPLICASET": get_env_var("ARC_MONGO_REPLICASET_A"), + }, + "B": { + "SERVER": get_env_var("ARC_MONGO_SERVER_B"), + "REPLICASET": get_env_var("ARC_MONGO_REPLICASET_B"), + }, + "C": { + "SERVER": get_env_var("ARC_MONGO_SERVER_C"), + "REPLICASET": get_env_var("ARC_MONGO_REPLICASET_C"), + }, +} +ARC_MONGODB_DETAILS = { + "USERNAME": get_env_var("ARC_MONGO_USERNAME"), + "PASSWORD": get_env_var("ARC_MONGO_PASSWORD"), + "DATABASE": get_env_var("ARC_MONGO_DATABASE"), + "COLLECTION": get_env_var("ARC_MONGO_COLLECTION"), +} + +# General MongoDB Settings +GENERAL_MONGODB_SETTINGS = { + "AUTHSOURCE": get_env_var("MONGO_AUTHSOURCE"), + "AUTHMECHANISM": "SCRAM-SHA-1", + "READ_PREFERENCE": "secondary", +} diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..748df5d --- /dev/null +++ b/app/main.py @@ -0,0 +1,268 @@ +"""All important functions goes here.""" +from datetime import datetime +from ast import literal_eval +import pandas as pd +import pytz +from pandas import DataFrame + +from app.utils.db_utils import MySQLClient, MongoCollection +from app.utils.time_utils import get_datetime_today, get_datetime_range_today +from app.config import ( + REPORTING_MYSQL_DB, + REPORTING_AULDATALEAK_TABLENAME, + AUDIT_MONGODB, + ARC_MONGODB_NODES, + ARC_MONGODB_DETAILS, + GENERAL_MONGODB_SETTINGS, +) + + +def get_reporting_client() -> MySQLClient: + """Build MySQL Client for Reporting. + + Returns: + MySQLClient: MySQL Client for Reporting. + """ + return MySQLClient(mysql_details=REPORTING_MYSQL_DB) + + +def get_audit_collection() -> MongoCollection: + """Build MongoDB Collection for Auditing. + + Returns: + MongoCollection: MongoDB Collection for Auditing + """ + return MongoCollection( + mongodb_details=AUDIT_MONGODB, mongodb_general_settings=GENERAL_MONGODB_SETTINGS + ) + + +def get_usage_collection(node: str) -> MongoCollection: + """Build MongoDB Collection for Usage Reporting. + + Returns: + MongoCollection: MongoDB Collection for Usage Reporting. + """ + mongodb_details = {**ARC_MONGODB_NODES[node], **ARC_MONGODB_DETAILS} + return MongoCollection( + mongodb_details=mongodb_details, + mongodb_general_settings=GENERAL_MONGODB_SETTINGS, + ) + + +def init_auldata_leak_reporting_table(client: MySQLClient) -> None: + """Initialize Table for Data Leak. + + Args: + client (MySQLClient): MySQL Client for Reporting. + """ + print("Creating table... " + REPORTING_AULDATALEAK_TABLENAME) + + reporting_table_create_query = f"CREATE TABLE IF NOT EXISTS \ + {REPORTING_AULDATALEAK_TABLENAME} ( \ + `SUBSCRIBERID` VARCHAR(100), \ + `MDN` VARCHAR(100), \ + `BAN` VARCHAR(100), \ + `USAGESTART` DATETIME, \ + `USAGEEND` DATETIME, \ + `TOTALMB` DECIMAL, \ + `AUDITDATE` DATETIME \ + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;" + + reporting_table_create_index = f"CREATE INDEX idx_AUDITDATE \ + ON {REPORTING_AULDATALEAK_TABLENAME} (AUDITDATE);" + + client.execute_query(reporting_table_create_query) + client.execute_query(reporting_table_create_index) + + +def get_auldata_subscribers( + audit_collection: MongoCollection, + audit_range_start: datetime, + audit_range_end: datetime, +) -> DataFrame: + """Get data of all Subscribers from the Audit Database. + + Args: + audit_collection (MongoCollection): MongoDB Collection for Auditing. + audit_range_start (datetime): Start Datetime of Audit. + audit_range_end (datetime): End Datetime of Audit. + + Returns: + DataFrame: Data of all Subscribers. + """ + + audit_query = [ + { + "$match": { + "$and": [ + { + "details": { + "$elemMatch": { + "state": "ADD", + "data.payload.payloads": { + "$elemMatch": { + "requestpayload.subscriptions": { + "$elemMatch": {"offerName": "MYOFFERNAME"} + } + } + }, + }, + }, + }, + { + "lastModifiedDate": { + "$gte": audit_range_start, + "$lte": audit_range_end, + }, + }, + ] + }, + }, + {"$unwind": {"path": "$details"}}, + { + "$match": { + "details.state": "ADD", + "details.data.payload.payloads": { + "$elemMatch": { + "requestpayload.subscriptions": { + "$elemMatch": {"offerName": "MYOFFERNAME"} + }, + }, + }, + }, + }, + {"$unwind": {"path": "$details.data.payload.payloads"}}, + { + "$unwind": { + "path": "$details.data.payload.payloads.requestpayload.subscriptions" + }, + }, + { + "$project": { + "_id": 0.0, + "ban": 1.0, + "subscriberId": "$details.data.payload.subscriberId", + "effectiveDate": "$details.data.payload.payloads.\ + requestpayload.subscriptions.effectiveDate", + "expiryDate": "$details.data.payload.payloads.\ + requestpayload.subscriptions.expiryDate", + }, + }, + ] + + return audit_collection.run_mongo_query_agr(audit_query) + + +def run_compare_on_node( + node: str, sub_list: DataFrame, reporting_client: MySQLClient +) -> None: + """Create a report of Usage of Subscribers for a single node. + Write the report on the Reporting Database. + + Args: + node (str): Choose which Node to compare. + sub_list (DataFrame): Sub-list of Subscribers. + reporting_client (MySQLClient): MySQL Client for Reporting. + """ + if len(sub_list) == 0: + return None + + audit_date = get_datetime_today() + usage_collection = get_usage_collection(node) + + usage_result = DataFrame( + columns=["extSubId", "MDN", "BAN", "start", "end", "bytesIn", "bytesOut"] + ) + + for _, subscriber in sub_list.iterrows(): + print(subscriber["subscriberId"]) + effective_date = datetime.strptime( + subscriber["effectiveDate"], "%Y-%m-%dT%H:%M:%SZ" + ).astimezone(pytz.timezone("US/Eastern")) + expiry_date = datetime.strptime( + subscriber["expiryDate"], "%Y-%m-%dT%H:%M:%SZ" + ).astimezone(pytz.timezone("US/Eastern")) + + usage_query = { + "$and": [ + {"end": {"$gte": effective_date, "$lte": expiry_date}}, + {"extSubId": literal_eval(subscriber["subscriberId"])}, + {"usageType": "OVER"}, + {"$or": [{"bytesIn": {"$gt": 0}, "bytesOut": {"$gt": 0}}]}, + ] + } + usage_project = { + "_id": 0, + "extSubId": 1, + "MDN": 1, + "BAN": 1, + "start": 1, + "end": 1, + "bytesIn": 1, + "bytesOut": 1, + } + query_result = usage_collection.run_mongo_query(usage_query, usage_project) + usage_result = pd.concat([usage_result, query_result], axis=0) + + if len(usage_result) == 0: + continue + + usage_result_reporting_query = f"INSERT INTO {REPORTING_AULDATALEAK_TABLENAME} \ + (SUBSCRIBERID, MDN, BAN, USAGESTART, USAGEEND, TOTALMB, AUDITDATE) VALUES " + for _, row in usage_result.iterrows(): + usage_result_reporting_query += f"('{row['extSubId']}', \ + {row['MDN']}, {row['BAN']}, '{row['start']}', \ + '{row['end']}', '{int(row['bytesIn']) + int(row['bytesOut'])}', \ + '{audit_date}')," + usage_result_reporting_query = usage_result_reporting_query[:-1] + reporting_client.execute_query(usage_result_reporting_query) + print(usage_result.size + " rows written to " + REPORTING_AULDATALEAK_TABLENAME) + + return None + + +def compare_auldata(auldata_subs: DataFrame, reporting_client: MySQLClient) -> None: + """Generate Report of Usage of Subscribers for three nodes. + + Args: + auldata_subs (DataFrame): Data of all Subscribers. + reporting_client (MySQLClient): MySQL Client for Reporting. + """ + nodes = list(ARC_MONGODB_NODES.keys()) + sub_lists: list[DataFrame] = [[], [], []] + + for _, row in auldata_subs.iterrows(): + remainder = int(row["ban"]) % 3 + sub_lists[remainder].append(row) + + for node, sub_list in zip(nodes, sub_lists): + run_compare_on_node(node, sub_list, reporting_client) + + +def cleanup_auldata_leak_reporting_table(client: MySQLClient) -> None: + """Delete the older Subscriber Data Reports older than 1 month. + + Args: + client (MySQLClient): MySQL Client for Reporting. + """ + reporting_table_delete_query = f"DELETE FROM {REPORTING_AULDATALEAK_TABLENAME} \ + WHERE AUDITDATE < DATE_SUB(NOW(), INTERVAL 1 MONTH)" + client.execute_query(reporting_table_delete_query) + + print("Deleting table... " + REPORTING_AULDATALEAK_TABLENAME) + + +def run_program() -> None: + """Run the Reporting Process""" + reporting_client = get_reporting_client() + init_auldata_leak_reporting_table(reporting_client) + + audit_collection = get_audit_collection() + audit_range_start, audit_range_end = get_datetime_range_today() + auldata_subs = get_auldata_subscribers( + audit_collection, audit_range_start, audit_range_end + ) + + compare_auldata(auldata_subs, reporting_client) + cleanup_auldata_leak_reporting_table(reporting_client) diff --git a/app/utils/__init__.py b/app/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/utils/db_utils.py b/app/utils/db_utils.py new file mode 100644 index 0000000..305f808 --- /dev/null +++ b/app/utils/db_utils.py @@ -0,0 +1,125 @@ +"""Classes and Methods for Database clients.""" +from typing import Union, Optional, Any +from sqlalchemy import create_engine +from sqlalchemy.exc import SQLAlchemyError +from pymongo import MongoClient +from pymongo import DESCENDING +from pandas import DataFrame + + +class MySQLClient: + """MySQL Client Object.""" + + def __init__(self, mysql_details: dict[str, str]) -> None: + server = mysql_details["SERVER"] + port = mysql_details["PORT"] + database = mysql_details["DATABASE"] + username = mysql_details["USERNAME"] + password = mysql_details["PASSWORD"] + + mysql_uri = f"mysql://{username}:{password}@{server}:{port}\ + /{database}?charset=utf8" + self.engine = create_engine(mysql_uri, pool_recycle=3600) + + def execute_query(self, query: str) -> Union[int, str]: + """Execute a Query via the MySQL connection. + + Args: + query (str): Query strng. + + Returns: + Union[int, str]: 0 if successful. Error Message if not. + """ + try: + with self.engine.connect() as connection: + connection.execute(query) + + return 0 + except SQLAlchemyError as sql_error: + error = str(sql_error.__dict__["orig"]) + return error + + +class MongoCollection: + """MongoDB Collection Object""" + + def __init__( + self, + mongodb_details: dict[str, str], + mongodb_general_settings: dict[str, str], + ) -> None: + server = mongodb_details["SERVER"] + replica_set = mongodb_details["REPLICASET"] + username = mongodb_details["USERNAME"] + password = mongodb_details["PASSWORD"] + database = mongodb_details["DATABASE"] + collection = mongodb_details["COLLECTION"] + + auth_source = mongodb_general_settings["AUTHSOURCE"] + auth_mechanism = mongodb_general_settings["AUTHMECHANISM"] + read_preference = mongodb_general_settings["READ_PREFERENCE"] + + mongo_uri = f"mongodb://{username}:{password}@{server}" + self.client = MongoClient( + mongo_uri, + replicaSet=replica_set, + authSource=auth_source, + readPreference=read_preference, + authMechanism=auth_mechanism, + ) + try: + self.database = self.client[database] + except Exception as exc: + raise ValueError(f"Database {database} does not exist.") from exc + + try: + self.collection = self.database[collection] + except Exception as exc: + raise ValueError(f"Collection {collection} does not exist.") from exc + + def run_mongo_query( + self, + query: Union[list[Any], dict[str, Any]], + project: Optional[dict[str, Any]] = None, + sort_field: Optional[str] = "eventTime", + limit_count: Optional[int] = None, + ) -> DataFrame: + """Run a MongoDB Query on the Collection. + + Args: + query (dict): MongoDB Query. + project (dict, optional): Specify the fields to be returned. Defaults to None. + sort_field (str, optional): Field to sort on. Defaults to "eventTime". + limit_results (bool, optional): Flag to limit the results. Defaults to False. + limit_count (int, optional): Maximum number of results. Defaults to 10. + + Returns: + DataFrame: Query Result. + """ + results = [] + if project is not None: + db_query = self.collection.find(query, project) + else: + db_query = self.collection.find(query) + if sort_field is not None: + db_query.sort(sort_field, DESCENDING) + if limit_count: + db_query.limit(limit_count) + for doc in db_query: + results.append(doc) + + results_df = DataFrame(list(results)) + return results_df + + def run_mongo_query_agr(self, query: Union[list[Any], dict[str, Any]]) -> DataFrame: + """Return an aggregated result from a MongoDB Query. + + Args: + query (dict): MongoDB Query. + + Returns: + DataFrame: Aggregated Query result. + """ + results = self.collection.aggregate(query, cursor={}) + results_df = DataFrame(list(results)) + return results_df diff --git a/app/utils/time_utils.py b/app/utils/time_utils.py new file mode 100644 index 0000000..4c3e6d1 --- /dev/null +++ b/app/utils/time_utils.py @@ -0,0 +1,26 @@ +"""Time-related utility functions""" +from datetime import datetime, timedelta, date, time +from typing import Tuple + + +def get_datetime_range_today() -> Tuple[datetime, datetime]: + """Get a tuple of datetimes which would be the first and last + datetime for today. + + Returns: + Tuple[datetime, datetime]: Start and End Datetimes for today. + """ + date_today = date.today() - timedelta(days=1) + range_start = datetime.combine(date_today, time(0, 0, 0)) + range_end = datetime.combine(date_today, time(23, 59, 59)) + + return range_start, range_end + + +def get_datetime_today() -> str: + """Get a string formatted datetime for now. + + Returns: + str: String formatted datetime for now + """ + return datetime.today().strftime("%Y-%m-%d %H:%M:%S") diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..be42644 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,23 @@ +astroid==2.15.6 +cfgv==3.3.1 +dill==0.3.6 +distlib==0.3.6 +filelock==3.12.2 +identify==2.5.24 +iniconfig==2.0.0 +isort==5.12.0 +lazy-object-proxy==1.9.0 +mccabe==0.7.0 +mypy==1.4.1 +mypy-extensions==1.0.0 +nodeenv==1.8.0 +packaging==23.1 +platformdirs==3.8.1 +pluggy==1.2.0 +pre-commit==3.3.3 +pylint==2.17.4 +pytest==7.4.0 +PyYAML==6.0 +tomlkit==0.11.8 +virtualenv==20.23.1 +wrapt==1.15.0 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..7c48907 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,14 @@ +dnspython==2.3.0 +greenlet==2.0.2 +mysqlclient==2.2.0 +numpy==1.25.1 +pandas==2.0.3 +pymongo==4.4.0 +python-dateutil==2.8.2 +python-dotenv==1.0.0 +pytz==2023.3 +six==1.16.0 +SQLAlchemy==2.0.18 +types-pytz==2023.3.0.0 +typing_extensions==4.7.1 +tzdata==2023.3 diff --git a/run.py b/run.py new file mode 100644 index 0000000..8766086 --- /dev/null +++ b/run.py @@ -0,0 +1,5 @@ +"""Entry point to the program.""" +from app.main import run_program + +if __name__ == "__main__": + run_program() diff --git a/sample.env b/sample.env new file mode 100644 index 0000000..8bdb9f5 --- /dev/null +++ b/sample.env @@ -0,0 +1,30 @@ +# Reporting MySQL Server +REPORTING_SQL_SERVER = +REPORTING_SQL_PORT = +REPORTING_SQL_DATABASE = +REPORTING_SQL_USERNAME = +REPORTING_SQL_PASSWORD = +REPORTING_AULDATALEAK_TABLENAME = + +# Audit MongoDB Server +AUDIT_MONGO_SERVER = +AUDIT_MONGO_REPLICASET = +AUDIT_MONGO_USERNAME = +AUDIT_MONGO_PASSWORD = +AUDIT_MONGO_DATABASE = +AUDIT_MONGO_COLLECTION = + +# Comparison MongoDB Servers +ARC_MONGO_SERVER_A = +ARC_MONGO_SERVER_B = +ARC_MONGO_SERVER_C = +ARC_MONGO_REPLICASET_A = +ARC_MONGO_REPLICASET_B = +ARC_MONGO_REPLICASET_C = +ARC_MONGO_USERNAME = +ARC_MONGO_PASSWORD = +ARC_MONGO_DATABASE = +ARC_MONGO_COLLECTION = + +# General MongoDB Settings +MONGO_AUTHSOURCE = \ No newline at end of file diff --git a/sample_script.py b/sample_script.py deleted file mode 100644 index 7454a6d..0000000 --- a/sample_script.py +++ /dev/null @@ -1,269 +0,0 @@ -import os - -import pytz -import pandas as pd -from datetime import datetime, timedelta, date, time -from pymongo import MongoClient -from pymongo.collection import Collection -from pymongo import DESCENDING -from sqlalchemy import create_engine -from sqlalchemy.exc import SQLAlchemyError - -REPORTING_SQL_SERVER = '127.0.0.1' -REPORTING_SQL_PORT = '3306' -REPORTING_SQL_DATABASE = 'myreportingdatabase' -REPORTING_SQL_USERNAME = os.environ.get('REPORTING_SQL_USERNAME') -REPORTING_SQL_PASSWORD = os.environ.get('REPORTING_SQL_PASSWORD') - -AUDIT_SERVER = "127.0.0.1:27018" -AUDIT_REPLICASET = "rs4" -AUDIT_USERNAME = os.environ.get('MONGO_AUDIT_USERNAME') -AUDIT_PASSWORD = os.environ.get('MONGO_AUDIT_PASSWORD') -AUDIT_DATABASE = "mydb" -AUDIT_COLLECTION = "myauditcollection" - -SERVER_A = "127.0.0.1:27017" -SERVER_B = "127.0.0.1:27017" -SERVER_C = "127.0.0.1:27017" -REPLICASET_A = "rs0" -REPLICASET_B = "rs1" -REPLICASET_C = "rs2" -USERNAME = os.environ.get('mongo_USERNAME')PASSWORD = os.environ.get('mongo_PASSWORD')DATABASE = "mydb"COLLECTION = "mycollection" -ARC_MONGO_PORT = '27017' -ARC_MONGO_AUTHMECHANISM = "SCRAM-SHA-1" -ARC_MONGO_AUTHSOURCE = "admin" -ARC_MONGO_DATABASE = 'admin' -ARC_MONGO_READ_PREFERENCE = "secondary" - -REPORTING_AULDATALEAK_TABLENAME = "auldata_leak" - - -def get_mongo_client(mongoServers: str, mongoReplicaset: str, username: str, password: str): - mongo_uri = 'mongodb://%s:%s@%s' % (username, password, mongoServers) - return MongoClient(mongo_uri, replicaSet=mongoReplicaset, authSource=ARC_MONGO_AUTHSOURCE, - readPreference=ARC_MONGO_READ_PREFERENCE, - authMechanism=ARC_MONGO_AUTHMECHANISM) - - -def connect_to_mysql(): - mysql_uri = 'mysql://%s:%s@%s:%s/%s?charset=utf8' % (REPORTING_SQL_USERNAME, REPORTING_SQL_PASSWORD, - REPORTING_SQL_SERVER, REPORTING_SQL_PORT, - REPORTING_SQL_DATABASE) - return create_engine(mysql_uri, pool_recycle=3600) - - -def run_mongo_query(collection: Collection, query: dict, project: dict = None, sort: bool = True, - sort_field: str = 'eventTime', - limit_results: bool = False, limit_count: int = 10): - results = [] - if project is not None: - db_query = collection.find(query, project) - else: - db_query = collection.find(query) - if sort: - db_query.sort(sort_field, DESCENDING) - if limit_results: - db_query.limit(limit_count) - for doc in db_query: - results.append(doc) - - results_df = pd.DataFrame(list(results)) - return results_df - - -def run_mongo_query_agr(collection: Collection, query: dict): - results = collection.aggregate(query, cursor={}) - results_df = pd.DataFrame(list(results)) - return results_df - - -def create_mysql_table(sql_client, q, tbl_name): - try: - sql_client.execute(q) - return 0 - except SQLAlchemyError as e: - error = str(e.__dict__['orig']) - return error - - -def init_aludata_leak_reporting_table(client): - print('Creating table... ' + REPORTING_AULDATALEAK_TABLENAME) - - reportingTableCreateQuery = f'CREATE TABLE IF NOT EXISTS {REPORTING_AULDATALEAK_TABLENAME} ( \ - `SUBSCRIBERID` VARCHAR(100), \ - `MDN` VARCHAR(100), \ - `BAN` VARCHAR(100), \ - `USAGESTART` DATETIME, \ - `USAGEEND` DATETIME, \ - `TOTALMB` DECIMAL, \ - `AUDITDATE` DATETIME \ - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;' - - reportingTableCreateIndex = f'CREATE INDEX idx_AUDITDATE \ - ON {REPORTING_AULDATALEAK_TABLENAME} (AUDITDATE);' - - create_mysql_table(client, reportingTableCreateQuery, REPORTING_AULDATALEAK_TABLENAME) - create_mysql_table(client, reportingTableCreateIndex, REPORTING_AULDATALEAK_TABLENAME) - - -def get_auldata_subscribers(auditRangeStart: datetime, auditRangeEnd: datetime): - auditClient = get_mongo_client( - mongoServers=AUDIT_SERVER, - mongoReplicaset=AUDIT_REPLICASET, - username=AUDIT_USERNAME, - password=AUDIT_PASSWORD)[ARC_AUDIT_DATABASE] - auditCollection = auditClient[AUDIT_COLLECTION] - - # print(auditRangeStart.strftime('%Y-%m-%dT%H:%M:%SZ')) - # print(auditRangeEnd.strftime('%Y-%m-%dT%H:%M:%SZ')) - - auditQuery = [ - { - "$match": { - "$and": [ - { - "details": { - "$elemMatch": { - "state": "ADD", - "data.payload.payloads": { - "$elemMatch": { - "requestpayload.subscriptions": { - "$elemMatch": { - "offerName": "MYOFFERNAME" - } - } - } - } - } - } - }, - { - "lastModifiedDate": { - "$gte": auditRangeStart, - "$lte": auditRangeEnd - } - } - ] - } - }, - { - "$unwind": { - "path": "$details" - } - }, - { - "$match": { - "details.state": "ADD", - "details.data.payload.payloads": { - "$elemMatch": { - "requestpayload.subscriptions": { - "$elemMatch": { - "offerName": "MYOFFERNAME" - } - } - } - } - } - }, - { - "$unwind": { - "path": "$details.data.payload.payloads" - } - }, - { - "$unwind": { - "path": "$details.data.payload.payloads.requestpayload.subscriptions" - } - }, - { - "$project": { - "_id": 0.0, - "ban": 1.0, - "subscriberId": "$details.data.payload.subscriberId", - "effectiveDate": "$details.data.payload.payloads.requestpayload.subscriptions.effectiveDate", - "expiryDate": "$details.data.payload.payloads.requestpayload.subscriptions.expiryDate" - } - } - ] - - return run_mongo_query_agr(auditCollection, auditQuery) - - -def run_compare_on_node(node: str, subList): - auditDate = datetime.today().strftime('%Y-%m-%d %H:%M:%S') - arcUsageServer = "" - arcUsageReplicaset = "" - - if node == "A": - arcUsageServer = SERVER_A - arcUsageReplicaset = REPLICASET_A - elif node == "B": - arcUsageServer = SERVER_B - arcUsageReplicaset = REPLICASET_B - elif node == "C": - arcUsageServer = SERVER_C - arcUsageReplicaset = REPLICASET_C - - if len(subList) > 0: - usageClient = get_mongo_client( - mongoServers=arcUsageServer, - mongoReplicaset=arcUsageReplicaset, - username=USERNAME, password=PASSWORD)[ARC_USAGE_DATABASE] usageCollection = usageClient[COLLECTION] - usageResult = pd.DataFrame(columns = ['extSubId', 'MDN', 'BAN', 'start', 'end', 'bytesIn', 'bytesOut']) - - for subscriber in subList: - effectiveDate = datetime.strptime(subscriber["effectiveDate"], '%Y-%m-%dT%H:%M:%SZ').astimezone(pytz.timezone('US/Eastern')) - expiryDate = datetime.strptime(subscriber["expiryDate"], '%Y-%m-%dT%H:%M:%SZ').astimezone(pytz.timezone('US/Eastern')) - - usageQuery = {"$and": [ - {"end": {"$gte": effectiveDate, "$lte": expiryDate}}, - {"extSubId": eval(subscriber["subscriberId"])}, - {"usageType": "OVER"}, - {"$or": [{"bytesIn": {"$gt": 0}, "bytesOut": {"$gt": 0}}]} - ]} - usageProject = {"_id": 0, "extSubId": 1, "MDN": 1, "BAN": 1, "start": 1, "end": 1, "bytesIn": 1, "bytesOut": 1} - queryResult = run_mongo_query(usageCollection, usageQuery, usageProject) - usageResult = pd.concat([usageResult, queryResult], axis=0) - - if len(usageResult) > 0: - usageResultReportingQuery = f"INSERT INTO {REPORTING_AULDATALEAK_TABLENAME} (SUBSCRIBERID, MDN, BAN, USAGESTART, USAGEEND, TOTALMB, AUDITDATE) VALUES " - for index, row in usageResult.iterrows(): - usageResultReportingQuery = usageResultReportingQuery + f"(\'{row['extSubId']}\', {row['MDN']}, {row['BAN']}, \'{row['start']}\', \'{row['end']}\', \'{int(row['bytesIn']) + int(row['bytesOut'])}\', \'{auditDate}\')," - usageResultReportingQuery = usageResultReportingQuery[:-1] - reportingClient.execute(usageResultReportingQuery) - print(usageResult.size + " rows written to " + REPORTING_AULDATALEAK_TABLENAME) - -def compare(auldataSubs): - subListA = [] - subListB = [] - subListC = [] - - for index, row in auldataSubs.iterrows(): - remainder = int(row["ban"]) % 3 - if remainder == 0: - subListA.append(row) - elif remainder == 1: - subListB.append(row) - elif remainder == 2: - subListC.append(row) - - run_compare_on_node("A", subListA) - run_compare_on_node("B", subListB) - run_compare_on_node("C", subListC) - - -def aludata_leak_reporting_table_cleanup(client): - reportingTableDeleteQuery = f"DELETE FROM {REPORTING_AULDATALEAK_TABLENAME} WHERE AUDITDATE < DATE_SUB(NOW(), INTERVAL 1 MONTH)" - client.execute(reportingTableDeleteQuery) - - -if __name__ == '__main__': - reportingClient = connect_to_mysql() - init_aludata_leak_reporting_table(reportingClient) - auditDate = date.today() - timedelta(days=1) - auditRangeStart = (datetime.combine(auditDate, time(0, 0, 0))) - auditRangeEnd = (datetime.combine(auditDate, time(23, 59, 59))) - - auldataSubs = get_auldata_subscribers(auditRangeStart, auditRangeEnd) - compare(auldataSubs) - aludata_leak_reporting_table_cleanup(reportingClient) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..3a5dde6 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,175 @@ +"""Unit tests for the important functions on app.main""" +import unittest +from unittest.mock import MagicMock, patch +from datetime import datetime +from pandas import DataFrame + +from app.main import ( + init_auldata_leak_reporting_table, + get_auldata_subscribers, + compare_auldata, + run_compare_on_node, + cleanup_auldata_leak_reporting_table, +) + + +class TestInitAuldataLeakReportingTable(unittest.TestCase): + """Test the init_auldata_leak_reporting_table method""" + + def test_init_auldata_leak_reporting_table(self): + """Cross check with execute_query method calls""" + # Create a mock MySQLClient object + mock_client = MagicMock() + + # Call the function with the mock client + init_auldata_leak_reporting_table(mock_client) + + # Assert that the execute_query method was called twice with the expected arguments + self.assertEqual(mock_client.execute_query.call_count, 2) + self.assertIn( + "CREATE TABLE IF NOT EXISTS", + mock_client.execute_query.call_args_list[0][0][0], + ) + self.assertIn( + "CREATE INDEX idx_AUDITDATE", + mock_client.execute_query.call_args_list[1][0][0], + ) + + +class TestGetAuldataSubscribers(unittest.TestCase): + """Test the get_auldata_subscribers method""" + + def test_get_auldata_subscribers(self): + """Cross check with run_mongo_query_agr method calls""" + # Create a mock MongoCollection object + mock_collection = MagicMock() + + # Set the return value of the run_mongo_query_agr method + mock_collection.run_mongo_query_agr.return_value = "mock result" + + # Define the audit range start and end datetimes + audit_range_start = datetime(2022, 1, 1, 0, 0, 0) + audit_range_end = datetime(2022, 1, 1, 23, 59, 59) + + # Call the function with the mock collection and audit range start and end datetimes + result = get_auldata_subscribers( + mock_collection, audit_range_start, audit_range_end + ) + + # Assert that the run_mongo_query_agr method was called once with + # the expected arguments + self.assertEqual(mock_collection.run_mongo_query_agr.call_count, 1) + self.assertIn( + "$match", mock_collection.run_mongo_query_agr.call_args_list[0][0][0][0] + ) + + # Assert that the result is equal to the expected value + self.assertEqual(result, "mock result") + + +class TestCompareAuldata(unittest.TestCase): + """Test the compare_auldata method""" + + def test_compare_auldata(self): + """Cross check with run_compare_on_node method calls""" + # Create a mock MySQLClient object + mock_client = MagicMock() + + # Create a mock DataFrame object + mock_subs = DataFrame({"ban": [0, 1, 2]}) + + # Create a mock run_compare_on_node function + mock_run_compare_on_node = MagicMock() + + # Patch the run_compare_on_node function with the mock function + with patch("app.main.run_compare_on_node", mock_run_compare_on_node): + # Call the compare_auldata function with the mock subs and client + compare_auldata(mock_subs, mock_client) + + # Assert that the run_compare_on_node function was called three + # times with the expected arguments + self.assertEqual(mock_run_compare_on_node.call_count, 3) + self.assertEqual(mock_run_compare_on_node.call_args_list[0][0][0], "A") + self.assertEqual(mock_run_compare_on_node.call_args_list[1][0][0], "B") + self.assertEqual(mock_run_compare_on_node.call_args_list[2][0][0], "C") + + +class TestRunCompareOnNode(unittest.TestCase): + """Test the run_compare_on_node method""" + + def test_run_compare_on_node(self): + """Cross check with get_usage_collection method calls""" + # Create a mock MySQLClient object + mock_client = MagicMock() + + # Create a mock DataFrame object + mock_data = { + "effectiveDate": ["2022-01-01T00:00:00Z", "2022-01-02T00:00:00Z"], + "expiryDate": ["2022-01-01T23:59:59Z", "2022-01-02T23:59:59Z"], + "subscriberId": ["{'$eq': 'sample1'}", "{'$eq': 'sample2'}"], + } + mock_subs = DataFrame(mock_data) + + # Create a mock get_usage_collection function + mock_get_usage_collection = MagicMock() + + # Set the return value of the run_mongo_query method of the mock collection + mock_collection = MagicMock() + mock_collection.run_mongo_query.return_value = DataFrame() + mock_get_usage_collection.return_value = mock_collection + + # Patch the get_usage_collection function with the mock function + with patch("app.main.get_usage_collection", mock_get_usage_collection): + # Call the run_compare_on_node function with the node, mock subs, and mock client + run_compare_on_node("A", mock_subs, mock_client) + + # Assert that the get_usage_collection function was called once + # with the expected arguments + self.assertEqual(mock_get_usage_collection.call_count, 1) + self.assertEqual(mock_get_usage_collection.call_args_list[0][0][0], "A") + + # Assert that the execute_query method of the mock client was not called + self.assertEqual(mock_client.execute_query.call_count, 0) + + def test_run_compare_on_node_empty(self): + """Cross check with get_usage_collection method calls + but with an empty Subscribers List""" + # Create a mock MySQLClient object + mock_client = MagicMock() + + # Create a mock empty DataFrame object + mock_subs = DataFrame() + + # Create a mock get_usage_collection function + mock_get_usage_collection = MagicMock() + + # Set the return value of the run_mongo_query method of the mock collection + mock_collection = MagicMock() + mock_collection.run_mongo_query.return_value = DataFrame() + mock_get_usage_collection.return_value = mock_collection + + # Patch the get_usage_collection function with the mock function + with patch("app.main.get_usage_collection", mock_get_usage_collection): + # Call the run_compare_on_node function with the node, mock subs, + # and mock client + run_compare_on_node("A", mock_subs, mock_client) + + # Assert that the get_usage_collection function was not called + self.assertEqual(mock_get_usage_collection.call_count, 0) + + +class TestAuldataLeakReportingTableCleanup(unittest.TestCase): + """Test the cleanup_auldata_leak_reporting_table method""" + + def test_auldata_leak_reporting_table_cleanup(self): + """Cross check with execute_query method calls""" + # Create a mock MySQLClient object + mock_client = MagicMock() + + # Call the auldata_leak_reporting_table_cleanup function with the mock client + cleanup_auldata_leak_reporting_table(mock_client) + + # Assert that the execute_query method of the mock client was called once + # with the expected arguments + self.assertEqual(mock_client.execute_query.call_count, 1) + self.assertIn("DELETE FROM", mock_client.execute_query.call_args_list[0][0][0])