diff --git a/docs/source/input_output.rst b/docs/source/input_output.rst index c8b81d02..55bea018 100644 --- a/docs/source/input_output.rst +++ b/docs/source/input_output.rst @@ -107,3 +107,8 @@ the entities, with the first dimension being the number of entities and the second being the dimension of the embedding. Just like for the model parameters file, the optimizer state dict and additional metadata is also included. + +HDFS Format +^^^^^^^^^^ + +Include the prefix ``hdfs://`` in entities, edges and checkpoint paths when running in distributed hdfs cluster. diff --git a/test/resources/edges_0_0.h5 b/test/resources/edges_0_0.h5 new file mode 100644 index 00000000..5b47bf5f Binary files /dev/null and b/test/resources/edges_0_0.h5 differ diff --git a/test/resources/invalidFile.h5 b/test/resources/invalidFile.h5 new file mode 100644 index 00000000..e69de29b diff --git a/test/resources/text.txt b/test/resources/text.txt new file mode 100644 index 00000000..e69de29b diff --git a/test/test_storage_manager.py b/test/test_storage_manager.py new file mode 100644 index 00000000..5b7655d1 --- /dev/null +++ b/test/test_storage_manager.py @@ -0,0 +1,257 @@ +import shutil +import tempfile +from contextlib import AbstractContextManager +from io import TextIOWrapper +from pathlib import Path +from unittest import TestCase, main +import h5py + +from torchbiggraph.storage_repository import CUSTOM_PATH +from torchbiggraph.storage_repository import LocalPath, HDFSPath, HDFSFileContextManager, LocalFileContextManager +from torchbiggraph.util import run_external_cmd, url_scheme + +HDFS_TEST_PATH = '' + + +def _touch_file(name: str): + file_path = HDFS_TEST_PATH + "/" + name + run_external_cmd("hadoop fs -touchz " + file_path) + return file_path + + +class TestLocalFileContextManager(TestCase): + def setUp(self): + self.resource_dir = Path(__file__).parent.absolute() / 'resources' + + def test_get_resource_valid_h5(self): + filepath_h5 = self.resource_dir / 'edges_0_0.h5' + file_path = str(filepath_h5) + self.assertIs(type(LocalFileContextManager.get_resource(file_path, 'r')), h5py.File) + + def test_get_resource_invalid_h5(self): + filepath_h5 = self.resource_dir / 'invalidFile.h5' + file_path = str(filepath_h5) + + with self.assertRaises(ValueError): + LocalFileContextManager.get_resource(str(file_path), 'r') + + def test_get_resource_valid_text_file(self): + filepath_txt = self.resource_dir / 'text.txt' + file_path = str(filepath_txt) + self.assertIs(type(LocalFileContextManager.get_resource(str(file_path), 'r')), TextIOWrapper) + + +class TestHDFSFileContextManager(TestCase): + def setUp(self): + if not HDFSFileContextManager.hdfs_file_exists(HDFS_TEST_PATH): + self.skipTest('skipped test due to skip_tests_flag') + + self.resource_dir = Path(__file__).parent.absolute() / 'resources' + run_external_cmd("hadoop fs -mkdir -p " + HDFS_TEST_PATH) + + def tearDown(self): + run_external_cmd("hadoop fs -rm -r " + HDFS_TEST_PATH) + + def test_prepare_hdfs_path(self): + actual = HDFSFileContextManager.get_hdfs_path(Path.cwd() / '/some/path') + expected = '/some/path' + self.assertEqual(str(expected), actual) + + def test_hdfs_file_exists(self): + valid_path = _touch_file('abc') + self.assertTrue(HDFSFileContextManager.hdfs_file_exists(valid_path)) + + def test_hdfs_file_doesnt_exists(self): + invalid_path = HDFS_TEST_PATH + "/invalid_loc" + self.assertFalse(HDFSFileContextManager.hdfs_file_exists(invalid_path)) + + def test_get_from_hdfs_valid(self): + valid_hdfs_file = _touch_file('valid.file') + local_file = Path(str(Path.cwd()) + valid_hdfs_file) + file_ctx = HDFSFileContextManager(local_file, 'r') + + # valid path + file_ctx.get_from_hdfs(reload=True) + self.assertTrue(Path(file_ctx._path).exists()) + + def test_get_from_hdfs_valid_dont_reload(self): + valid_hdfs_file = _touch_file('valid.file') + local_file = Path(str(Path.cwd()) + valid_hdfs_file) + file_ctx = HDFSFileContextManager(local_file, 'r') + + # valid path + file_ctx.get_from_hdfs(reload=False) + self.assertTrue(Path(file_ctx._path).exists()) + + def test_get_from_hdfs_invalid(self): + invalid_hdfs_file = Path('./' + HDFS_TEST_PATH + "/invalid_loc").resolve() + file_ctx = HDFSFileContextManager(invalid_hdfs_file, 'r') + + # invalid path + with self.assertRaises(FileNotFoundError): + file_ctx.get_from_hdfs(reload=True) + + def test_put_to_hdfs(self): + local_file_name = 'test_local.file' + local_file = Path(str(Path.cwd()) + HDFS_TEST_PATH + '/' + local_file_name) + file_ctx = HDFSFileContextManager(local_file, 'w') + + # clean up local + if local_file.exists(): + local_file.unlink() + + # invalid path + with self.assertRaises(FileNotFoundError): + file_ctx.put_to_hdfs() + + # create local file + local_file.touch() + file_ctx.put_to_hdfs() + self.assertTrue(HDFSFileContextManager.hdfs_file_exists(HDFS_TEST_PATH + '/' + local_file_name)) + + +class TestLocalPath(TestCase): + def setUp(self): + self.resource_dir = Path(__file__).parent.absolute() / 'resources' + + def test_init(self): + path = LocalPath(Path.cwd()) + self.assertIs(type(path), LocalPath) + + path = LocalPath('some/path') + self.assertIs(type(path), LocalPath) + + def test_stem_suffix(self): + path = LocalPath('some/path/name.txt') + self.assertTrue(path.stem == 'name') + self.assertTrue(path.suffix == '.txt') + self.assertIsInstance(path.stem, str) + + def test_name(self): + path = LocalPath('some/path/name') + self.assertTrue(path.name == 'name') + self.assertIsInstance(path.name, str) + + def test_resolve(self): + path = LocalPath('some/path/name') + actual = path.resolve(strict=False) + expected = Path.cwd() / Path(str(path)) + self.assertTrue(str(actual) == str(expected)) + + def test_exists(self): + invalid_path = LocalPath('some/path/name') + self.assertFalse(invalid_path.exists()) + + valid_path = LocalPath(Path(__file__)) + self.assertTrue(valid_path.exists()) + + def test_append_path(self): + path = LocalPath('/some/path/name') + actual = path / 'storage_manager.py' + expected = '/some/path/name/storage_manager.py' + self.assertTrue(str(actual) == expected) + + def test_open(self): + file_path = Path(__file__) + with file_path.open('r') as fh: + self.assertGreater(len(fh.readlines()), 0) + + def test_mkdir(self): + path = LocalPath(self.resource_dir) + path.parent.mkdir(parents=True, exist_ok=True) + + def test_with_plugin_empty_scheme(self): + local_path = '/some/path/file.txt' + actual_path = CUSTOM_PATH.get_class(url_scheme(local_path))(local_path) + expected_path = '/some/path/file.txt' + self.assertEqual(str(actual_path), str(expected_path)) + + def test_with_plugin_file_scheme(self): + local_path = 'file:///some/path/file.txt' + actual_path = CUSTOM_PATH.get_class(url_scheme(local_path))(local_path) + expected_path = '/some/path/file.txt' + self.assertEqual(str(expected_path), str(actual_path)) + + +class TestHDFSDataPath(TestCase): + + def setUp(self): + if not HDFSFileContextManager.hdfs_file_exists(HDFS_TEST_PATH): + self.skipTest('skipped test due to skip_tests_flag') + + self.resource_dir = Path(__file__).parent.absolute() / 'resources' + run_external_cmd("hadoop fs -mkdir -p " + HDFS_TEST_PATH) + + def tearDown(self): + run_external_cmd("hadoop fs -rm -r " + HDFS_TEST_PATH) + + def test_delete_valid(self): + valid_path = _touch_file('abc.txt') + local_temp_dir = str(Path.cwd()) + '/' + 'axp' + + # create resolved path based on the hdfs path + remote_path = HDFSPath(valid_path).resolve(strict=False) + remote_path.parent.mkdir(parents=True, exist_ok=True) + remote_path.touch() + remote_path.unlink() + + # remove local path + shutil.rmtree(local_temp_dir, ignore_errors=True) + + def test_delete_invalid(self): + invalid_path = HDFSPath(HDFS_TEST_PATH + '/invalid.file') + with self.assertRaises(FileNotFoundError): + invalid_path.unlink() + + def test_open(self): + filepath_h5 = self.resource_dir / 'edges_0_0.h5' + hdfs = HDFSPath(filepath_h5).resolve(strict=False) + with hdfs.open('r') as fh: + self.assertEqual(len(fh.keys()), 3) + self.assertIsInstance(fh, AbstractContextManager) + + def test_open_reload_False(self): + filepath_h5 = self.resource_dir / 'edges_0_0.h5' + hdfs = HDFSPath(filepath_h5).resolve(strict=False) + with hdfs.open('r', reload=False) as fh: + self.assertEqual(len(fh.keys()), 3) + self.assertIsInstance(fh, AbstractContextManager) + + def test_name(self): + hdfs = HDFSPath('/some/path/file.txt') + self.assertEqual(hdfs.name, 'file.txt') + + def test_with_plugin(self): + hdfs_path = 'hdfs:///some/path/file.txt' + actual_path = CUSTOM_PATH.get_class(url_scheme(hdfs_path))(hdfs_path).resolve(strict=False) + expected_path = Path.cwd() / 'some/path/file.txt' + self.assertEqual(str(actual_path), str(expected_path)) + + def test_append_path(self): + path = HDFSPath('/some/path/name') + actual = path.resolve(strict = False) / 'storage_manager.py' + expected = str(Path.cwd() / 'some/path/name/storage_manager.py') + self.assertEqual(expected, str(actual)) + + def test_stem_suffix(self): + path = HDFSPath('some/path/name.txt') + self.assertTrue(path.stem == 'name') + self.assertTrue(path.suffix == '.txt') + self.assertIsInstance(path.stem, str) + + def test_cleardir(self): + # create empty files + + tempdir = tempfile.mkdtemp() + Path(tempdir + 'file1.txt').touch() + Path(tempdir + 'file2.txt').touch() + Path(tempdir + 'file3.txt').touch() + + dirpath = HDFSPath(tempdir) + dirpath.cleardir() + + self.assertFalse(any(dirpath.iterdir())) + + +if __name__ == "__main__": + main() diff --git a/torchbiggraph/checkpoint_storage.py b/torchbiggraph/checkpoint_storage.py index 9df28e88..e84d6d25 100644 --- a/torchbiggraph/checkpoint_storage.py +++ b/torchbiggraph/checkpoint_storage.py @@ -10,7 +10,6 @@ import logging import os from abc import ABC, abstractmethod -from pathlib import Path from typing import Any, Dict, Generator, List, NamedTuple, Optional, Tuple import h5py @@ -18,7 +17,8 @@ import torch from torchbiggraph.plugin import URLPluginRegistry from torchbiggraph.types import EntityName, FloatTensorType, ModuleStateDict, Partition -from torchbiggraph.util import CouldNotLoadData, allocate_shared_tensor +from torchbiggraph.util import CouldNotLoadData, allocate_shared_tensor, url_scheme +from torchbiggraph.storage_repository import CUSTOM_PATH, AbstractPath as Path logger = logging.getLogger("torchbiggraph") @@ -208,6 +208,7 @@ def process_dataset(public_name, dataset) -> None: @CHECKPOINT_STORAGES.register_as("") # No scheme @CHECKPOINT_STORAGES.register_as("file") +@CHECKPOINT_STORAGES.register_as("hdfs") class FileCheckpointStorage(AbstractCheckpointStorage): """Reads and writes checkpoint data to/from disk. @@ -241,9 +242,8 @@ class FileCheckpointStorage(AbstractCheckpointStorage): """ def __init__(self, path: str) -> None: - if path.startswith("file://"): - path = path[len("file://") :] - self.path: Path = Path(path).resolve(strict=False) + self.path: Path = CUSTOM_PATH.get_class(url_scheme(path))(path).resolve(strict=False) + self.prepare() def get_version_file(self, *, path: Optional[Path] = None) -> Path: if path is None: @@ -319,7 +319,7 @@ def save_entity_partition( ) -> None: path = self.get_entity_partition_file(version, entity_name, partition) logger.debug(f"Saving to {path}") - with h5py.File(path, "w") as hf: + with path.open("w") as hf: hf.attrs[FORMAT_VERSION_ATTR] = FORMAT_VERSION for k, v in metadata.items(): hf.attrs[k] = v @@ -338,11 +338,13 @@ def load_entity_partition( path = self.get_entity_partition_file(version, entity_name, partition) logger.debug(f"Loading from {path}") try: - with h5py.File(path, "r") as hf: + with path.open("r") as hf: if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION: raise RuntimeError(f"Version mismatch in embeddings file {path}") embs = load_embeddings(hf, out=out) optim_state = load_optimizer_state_dict(hf) + except FileNotFoundError as err: + raise CouldNotLoadData() from err except OSError as err: # h5py refuses to make it easy to figure out what went wrong. The errno # attribute is set to None. See https://github.com/h5py/h5py/issues/493. @@ -368,7 +370,7 @@ def save_model( ) -> None: path = self.get_model_file(version) logger.debug(f"Saving to {path}") - with h5py.File(path, "w") as hf: + with path.open("w") as hf: hf.attrs[FORMAT_VERSION_ATTR] = FORMAT_VERSION for k, v in metadata.items(): hf.attrs[k] = v @@ -383,11 +385,13 @@ def load_model( path = self.get_model_file(version) logger.debug(f"Loading from {path}") try: - with h5py.File(path, "r") as hf: + with path.open("r") as hf: if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION: raise RuntimeError(f"Version mismatch in model file {path}") state_dict = load_model_state_dict(hf) optim_state = load_optimizer_state_dict(hf) + except FileNotFoundError as err: + raise CouldNotLoadData() from err except OSError as err: # h5py refuses to make it easy to figure out what went wrong. The errno # attribute is set to None. See https://github.com/h5py/h5py/issues/493. diff --git a/torchbiggraph/examples/configs/distributedCluster_config.py b/torchbiggraph/examples/configs/distributedCluster_config.py new file mode 100644 index 00000000..d0596199 --- /dev/null +++ b/torchbiggraph/examples/configs/distributedCluster_config.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 + +def get_torchbiggraph_config(): + + config = dict( # noqa + # I/O data + entity_path='hdfs://>', # set entity_path + edge_paths=['hdfs://'], # set edge_path + checkpoint_path='hdfs://', # set checkpoint_path + # Graph structure + entities={"all": {"num_partitions": 20}}, + relations=[ + { + "name": "all_edges", + "lhs": "all", + "rhs": "all", + "operator": "complex_diagonal", + } + ], + dynamic_relations=True, + verbose=1, + # Scoring model + dimension=100, + batch_size=1000, + workers=10, + global_emb=False, + # Training + num_epochs=25, + num_machines=10, + num_uniform_negs=100, + num_batch_negs=50, + comparator='cos', + loss_fn='softmax', + distributed_init_method='env://', + lr=0.02, + eval_fraction=0.01 # to reproduce results we need to use all training data + ) + + return config + diff --git a/torchbiggraph/examples/distributedCluster/GraphTrainingPoc.scala b/torchbiggraph/examples/distributedCluster/GraphTrainingPoc.scala new file mode 100644 index 00000000..b247d3cb --- /dev/null +++ b/torchbiggraph/examples/distributedCluster/GraphTrainingPoc.scala @@ -0,0 +1,169 @@ +// compile this scala code and use to run spark submit job + +import java.net._ +import java.nio.file.{Files, Paths} + +import org.apache.spark.SparkContext +import org.apache.spark.sql.SparkSession + +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.duration.Duration +import scala.concurrent.{Await, Future} +import scala.sys.process.Process + +object GraphTrainingPoc { + + val ZIP_FILE: String = sys.env("ZIP_FILE") //'torchbiggraph' + val TRAIN_WRAPPER: String = sys.env("TRAIN_WRAPPER") // "train_wrapper.py" + val PBG_CONFIG: String = sys.env("PBG_CONFIG") // "distributedCluster_config.py" + val NUM_MACHINES: Int = sys.env("NUM_MACHINES").toInt + + def main(args: Array[String]) { + + val num_machines = NUM_MACHINES + + val spark = SparkSession.builder().enableHiveSupport().getOrCreate() + implicit val sparkContext: SparkContext = spark.sparkContext + + //Step 1: Get host and port of Driver + val (driverHost: String, driverPort: Int) = getHostAndPort + + //Step 2: Ask resource manager for required number of executors to run all workers + requestExecutorsAndWait(num_machines) + + //Step 3: execute PBG process on driver + // Driver will wait until all trainer not complete process group + val driverFuture = Future { + executePbgProcess(0, driverHost, driverPort) + } + + //Step 4: execute PBG process on each worker trainers i.e on each partition + sparkContext + .makeRDD((1 until num_machines).toList, 1) + .repartition(num_machines - 1) + .mapPartitionsWithIndex( + (index: Int, iterator: Iterator[Int]) => { + executePbgProcess(index + 1, driverHost, driverPort) + iterator + } + ).collect() + + Await.result(driverFuture, Duration.Inf) + } + + def extractZip(targetPath: String, zipName: String, logPrefix: Option[String] = None): Boolean = { + + val prefix = logPrefix.getOrElse("") + + val containerPath = Paths.get(".").toAbsolutePath.toString + print(s"${prefix}ContainerPath - $containerPath") + print(s"${prefix}Check if path exists - $targetPath/$zipName") + + val pathAlreadyFound = Files.exists(Paths.get(s"$targetPath/$zipName")) + + // unzip contents if path doesnt exists + if (pathAlreadyFound) { + print(s"${prefix}Folder $zipName already exists. Skip unzip..") + } + else { + print(s"${prefix}Folder $zipName not found. Unzip $zipName.zip") + + val pzip = Process(s"unzip $targetPath/$zipName.zip -d $targetPath") + if (pzip.! != 0) { + val errorMsg = s"${prefix}Error doing unzip of $zipName.zip" + print(errorMsg) + throw new Exception(errorMsg) + } + + print(s"${prefix}Unzip $zipName successful..") + } + + true + } + + /** + * This function is used to get Host and port of Trainer 0 which will be executed on driver + */ + def getHostAndPort: (String, Int) = { + // get ip address and port of the machine + val localhost: String = InetAddress.getLocalHost.getHostAddress + var port = -1 + try { + val s = new ServerSocket(0) + port = s.getLocalPort + s.close() + } + catch { + case e: Exception => print(e.getMessage) + } + + (localhost, port) + } + + /** + * This function is used get executor from resource manager + */ + def requestExecutorsAndWait(numExecutors: Int)(implicit sparkContext: SparkContext): Boolean = { + + val workerExecCount = numExecutors.toInt - 1 + if (sparkContext.requestTotalExecutors(workerExecCount, 0, Map.empty)) { + + var counter = 1 + + while (sparkContext.getExecutorStorageStatus.length - 1 < workerExecCount) { + print(s"Waiting for $workerExecCount executors. Iter - $counter") + Thread.sleep(60000) + counter += 1 + } + + if (counter > 15) { + val errorMsg = s"Unable to get the required executors in 15 minutes " + + s"- Expected: $workerExecCount, Got: ${sparkContext.getExecutorStorageStatus.length}" + print(errorMsg) + throw new Exception(errorMsg) + } + } else { + + return false + } + + print(s"Got requested executors. Count - $workerExecCount") + true + } + + def executePbgProcess(rank: Int, driverHost: String, driverPort: Int): Int = { + + val executionPath = Paths.get(".").toAbsolutePath.toString + + // Setup driver machine, remove PBG zip code if exists and place fresh copy + extractZip(zipName = ZIP_FILE, targetPath = executionPath, logPrefix = Some(s"[Rank $rank] : ")) + + val command = "python3.7 " + executionPath + "/" + TRAIN_WRAPPER + " --rank " + rank.toString + " " + PBG_CONFIG + + print(s"[Rank $rank] : command: " + command) + val rCode = runCommand(command = command, + path = executionPath, + extraEnv = List("MASTER_ADDR" -> driverHost, + "MASTER_PORT" -> driverPort.toString, + "WORLD_SIZE" -> NUM_MACHINES.toString, + "PYTHONPATH" -> ".") + ) + rCode + } + + def runCommand(command: String, path: String, extraEnv: List[(String, String)] = List.empty): Int = { + val ps = Process(command = Seq("/bin/sh", "-c", command), + cwd = Paths.get(path).toFile, + extraEnv = extraEnv: _* + ) + + val proc = ps.run() + val exitValue = proc.exitValue() + + if (exitValue != 0) + throw new RuntimeException(s"Process exited with code $exitValue") + + exitValue + + } +} \ No newline at end of file diff --git a/torchbiggraph/examples/distributedCluster/start_spark.sh b/torchbiggraph/examples/distributedCluster/start_spark.sh new file mode 100644 index 00000000..9a4c68c1 --- /dev/null +++ b/torchbiggraph/examples/distributedCluster/start_spark.sh @@ -0,0 +1,47 @@ +EXAMPLES_PATH="...../torchbiggraph/examples" # set correct path of pbg codebase + +# compile the scala code and update below variables +JAR="" # set jar filename +CLASS_FILE="" # set scala class file + +ZIP_FLDR='torchbiggraph' +TRAIN_WRAPPER='train_wrapper.py' +PBG_CONFIG="distributedCluster_config.py" +NUM_MACHINES=10 + +# create a zip file and send it cluster as part of spark submit +cwd=`pwd` +cd ${EXAMPLES_PATH}/../.. +zip -r ${ZIP_FLDR}.zip ${ZIP_FLDR} + +# start spark +spark-submit --class ${CLASS_FILE} \ + --master yarn \ + --deploy-mode cluster \ + --queue entgraph \ + --driver-memory=25G \ + --executor-memory=25G \ + --conf spark.executor.memoryOverhead=3G \ + --conf spark.driver.memoryOverhead=3G \ + --conf spark.driver.cores=10 \ + --conf spark.executor.cores=9 \ + --conf spark.task.cpus=9 \ + --conf spark.yarn.maxAppAttempts=1 \ + --conf spark.task.maxFailures=1 \ + --conf spark.dynamicAllocation.enabled=false \ + --conf spark.yarn.nodemanager.vmem-check-enabled=false \ + --conf spark.locality.wait=0s \ + --conf spark.speculation=false \ + --conf spark.executorEnv.SPARK_YARN_USER_ENV=PYTHONHASHSEED=0 \ + --conf spark.yarn.appMasterEnv.ZIP_FILE=${ZIP_FLDR} \ + --conf spark.executorEnv.ZIP_FILE=${ZIP_FLDR} \ + --conf spark.yarn.appMasterEnv.TRAIN_WRAPPER=${TRAIN_WRAPPER} \ + --conf spark.executorEnv.TRAIN_WRAPPER=${TRAIN_WRAPPER} \ + --conf spark.yarn.appMasterEnv.PBG_CONFIG=${PBG_CONFIG} \ + --conf spark.executorEnv.PBG_CONFIG=${PBG_CONFIG} \ + --conf spark.yarn.appMasterEnv.NUM_MACHINES=${NUM_MACHINES} \ + --conf spark.executorEnv.NUM_MACHINES=${NUM_MACHINES} \ + --files ${ZIP_FLDR}.zip,${EXAMPLES_PATH}/distributedCluster/${TRAIN_WRAPPER},${EXAMPLES_PATH}/configs/${PBG_CONFIG} \ + ${JAR} + +cd $cwd diff --git a/torchbiggraph/examples/distributedCluster/train_wrapper.py b/torchbiggraph/examples/distributedCluster/train_wrapper.py new file mode 100644 index 00000000..7cf5e54f --- /dev/null +++ b/torchbiggraph/examples/distributedCluster/train_wrapper.py @@ -0,0 +1,18 @@ +import sys +import os +import subprocess +from torchbiggraph.train import main + + +print("MASTER_ADDR::", os.environ['MASTER_ADDR']) + +if __name__ == '__main__': + child = subprocess.Popen(['pgrep', '-f', 'train_wrapper.py'], stdout=subprocess.PIPE, shell=False) + response = child.communicate()[0] + running_procs = [int(pid) for pid in response.split()] + if len(running_procs) > 1: + for pid in running_procs: + print('Already running train_wrapper.py is '+ str(pid)) + + sys.exit(main()) + diff --git a/torchbiggraph/graph_storages.py b/torchbiggraph/graph_storages.py index 2e25d59d..90266c44 100644 --- a/torchbiggraph/graph_storages.py +++ b/torchbiggraph/graph_storages.py @@ -11,7 +11,6 @@ import logging from abc import ABC, abstractmethod from contextlib import contextmanager -from pathlib import Path from types import TracebackType from typing import ContextManager, Dict, Iterator, List, Optional, Type @@ -23,7 +22,8 @@ from torchbiggraph.plugin import URLPluginRegistry from torchbiggraph.tensorlist import TensorList from torchbiggraph.types import Partition -from torchbiggraph.util import CouldNotLoadData, allocate_shared_tensor, div_roundup +from torchbiggraph.util import CouldNotLoadData, allocate_shared_tensor, div_roundup, url_scheme +from torchbiggraph.storage_repository import CUSTOM_PATH, AbstractPath as Path logger = logging.getLogger("torchbiggraph") @@ -180,11 +180,10 @@ def load_names(path: Path) -> List[str]: @ENTITY_STORAGES.register_as("") # No scheme @ENTITY_STORAGES.register_as("file") +@ENTITY_STORAGES.register_as("hdfs") class FileEntityStorage(AbstractEntityStorage): def __init__(self, path: str) -> None: - if path.startswith("file://"): - path = path[len("file://") :] - self.path = Path(path).resolve(strict=False) + self.path = CUSTOM_PATH.get_class(url_scheme(path))(path).resolve(strict=False) def get_count_file(self, entity_name: str, partition: Partition) -> Path: return self.path / f"entity_count_{entity_name}_{partition}.txt" @@ -218,11 +217,10 @@ def load_names(self, entity_name: str, partition: Partition) -> List[str]: @RELATION_TYPE_STORAGES.register_as("") # No scheme @RELATION_TYPE_STORAGES.register_as("file") +@RELATION_TYPE_STORAGES.register_as("hdfs") class FileRelationTypeStorage(AbstractRelationTypeStorage): def __init__(self, path: str) -> None: - if path.startswith("file://"): - path = path[len("file://") :] - self.path = Path(path).resolve(strict=False) + self.path = CUSTOM_PATH.get_class(url_scheme(path))(path).resolve(strict=False) def get_count_file(self) -> Path: return self.path / "dynamic_rel_count.txt" @@ -367,6 +365,7 @@ def append_edges(self, edgelist: EdgeList) -> None: @EDGE_STORAGES.register_as("") # No scheme @EDGE_STORAGES.register_as("file") +@EDGE_STORAGES.register_as("hdfs") class FileEdgeStorage(AbstractEdgeStorage): """Reads partitioned edgelists from disk, in the format created by edge_downloader.py. @@ -378,9 +377,7 @@ class FileEdgeStorage(AbstractEdgeStorage): """ def __init__(self, path: str) -> None: - if path.startswith("file://"): - path = path[len("file://") :] - self.path = Path(path).resolve(strict=False) + self.path = CUSTOM_PATH.get_class(url_scheme(path))(path).resolve(strict=False) def get_edges_file(self, lhs_p: Partition, rhs_p: Partition) -> Path: return self.path / f"edges_{lhs_p}_{rhs_p}.h5" @@ -394,7 +391,7 @@ def has_edges(self, lhs_p: Partition, rhs_p: Partition) -> bool: def get_number_of_edges(self, lhs_p: Partition, rhs_p: Partition) -> int: file_path = self.get_edges_file(lhs_p, rhs_p) try: - with h5py.File(file_path, "r") as hf: + with file_path.open("r") as hf: if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION: raise RuntimeError(f"Version mismatch in edge file {file_path}") return hf["rel"].len() @@ -415,7 +412,7 @@ def load_chunk_of_edges( ) -> EdgeList: file_path = self.get_edges_file(lhs_p, rhs_p) try: - with h5py.File(file_path, "r") as hf: + with file_path.open("r", reload=False) as hf: if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION: raise RuntimeError(f"Version mismatch in edge file {file_path}") lhs_ds = hf["lhs"] @@ -486,3 +483,6 @@ def save_edges_by_appending( hf.attrs[FORMAT_VERSION_ATTR] = FORMAT_VERSION yield appender tmp_file_path.rename(file_path) + + def cleardir(self): + self.path.cleardir() diff --git a/torchbiggraph/storage_repository.py b/torchbiggraph/storage_repository.py new file mode 100644 index 00000000..78f8f787 --- /dev/null +++ b/torchbiggraph/storage_repository.py @@ -0,0 +1,255 @@ +import logging +import pathlib +from abc import abstractmethod +from contextlib import AbstractContextManager +from io import TextIOWrapper +from pathlib import Path +from types import TracebackType +from typing import IO +from typing import Optional, Type, Union + +import h5py + +from torchbiggraph.plugin import PluginRegistry +from torchbiggraph.util import url_path, run_external_cmd + +logger = logging.getLogger("torchbiggraph") + + +class Constants: + GET = "hadoop fs -get {remote_path} {local_path}" + PUT = "hadoop fs -put -f {local_path} {remote_path}" + TEST_FILE = "hadoop fs -test -e {remote_path}" + REMOVE = "hadoop fs -rm {remote_path}" + H5 = "h5" + RELOAD = 'reload' + WRITE_MODES = ['w', 'x', 'a'] + READ_MODE = 'r' + + +class LocalFileContextManager(AbstractContextManager): + + def __init__(self, path: Path, mode: str, **kwargs) -> None : + self._path: Path = path + self.mode = mode + self.kwargs = kwargs + + def __enter__(self) -> Union[h5py.File, TextIOWrapper]: + self._file = self.get_resource(str(self._path), self.mode) + return self._file + + def __exit__(self, exception_type: Optional[Type[BaseException]], + exception_value: Optional[BaseException], + traceback: Optional[TracebackType]) -> bool: + self._file.close() + self._file = None + return True + + @staticmethod + def get_resource(filepath: str, mode: str) -> Union[h5py.File, IO[str]]: + # get file handler + if filepath.split(".")[-1] == Constants.H5: + # check if the file being read is a valid h5 file + if 'r' in mode: + if Path(filepath).exists() and not h5py.is_hdf5(filepath): + raise ValueError('Invalid .h5 file', filepath) + return h5py.File(filepath, mode) + else: + return open(filepath, mode) + + +class HDFSFileContextManager(AbstractContextManager): + def __init__(self, path: Path, mode: str, **kwargs) -> None : + self._path: Path = path + self.mode = mode + self.kwargs = kwargs + + def __enter__(self) -> Union[h5py.File, TextIOWrapper]: + reload = True + if Constants.RELOAD in self.kwargs: + reload = self.kwargs[Constants.RELOAD] + + if Constants.READ_MODE in self.mode: + self.get_from_hdfs(reload) + + self._file = LocalFileContextManager.get_resource(str(self._path), self.mode) + return self._file + + + def __exit__(self, exception_type: Optional[Type[BaseException]], + exception_value: Optional[BaseException], + traceback: Optional[TracebackType]) -> bool: + self._file.close() + self._file = None + + if any(ext in self.mode for ext in Constants.WRITE_MODES): + self.put_to_hdfs() + + return True + + @staticmethod + def get_hdfs_path(path) -> str: + return str(path).replace(str(Path.cwd()), '', 1) + + @staticmethod + def hdfs_file_exists(path: str) -> bool: + rcode = 1 + try: + rcode, output, errors = run_external_cmd(Constants.TEST_FILE.format(remote_path=path)) + except: + pass + + return rcode == 0 + + def get_from_hdfs(self, reload): + local_path = self._path + hdfs_loc = self.get_hdfs_path(local_path) + + if hdfs_loc != str(local_path): + # check if hdfs file exists before running get command + if self.hdfs_file_exists(hdfs_loc): + local_path.parent.mkdir(parents=True, exist_ok=True) + + if local_path.exists(): + if reload: + local_path.unlink() + run_external_cmd(Constants.GET.format(remote_path=hdfs_loc, local_path=str(local_path))) + else: + logger.info(f"Skip get: reload is {reload} and local file exists : {local_path}") + else: + run_external_cmd(Constants.GET.format(remote_path=hdfs_loc, local_path=str(local_path))) + else: + raise FileNotFoundError(f"File {hdfs_loc} not found.") + else: + logger.info(f"identical local and hdfs path. Skipping get {str(local_path)}") + + def put_to_hdfs(self): + local_loc_str = str(self._path) + hdfs_loc = self.get_hdfs_path(local_loc_str) + + if self._path.exists(): + if hdfs_loc != local_loc_str: + run_external_cmd(Constants.PUT.format(local_path=local_loc_str, remote_path=hdfs_loc)) + else: + logger.info('identical local and hdfs path. Skipping put ..', local_loc_str) + else: + raise FileNotFoundError("File " + local_loc_str + " not found.") + + +class AbstractPath(type(pathlib.Path())): + @abstractmethod + def __init__(self, path: Union[str, Path]): + raise NotImplementedError + + @abstractmethod + def __truediv__(self, key: str) -> 'AbstractPath': + raise NotImplementedError + + @abstractmethod + def __str__(self) -> str: + raise NotImplementedError + + @property + @abstractmethod + def parent(self) -> 'AbstractPath': + raise NotImplementedError + + @abstractmethod + def resolve(self, strict: bool = False) -> 'AbstractPath': + raise NotImplementedError + + @abstractmethod + def open(self, mode='r', buffering=-1, encoding=None, + errors=None, newline=None, **kwargs): + pass + + @abstractmethod + def cleardir(self) -> None: + raise NotImplementedError + +CUSTOM_PATH = PluginRegistry[AbstractPath]() + +@CUSTOM_PATH.register_as("") +@CUSTOM_PATH.register_as("file") +class LocalPath(AbstractPath): + def __init__(self, path: Union[str, Path]): + _tmp_path = '' + if isinstance(path, str): + _tmp_path = path + elif isinstance(path, Path): + _tmp_path = str(path) + + self._path = Path(url_path(_tmp_path)) + + def __truediv__(self, key: str): + return LocalPath(self._path / key) + + def __str__(self): + return str(self._path) + + @property + def parent(self): + return LocalPath(self._path.parent) + + def resolve(self, strict: bool = False) -> 'LocalPath': + resolved_path = self._path.resolve(strict) + return LocalPath(resolved_path) + + def cleardir(self) -> None: + pass + + def open(self, mode='r', buffering=-1, encoding=None, + errors=None, newline=None, **kwargs): + return LocalFileContextManager(self._path, mode, **kwargs) + + +@CUSTOM_PATH.register_as("hdfs") +class HDFSPath(AbstractPath): + def __init__(self, path: Union[str, Path]): + _tmp_path = '' + if isinstance(path, str): + _tmp_path = path + elif isinstance(path, Path): + _tmp_path = str(path) + + self._path = Path(url_path(_tmp_path)) + + def __truediv__(self, key: str) -> 'HDFSPath': + return HDFSPath(self._path / key) + + def __str__(self): + return str(self._path) + + @property + def parent(self) -> 'HDFSPath': + return HDFSPath(self._path.parent) + + def resolve(self, strict: bool = False) -> 'HDFSPath': + resolved_path = Path('./' + str(self._path)).resolve(strict) + return HDFSPath(resolved_path) + + def cleardir(self) -> None: + + resolved_path = self._path.resolve(strict=True) + + if not resolved_path.is_dir(): + raise ValueError(f"Not a directory: {resolved_path}") + else: + logger.info(f"Clean directory : {resolved_path}") + for file in resolved_path.iterdir(): + if file.is_file(): + logger.info(f"Deleting file : {file}") + file.unlink() + + def unlink(self) -> None: + _hdfs_path = HDFSFileContextManager.get_hdfs_path(self._path) + + if HDFSFileContextManager.hdfs_file_exists(_hdfs_path): + run_external_cmd(Constants.REMOVE.format(remote_path=str(_hdfs_path))) + else: + logger.info('hdfs file not found. Skipping : {}'.format(_hdfs_path)) + self._path.unlink() + + def open(self, mode='r', buffering=-1, encoding=None, + errors=None, newline=None, **kwargs): + return HDFSFileContextManager(self._path, mode, **kwargs) \ No newline at end of file diff --git a/torchbiggraph/train_cpu.py b/torchbiggraph/train_cpu.py index b87f0349..e6e1723e 100644 --- a/torchbiggraph/train_cpu.py +++ b/torchbiggraph/train_cpu.py @@ -709,6 +709,8 @@ def train(self) -> None: epoch_idx, edge_path_idx, edge_chunk_idx, current_index ) + self._maybe_clear_edge_storage_dir(edge_storage) + # now we're sure that all partition files exist, # so be strict about loading them self.strict = True @@ -1005,3 +1007,7 @@ def _maybe_write_checkpoint( self.checkpoint_manager.preserve_current_version(config, epoch_idx + 1) if not preserve_old_checkpoint: self.checkpoint_manager.remove_old_version(config) + + def _maybe_clear_edge_storage_dir(self, edge_storage) -> None: + # clean edge storage path for scheme 'hdfs' and not local (i.e. '' or 'file') + edge_storage.cleardir() diff --git a/torchbiggraph/util.py b/torchbiggraph/util.py index 660ddc5d..550beb80 100644 --- a/torchbiggraph/util.py +++ b/torchbiggraph/util.py @@ -34,7 +34,8 @@ from torch.optim import Optimizer from torchbiggraph.config import ConfigSchema from torchbiggraph.types import Bucket, EntityName, FloatTensorType, Partition, Side - +import subprocess +from urllib.parse import urlparse logger = logging.getLogger("torchbiggraph") @@ -434,3 +435,24 @@ def get_num_workers(override: Optional[int]) -> int: f"couldn't be auto-detected; defaulting to {result} workers." ) return result + + +def url_path(url): + return urlparse(url).path + + +def url_scheme(url): + return urlparse(url).scheme + + +def run_external_cmd(command: str) -> Tuple[int, str, str]: + logger.info('run_external_cmd : {0}'.format(command)) + + args_list = command.split() + proc = subprocess.Popen(args_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + (output, errors) = proc.communicate() + if proc.returncode: + raise RuntimeError( + 'Error running command: %s. Return code: %d, Error: %s' % ( + ' '.join(args_list), proc.returncode, errors)) + return proc.returncode, str(output), str(errors)