From 77abb1d62b83679eca4b0aa7071b1e450d362219 Mon Sep 17 00:00:00 2001 From: yoonhyejin <0327jane@gmail.com> Date: Wed, 29 Jan 2025 12:13:31 +0900 Subject: [PATCH] init fresh scripts --- .../examples/ml/mlflow_dh_client.py | 443 ++++++++++++++++++ .../examples/ml/mlflow_dh_client_sample.py | 128 +++++ 2 files changed, 571 insertions(+) create mode 100644 metadata-ingestion/examples/ml/mlflow_dh_client.py create mode 100644 metadata-ingestion/examples/ml/mlflow_dh_client_sample.py diff --git a/metadata-ingestion/examples/ml/mlflow_dh_client.py b/metadata-ingestion/examples/ml/mlflow_dh_client.py new file mode 100644 index 00000000000000..f5581591e2fd84 --- /dev/null +++ b/metadata-ingestion/examples/ml/mlflow_dh_client.py @@ -0,0 +1,443 @@ +import logging +import time +from typing import Any, Dict, List, Optional, Union + +import datahub.metadata.schema_classes as models +from datahub.api.entities.dataset.dataset import Dataset +from datahub.emitter.mcp import MetadataChangeProposalWrapper +from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph +from datahub.metadata.com.linkedin.pegasus2avro.dataprocess import ( + DataProcessInstanceInput, + DataProcessInstanceOutput, +) +from datahub.metadata.schema_classes import ( + ChangeTypeClass, + DataProcessInstanceRunResultClass, + DataProcessRunStatusClass, +) +from datahub.metadata.urns import ( + ContainerUrn, + DataPlatformUrn, + MlModelGroupUrn, + MlModelUrn, + VersionSetUrn, +) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class MLflowDatahubClient: + """Client for creating and managing MLflow metadata in DataHub.""" + + def __init__( + self, + token: str, + server_url: str = "http://localhost:8080", + platform: str = "mlflow", + ) -> None: + """Initialize the MLflow DataHub client.""" + self.token = token + self.server_url = server_url + self.platform = platform + self.graph = DataHubGraph( + DatahubClientConfig( + server=server_url, + token=token, + extra_headers={"Authorization": f"Bearer {token}"}, + ) + ) + + def _create_timestamp( + self, timestamp: Optional[int] = None + ) -> models.TimeStampClass: + """Helper to create timestamp with current time if not provided""" + return models.TimeStampClass( + time=timestamp or int(time.time() * 1000), actor="urn:li:corpuser:datahub" + ) + + def _emit_mcps( + self, + mcps: Union[List[MetadataChangeProposalWrapper], MetadataChangeProposalWrapper], + ) -> None: + """Helper to emit MCPs with proper connection handling""" + if not isinstance(mcps, list): + mcps = [mcps] + with self.graph: + for mcp in mcps: + self.graph.emit(mcp) + + def _get_aspect( + self, entity_urn: str, aspect_type: Any, default_constructor: Any = None + ) -> Any: + """Helper to safely get an aspect with fallback""" + try: + return self.graph.get_aspect(entity_urn=entity_urn, aspect_type=aspect_type) + except Exception as e: + logger.warning(f"Could not fetch aspect for {entity_urn}: {e}") + return default_constructor() if default_constructor else None + + def _create_properties_class( + self, props_class: Any, props_dict: Optional[Dict[str, Any]] = None + ) -> Any: + """Helper to create properties class with provided values""" + if props_dict is None: + props_dict = {} + + filtered_props = {k: v for k, v in props_dict.items() if v is not None} + + if hasattr(props_class, "created"): + filtered_props.setdefault("created", self._create_timestamp()) + if hasattr(props_class, "lastModified"): + filtered_props.setdefault("lastModified", self._create_timestamp()) + + return props_class(**filtered_props) + + def _update_list_property( + self, existing_list: Optional[List[str]], new_item: str + ) -> List[str]: + """Helper to update a list property while maintaining uniqueness""" + items = set(existing_list if existing_list else []) + items.add(new_item) + return list(items) + + def _create_mcp( + self, + entity_urn: str, + aspect: Any, + entity_type: Optional[str] = None, + aspect_name: Optional[str] = None, + change_type: str = ChangeTypeClass.UPSERT, + ) -> MetadataChangeProposalWrapper: + """Helper to create an MCP with standard parameters""" + mcp_args = {"entityUrn": entity_urn, "aspect": aspect} + if entity_type: + mcp_args["entityType"] = entity_type + if aspect_name: + mcp_args["aspectName"] = aspect_name + mcp_args["changeType"] = change_type + return MetadataChangeProposalWrapper(**mcp_args) + + def _update_entity_properties( + self, + entity_urn: str, + aspect_type: Any, + updates: Dict[str, Any], + entity_type: str, + skip_properties: Optional[List[str]] = None, + ) -> None: + """Helper to update entity properties while preserving existing ones""" + existing_props = self._get_aspect(entity_urn, aspect_type, aspect_type) + skip_list = [] if skip_properties is None else skip_properties + props = self._copy_existing_properties(existing_props, skip_list) or {} + + for key, value in updates.items(): + if isinstance(value, str) and hasattr(existing_props, key): + existing_value = getattr(existing_props, key, []) + props[key] = self._update_list_property(existing_value, value) + else: + props[key] = value + + updated_props = self._create_properties_class(aspect_type, props) + mcp = self._create_mcp( + entity_urn, updated_props, entity_type, f"{entity_type}Properties" + ) + self._emit_mcps(mcp) + + def _copy_existing_properties( + self, existing_props: Any, skip_properties: Optional[List[str]] = None + ) -> Dict[str, Any]: + """Helper to copy existing properties while skipping specified ones""" + skip_list = [] if skip_properties is None else skip_properties + + internal_props = { + "ASPECT_INFO", + "ASPECT_NAME", + "ASPECT_TYPE", + "RECORD_SCHEMA", + } + skip_list.extend(internal_props) + + props: Dict[str, Any] = {} + if existing_props: + for prop in dir(existing_props): + if ( + prop.startswith("_") + or callable(getattr(existing_props, prop)) + or prop in skip_list + ): + continue + + value = getattr(existing_props, prop) + if value is not None: + props[prop] = value + + if hasattr(existing_props, "created"): + props.setdefault("created", self._create_timestamp()) + if hasattr(existing_props, "lastModified"): + props.setdefault("lastModified", self._create_timestamp()) + + return props + + def _create_run_event( + self, + status: str, + timestamp: int, + result: Optional[str] = None, + duration_millis: Optional[int] = None, + ) -> models.DataProcessInstanceRunEventClass: + """Helper to create run event with common parameters.""" + event_args: Dict[str, Any] = { + "timestampMillis": timestamp, + "status": status, + "attempt": 1, + } + + if result: + event_args["result"] = DataProcessInstanceRunResultClass( + type=result, nativeResultType=str(result) + ) + if duration_millis: + event_args["durationMillis"] = duration_millis + + return models.DataProcessInstanceRunEventClass(**event_args) + + def create_model_group( + self, + group_id: str, + properties: Optional[models.MLModelGroupPropertiesClass] = None, + **kwargs: Any, + ) -> str: + """Create an ML model group with either property class or kwargs.""" + model_group_urn = MlModelGroupUrn(platform=self.platform, name=group_id) + + if properties is None: + properties = self._create_properties_class( + models.MLModelGroupPropertiesClass, kwargs + ) + + mcp = self._create_mcp( + str(model_group_urn), properties, "mlModelGroup", "mlModelGroupProperties" + ) + self._emit_mcps(mcp) + logger.info(f"Created model group: {model_group_urn}") + return str(model_group_urn) + + def create_model( + self, + model_id: str, + version: str, + alias: Optional[str] = None, + properties: Optional[models.MLModelPropertiesClass] = None, + **kwargs: Any, + ) -> str: + """Create an ML model with either property classes or kwargs.""" + model_urn = MlModelUrn(platform=self.platform, name=model_id) + version_set_urn = VersionSetUrn( + id=f"mlmodel_{model_id}_versions", entity_type="mlModel" + ) + + # Handle model properties + if properties is None: + # If no properties provided, create from kwargs + properties = self._create_properties_class( + models.MLModelPropertiesClass, kwargs + ) + + # Ensure version is set in model properties + version_tag = models.VersionTagClass(versionTag=str(version)) + properties.version = version_tag + + # Create version properties + version_props = { + "version": version_tag, + "versionSet": str(version_set_urn), + "sortId": "AAAAAAAA", + } + + # Add alias if provided + if alias: + version_props["aliases"] = [models.VersionTagClass(versionTag=alias)] + + version_properties = self._create_properties_class( + models.VersionPropertiesClass, version_props + ) + + # Create version set properties + version_set_properties = models.VersionSetPropertiesClass( + latest=str(model_urn), + versioningScheme="ALPHANUMERIC_GENERATED_BY_DATAHUB", + ) + + mcps = [ + self._create_mcp( + str(model_urn), properties, "mlModel", "mlModelProperties" + ), + self._create_mcp( + str(version_set_urn), + version_set_properties, + "versionSet", + "versionSetProperties", + ), + self._create_mcp( + str(model_urn), version_properties, "mlModel", "versionProperties" + ), + ] + self._emit_mcps(mcps) + logger.info(f"Created model: {model_urn}") + return str(model_urn) + + def create_experiment( + self, + experiment_id: str, + properties: Optional[models.ContainerPropertiesClass] = None, + **kwargs: Any, + ) -> str: + """Create an ML experiment with either property class or kwargs.""" + container_urn = ContainerUrn(guid=experiment_id) + platform_urn = DataPlatformUrn(platform_name=self.platform) + + if properties is None: + properties = self._create_properties_class( + models.ContainerPropertiesClass, kwargs + ) + + container_subtype = models.SubTypesClass(typeNames=["ML Experiment"]) + browse_path = models.BrowsePathsV2Class(path=[]) + platform_instance = models.DataPlatformInstanceClass(platform=str(platform_urn)) + + mcps = MetadataChangeProposalWrapper.construct_many( + entityUrn=str(container_urn), + aspects=[container_subtype, properties, browse_path, platform_instance], + ) + self._emit_mcps(mcps) + logger.info(f"Created experiment: {container_urn}") + return str(container_urn) + + def create_training_run( + self, + run_id: str, + properties: Optional[models.DataProcessInstancePropertiesClass] = None, + training_run_properties: Optional[models.MLTrainingRunPropertiesClass] = None, + run_result: Optional[str] = None, + start_timestamp: Optional[int] = None, + end_timestamp: Optional[int] = None, + **kwargs: Any, + ) -> str: + """Create a training run with properties and events.""" + dpi_urn = f"urn:li:dataProcessInstance:{run_id}" + + # Create basic properties and aspects + aspects = [ + ( + properties + or self._create_properties_class( + models.DataProcessInstancePropertiesClass, kwargs + ) + ), + models.SubTypesClass(typeNames=["ML Training Run"]), + ] + + # Add training run properties if provided + if training_run_properties: + aspects.append(training_run_properties) + + # Handle run events + current_time = int(time.time() * 1000) + start_ts = start_timestamp or current_time + end_ts = end_timestamp or current_time + + # Create events + aspects.append( + self._create_run_event( + status=DataProcessRunStatusClass.STARTED, timestamp=start_ts + ) + ) + + if run_result: + aspects.append( + self._create_run_event( + status=DataProcessRunStatusClass.COMPLETE, + timestamp=end_ts, + result=run_result, + duration_millis=end_ts - start_ts, + ) + ) + + # Create and emit MCPs + mcps = [self._create_mcp(dpi_urn, aspect) for aspect in aspects] + self._emit_mcps(mcps) + logger.info(f"Created training run: {dpi_urn}") + return dpi_urn + + def create_dataset(self, name: str, platform: str, **kwargs: Any) -> str: + """Create a dataset with flexible properties.""" + dataset = Dataset(id=name, platform=platform, name=name, **kwargs) + mcps = list(dataset.generate_mcp()) + self._emit_mcps(mcps) + if dataset.urn is None: + raise ValueError(f"Failed to create dataset URN for {name}") + return dataset.urn + + def add_run_to_model(self, model_urn: str, run_urn: str) -> None: + """Add a run to a model while preserving existing properties.""" + self._update_entity_properties( + entity_urn=model_urn, + aspect_type=models.MLModelPropertiesClass, + updates={"trainingJobs": run_urn}, + entity_type="mlModel", + skip_properties=["trainingJobs"], + ) + logger.info(f"Added run {run_urn} to model {model_urn}") + + def add_run_to_model_group(self, model_group_urn: str, run_urn: str) -> None: + """Add a run to a model group while preserving existing properties.""" + self._update_entity_properties( + entity_urn=model_group_urn, + aspect_type=models.MLModelGroupPropertiesClass, + updates={"trainingJobs": run_urn}, + entity_type="mlModelGroup", + skip_properties=["trainingJobs"], + ) + logger.info(f"Added run {run_urn} to model group {model_group_urn}") + + def add_model_to_model_group(self, model_urn: str, group_urn: str) -> None: + """Add a model to a group while preserving existing properties""" + self._update_entity_properties( + entity_urn=model_urn, + aspect_type=models.MLModelPropertiesClass, + updates={"groups": group_urn}, + entity_type="mlModel", + skip_properties=["groups"], + ) + logger.info(f"Added model {model_urn} to group {group_urn}") + + def add_run_to_experiment(self, run_urn: str, experiment_urn: str) -> None: + """Add a run to an experiment""" + mcp = self._create_mcp( + entity_urn=run_urn, aspect=models.ContainerClass(container=experiment_urn) + ) + self._emit_mcps(mcp) + logger.info(f"Added run {run_urn} to experiment {experiment_urn}") + + def add_input_datasets_to_run(self, run_urn: str, dataset_urns: List[str]) -> None: + """Add input datasets to a run""" + mcp = self._create_mcp( + entity_urn=run_urn, + entity_type="dataProcessInstance", + aspect_name="dataProcessInstanceInput", + aspect=DataProcessInstanceInput(inputs=dataset_urns), + ) + self._emit_mcps(mcp) + logger.info(f"Added input datasets to run {run_urn}") + + def add_output_datasets_to_run(self, run_urn: str, dataset_urns: List[str]) -> None: + """Add output datasets to a run""" + mcp = self._create_mcp( + entity_urn=run_urn, + entity_type="dataProcessInstance", + aspect_name="dataProcessInstanceOutput", + aspect=DataProcessInstanceOutput(outputs=dataset_urns), + ) + self._emit_mcps(mcp) + logger.info(f"Added output datasets to run {run_urn}") diff --git a/metadata-ingestion/examples/ml/mlflow_dh_client_sample.py b/metadata-ingestion/examples/ml/mlflow_dh_client_sample.py new file mode 100644 index 00000000000000..867f118fa88392 --- /dev/null +++ b/metadata-ingestion/examples/ml/mlflow_dh_client_sample.py @@ -0,0 +1,128 @@ +import argparse + +from mlflow_dh_client import MLflowDatahubClient + +import datahub.metadata.schema_classes as models +from datahub.metadata.com.linkedin.pegasus2avro.dataprocess import RunResultType + +if __name__ == "__main__": + # Example usage + parser = argparse.ArgumentParser() + parser.add_argument("--token", required=True, help="DataHub access token") + args = parser.parse_args() + + client = MLflowDatahubClient(token=args.token) + + # Create model group + model_group_urn = client.create_model_group( + group_id="airline_forecast_models_group", + properties=models.MLModelGroupPropertiesClass( + name="Airline Forecast Models Group", + description="Group of models for airline passenger forecasting", + created=models.TimeStampClass( + time=1628580000000, actor="urn:li:corpuser:datahub" + ), + ), + ) + + # Creating a model with property classes + model_urn = client.create_model( + model_id="arima_model", + properties=models.MLModelPropertiesClass( + name="ARIMA Model", + description="ARIMA model for airline passenger forecasting", + customProperties={"team": "forecasting"}, + trainingMetrics=[ + models.MLMetricClass(name="accuracy", value="0.9"), + models.MLMetricClass(name="precision", value="0.8"), + ], + hyperParams=[ + models.MLHyperParamClass(name="learning_rate", value="0.01"), + models.MLHyperParamClass(name="batch_size", value="32"), + ], + externalUrl="https:localhost:5000", + created=models.TimeStampClass( + time=1628580000000, actor="urn:li:corpuser:datahub" + ), + lastModified=models.TimeStampClass( + time=1628580000000, actor="urn:li:corpuser:datahub" + ), + tags=["forecasting", "arima"], + ), + version="1.0", + alias="champion", + ) + + # Creating an experiment with property class + experiment_urn = client.create_experiment( + experiment_id="airline_forecast_experiment", + properties=models.ContainerPropertiesClass( + name="Airline Forecast Experiment", + description="Experiment to forecast airline passenger numbers", + customProperties={"team": "forecasting"}, + created=models.TimeStampClass( + time=1628580000000, actor="urn:li:corpuser:datahub" + ), + lastModified=models.TimeStampClass( + time=1628580000000, actor="urn:li:corpuser:datahub" + ), + ), + ) + + run_urn = client.create_training_run( + run_id="simple_training_run", + properties=models.DataProcessInstancePropertiesClass( + name="Simple Training Run", + created=models.AuditStampClass( + time=1628580000000, actor="urn:li:corpuser:datahub" + ), + customProperties={"team": "forecasting"}, + ), + training_run_properties=models.MLTrainingRunPropertiesClass( + id="simple_training_run", + outputUrls=["s3://my-bucket/output"], + trainingMetrics=[models.MLMetricClass(name="accuracy", value="0.9")], + hyperParams=[models.MLHyperParamClass(name="learning_rate", value="0.01")], + externalUrl="https:localhost:5000", + ), + run_result=RunResultType.FAILURE, + start_timestamp=1628580000000, + end_timestamp=1628580001000, + ) + # Create datasets + input_dataset_urn = client.create_dataset( + platform="snowflake", + name="iris_input", + ) + + output_dataset_urn = client.create_dataset( + platform="snowflake", + name="iris_ouptut", + ) + + # Add run to experiment + client.add_run_to_experiment(run_urn=run_urn, experiment_urn=experiment_urn) + + # Add model to model group + client.add_model_to_model_group(model_urn=model_urn, group_urn=model_group_urn) + + # Add run to model + client.add_run_to_model( + model_urn=model_urn, + run_urn=run_urn, + ) + + # add run to model group + client.add_run_to_model_group( + model_group_urn=model_group_urn, + run_urn=run_urn, + ) + + # Add input and output datasets to run + client.add_input_datasets_to_run( + run_urn=run_urn, dataset_urns=[str(input_dataset_urn)] + ) + + client.add_output_datasets_to_run( + run_urn=run_urn, dataset_urns=[str(output_dataset_urn)] + )