From 0f7c5b4b759309b54050f66c8d59020546ca96fc Mon Sep 17 00:00:00 2001 From: HZY <41548394+hzyhhzy@users.noreply.github.com> Date: Wed, 13 Nov 2024 23:59:38 +0800 Subject: [PATCH] Loading the next npz file while yielding the current file when training It will reduce read and write bottlenecks and accelerate small net training --- python/data_processing_pytorch.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/python/data_processing_pytorch.py b/python/data_processing_pytorch.py index 55fcc585a..c47b55fe5 100644 --- a/python/data_processing_pytorch.py +++ b/python/data_processing_pytorch.py @@ -8,6 +8,9 @@ import modelconfigs +import threading +import concurrent.futures + def read_npz_training_data( npz_files, batch_size: int, @@ -24,7 +27,11 @@ def read_npz_training_data( num_global_features = modelconfigs.get_num_global_input_features(model_config) (h_base,h_builder) = build_history_matrices(model_config, device) - for npz_file in npz_files: + #create loading file thread + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + future = None + + def load_npz_file(npz_file): with np.load(npz_file) as npz: binaryInputNCHWPacked = npz["binaryInputNCHWPacked"] globalInputNC = npz["globalInputNC"] @@ -47,6 +54,19 @@ def read_npz_training_data( assert binaryInputNCHW.shape[1] == num_bin_features assert globalInputNC.shape[1] == num_global_features + return [binaryInputNCHW, globalInputNC, policyTargetsNCMove, globalTargetsNC, scoreDistrN, valueTargetsNCHW, metadataInputNC if include_meta else None] + + #read the first file + future = executor.submit(load_npz_file, npz_files[0]) + npz_files.append("") + npz_files=npz_files[1:] + for npz_file in npz_files: + binaryInputNCHW, globalInputNC, policyTargetsNCMove, globalTargetsNC, scoreDistrN, valueTargetsNCHW, metadataInputNC = future.result() + + if npz_file != "": + future = executor.submit(load_npz_file, npz_file) + + num_samples = binaryInputNCHW.shape[0] # Just discard stuff that doesn't divide evenly num_whole_steps = num_samples // (batch_size * world_size)