Skip to content

Commit fbc6ee5

Browse files
committed
Clean up shuffle.py args and defaults - deliberate compatibility break
1 parent 4013629 commit fbc6ee5

6 files changed

Lines changed: 40 additions & 76 deletions

File tree

python/selfplay/distributed/download_and_upload_and_shuffle_and_export_loop.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ cp -r "$GITROOTDIR"/python/selfplay "$DATED_ARCHIVE"
9090
sleep 10
9191

9292
echo "BEGINNING SHUFFLE------------------------------"
93-
./shuffle.sh "$basedir" "$tmpdir" "$NTHREADS" "$BATCHSIZE" -summary-file "$basedir"/selfplay.summary.json "$@"
93+
./shuffle.sh "$basedir" "$tmpdir" "$NTHREADS" -summary-file "$basedir"/selfplay.summary.json "$@"
9494
sleep "$SHUFFLEPERIOD"
9595
done
9696
fi

python/selfplay/shuffle.sh

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@ set -o pipefail
44
#Shuffles and copies selfplay training from selfplay/ to shuffleddata/current/
55
#Should be run periodically.
66

7-
if [[ $# -lt 4 ]]
7+
if [[ $# -lt 3 ]]
88
then
9-
echo "Usage: $0 BASEDIR TMPDIR NTHREADS BATCHSIZE"
9+
echo "Usage: $0 BASEDIR TMPDIR NTHREADS"
1010
echo "Currently expects to be run from within the 'python' directory of the KataGo repo, or otherwise in the same dir as shuffle.py."
1111
echo "BASEDIR containing selfplay data and models and related directories"
1212
echo "TMPDIR scratch space, ideally on fast local disk, unique to this loop"
1313
echo "NTHREADS number of parallel threads/processes to use in shuffle"
14-
echo "BATCHSIZE number of samples to concat together per batch for training"
1514
exit 0
1615
fi
1716
BASEDIR="$1"
@@ -20,8 +19,6 @@ TMPDIR="$1"
2019
shift
2120
NTHREADS="$1"
2221
shift
23-
BATCHSIZE="$1"
24-
shift
2522

2623
#------------------------------------------------------------------------------
2724

@@ -50,10 +47,9 @@ then
5047
-out-tmp-dir "$TMPDIR"/train \
5148
-approx-rows-per-out-file 70000 \
5249
-num-processes "$NTHREADS" \
53-
-batch-size "$BATCHSIZE" \
50+
-keep-target-rows 20000000 \
5451
-only-include-md5-path-prop-lbound 0.00 \
5552
-only-include-md5-path-prop-ubound 1.00 \
56-
-output-npz \
5753
"$@" \
5854
2>&1 | tee "$BASEDIR"/shuffleddata/"$OUTDIR".tmp/outtrain.txt &
5955

@@ -70,10 +66,9 @@ else
7066
-out-tmp-dir "$TMPDIR"/val \
7167
-approx-rows-per-out-file 70000 \
7268
-num-processes "$NTHREADS" \
73-
-batch-size "$BATCHSIZE" \
69+
-keep-target-rows 20000000 \
7470
-only-include-md5-path-prop-lbound 0.95 \
7571
-only-include-md5-path-prop-ubound 1.00 \
76-
-output-npz \
7772
"$@" \
7873
2>&1 | tee "$BASEDIR"/shuffleddata/"$OUTDIR".tmp/outval.txt &
7974

@@ -88,10 +83,9 @@ else
8883
-out-tmp-dir "$TMPDIR"/train \
8984
-approx-rows-per-out-file 70000 \
9085
-num-processes "$NTHREADS" \
91-
-batch-size "$BATCHSIZE" \
86+
-keep-target-rows 20000000 \
9287
-only-include-md5-path-prop-lbound 0.00 \
9388
-only-include-md5-path-prop-ubound 0.95 \
94-
-output-npz \
9589
"$@" \
9690
2>&1 | tee "$BASEDIR"/shuffleddata/"$OUTDIR".tmp/outtrain.txt &
9791

python/selfplay/shuffle_and_export_loop.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ cp -r "$GITROOTDIR"/python/muon "$DATED_ARCHIVE"
4848
cd "$basedir"/scripts
4949
while true
5050
do
51-
./shuffle.sh "$basedir" "$tmpdir" "$NTHREADS" "$BATCHSIZE" "$@"
51+
./shuffle.sh "$basedir" "$tmpdir" "$NTHREADS" "$@"
5252
sleep 20
5353
done
5454
) >> "$basedir"/logs/outshuffle.txt 2>&1 & disown

python/selfplay/shuffle_loop.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ cp -r "$GITROOTDIR"/python/muon "$DATED_ARCHIVE"
6060

6161
for i in {1..10}
6262
do
63-
./shuffle.sh "$basedir" "$tmpdir" "$NTHREADS" "$BATCHSIZE" -summary-file "$basedir"/selfplay.summary.json "$@"
63+
./shuffle.sh "$basedir" "$tmpdir" "$NTHREADS" -summary-file "$basedir"/selfplay.summary.json "$@"
6464
sleep 600
6565
done
6666
done

python/selfplay/synchronous_loop.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ do
102102
(
103103
# Skip validate since peeling off 5% of data is actually a bit too chunky and discrete when running at a small scale, and validation data
104104
# doesn't actually add much to debugging a fast-changing RL training.
105-
time SKIP_VALIDATE=1 ./shuffle.sh "$BASEDIR" "$SCRATCHDIR" "$NUM_THREADS_FOR_SHUFFLING" "$BATCHSIZE" -min-rows "$SHUFFLE_MINROWS" -keep-target-rows "$SHUFFLE_KEEPROWS" -taper-window-scale "$TAPER_WINDOW_SCALE" | tee -a "$BASEDIR"/logs/outshuffle.txt
105+
time SKIP_VALIDATE=1 ./shuffle.sh "$BASEDIR" "$SCRATCHDIR" "$NUM_THREADS_FOR_SHUFFLING" -min-rows "$SHUFFLE_MINROWS" -keep-target-rows "$SHUFFLE_KEEPROWS" -taper-window-scale "$TAPER_WINDOW_SCALE" | tee -a "$BASEDIR"/logs/outshuffle.txt
106106
)
107107

108108
echo "Train"

python/shuffle.py

Lines changed: 31 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -264,45 +264,33 @@ def write_one_output_file(
264264
arrs: Sequence[np.ndarray | None],
265265
out_file_start: int,
266266
out_file_stop: int,
267-
batch_size: int,
268-
ensure_batch_multiple: int,
269267
include_meta: bool,
270268
include_qvalues: bool,
271269
):
272270
"""Write rows [out_file_start, out_file_stop) of arrs to one output npz + json metadata
273-
274-
Truncates to a whole multiple of (batch_size * ensure_batch_multiple) rows, dropping
275-
the leftover partial batch at the end. Returns the number of rows actually written.
271+
Returns the number of rows written.
276272
"""
277273
num_rows = out_file_stop - out_file_start
278274

279-
# Just truncate and lose the batch at the end, it's fine
280-
num_batches = (num_rows // (batch_size * ensure_batch_multiple)) * ensure_batch_multiple
281-
282-
start = out_file_start
283-
stop = out_file_start + num_batches*batch_size
284-
285275
save_output_npz(
286276
filename=filename,
287277
arrs=arrs,
288278
include_meta=include_meta,
289279
include_qvalues=include_qvalues,
290-
start=start,
291-
stop=stop,
280+
start=out_file_start,
281+
stop=out_file_stop,
292282
)
293283

294284
jsonfilename = os.path.splitext(filename)[0] + ".json"
295285
with open(jsonfilename,"w") as f:
296-
json.dump({"num_rows":num_rows,"num_batches":num_batches},f)
286+
json.dump({"num_rows":num_rows},f)
297287

298-
return num_batches * batch_size
288+
return num_rows
299289

300290
def merge_bucket(
301291
out_filenames: list[str],
302292
num_shards_to_merge: int,
303293
out_tmp_dir: str,
304-
batch_size: int,
305-
ensure_batch_multiple: int,
306294
include_meta: bool,
307295
include_qvalues: bool
308296
):
@@ -365,8 +353,6 @@ def merge_bucket(
365353
arrs=concatenated_arrs,
366354
out_file_start=out_file_start,
367355
out_file_stop=out_file_stop,
368-
batch_size=batch_size,
369-
ensure_batch_multiple=ensure_batch_multiple,
370356
include_meta=include_meta,
371357
include_qvalues=include_qvalues,
372358
)
@@ -582,8 +568,6 @@ def run_two_phase_shuffle(
582568
bucket_tmp_dir,
583569
worker_group_size,
584570
keep_prob,
585-
batch_size,
586-
ensure_batch_multiple,
587571
include_meta,
588572
include_qvalues,
589573
fill_in_qvalues,
@@ -636,7 +620,7 @@ def run_two_phase_shuffle(
636620
os.path.join(out_dir, "%s%d_%d.npz" % (out_file_prefix, b, j))
637621
for j in range(num_out_files_per_bucket)
638622
],
639-
num_shards_to_merge, bucket_tmp_dirs[b], batch_size, ensure_batch_multiple,
623+
num_shards_to_merge, bucket_tmp_dirs[b],
640624
include_meta, include_qvalues
641625
)
642626
for b in range(num_buckets)
@@ -737,16 +721,13 @@ def __exit__(self, exception_type, exception_val, trace):
737721
If you want to control the "scale" of the power law differently than the min rows, you can specify -taper-window-scale as well.
738722
There is also a bit of a hack to cap the number of random rows (rows generated by random play without a neural net), since random row generation at the start of a run can be very fast due to not hitting the GPU, and overpopulate the run.
739723
740-
Additionally, NOT all of the shuffled window is output, only a random shuffled 20M rows will be kept. Adjust this using -keep-target-rows. The intention is that this script will be repeatedly run as new data comes in, such that well before train.py would need more than 20M rows, the data would have been shuffled again and a new random 20M rows chosen.
724+
Additionally, NOT all of the shuffled window need be output: -keep-target-rows controls how many rows are randomly sampled and kept (pass 'all' to keep the whole window). For ongoing self-play training the intention is that this script is rerun as new data comes in, such that well before train.py would need more than -keep-target-rows rows, the data would have been reshuffled and a fresh random sample chosen.
741725
742-
If you are NOT doing ongoing self-play training, but simply want to shuffle an entire dataset (not just a window of it) and want to output all of it (not just 20M of it) then you can use arguments like:
743-
-taper-window-exponent 1.0 \\
744-
-expand-window-per-row 1.0 \\
745-
-keep-target-rows SOME_VERY_LARGE_NUMBER
726+
If you are NOT doing ongoing self-play training, but simply want to shuffle an entire dataset (not just a window of it) and output all of it, the default window args already select the whole dataset, so you just need:
727+
-keep-target-rows all
746728
747729
If you ARE doing ongoing self-play training, but want a fixed window size, then you can use arguments like:
748730
-min-rows YOUR_DESIRED_SIZE \\
749-
-taper-window-exponent 1.0 \\
750731
-expand-window-per-row 0.0
751732
752733
==================================================================
@@ -766,7 +747,7 @@ def __exit__(self, exception_type, exception_val, trace):
766747
--dry-run-print-resource-cost NUM_DATASET_ROWS
767748
which assumes the dataset has NUM_DATASET_ROWS total rows and prints estimates instead of shuffling.
768749
""")
769-
parser.add_argument('dirs', metavar='DIR', nargs='+', help='Directories of training data files')
750+
parser.add_argument('dirs', metavar='DIR', nargs='*', help='Directories of training data files (not required in --dry-run-print-resource-cost mode)')
770751

771752
required_args = parser.add_argument_group('required arguments')
772753
optional_args = parser.add_argument_group('optional arguments')
@@ -777,24 +758,21 @@ def __exit__(self, exception_type, exception_val, trace):
777758
default=argparse.SUPPRESS,
778759
help='show this help message and exit'
779760
)
780-
optional_args.add_argument('-min-rows', type=int, required=False, help='Minimum training rows to use, default 250k')
781-
optional_args.add_argument('-max-rows', type=int, required=False, help='Maximum training rows to use, default unbounded')
782-
optional_args.add_argument('-keep-target-rows', type=int, required=False, help='Target number of rows to actually keep in the final data set, default 20M')
783-
required_args.add_argument('-expand-window-per-row', type=float, required=True, help='Beyond min rows, initially expand the window by this much every post-random data row')
784-
required_args.add_argument('-taper-window-exponent', type=float, required=True, help='Make the window size asymtotically grow as this power of the data rows')
761+
optional_args.add_argument('-min-rows', type=int, required=False, help='Minimum size of the desired training window, default 250k')
762+
optional_args.add_argument('-max-rows', type=int, required=False, help='Maximum size of the desired training window, default unbounded')
763+
required_args.add_argument('-keep-target-rows', required=True, help="Target number of rows to actually sample and keep in the final output shuffle, or 'all' to keep the whole window")
764+
optional_args.add_argument('-expand-window-per-row', type=float, required=False, default=1.0, help='Beyond min rows, initially expand the window by this much every post-random data row (default 1.0)')
765+
optional_args.add_argument('-taper-window-exponent', type=float, required=False, default=1.0, help='Make the window size asymtotically grow as this power of the data rows (default 1.0)')
785766
optional_args.add_argument('-taper-window-scale', type=float, required=False, help='The scale at which the power law applies, defaults to -min-rows')
786767
optional_args.add_argument('-add-to-data-rows', type=float, required=False, help='Compute the window size as if the number of data rows were this much larger/smaller')
787-
optional_args.add_argument('-add-to-window-size', type=float, required=False, help='DEPRECATED due to being misnamed name, use -add-to-data-rows')
788768
optional_args.add_argument('-summary-file', required=False, help='Summary json file for directory contents')
789-
required_args.add_argument('-out-dir', required=True, help='Dir to output training files')
790-
required_args.add_argument('-out-tmp-dir', required=True, help='Dir to use as scratch space')
769+
optional_args.add_argument('-out-dir', required=False, help='Dir to output training files (not required in --dry-run-print-resource-cost mode)')
770+
optional_args.add_argument('-out-tmp-dir', required=False, help='Dir to use as scratch space (not required in --dry-run-print-resource-cost mode)')
791771
optional_args.add_argument('-approx-rows-per-out-file', type=int, required=False, default=70000, help='Number of rows per output file, default 70k')
792772
optional_args.add_argument('-approx-rows-per-bucket', type=int, required=False, help='Each merge worker takes one whole bucket in RAM and splits it equally into output files. Bigger buckets means shard files. Must be a multiple of -approx-rows-per-out-file. Default: equal to -approx-rows-per-out-file.')
793773
optional_args.add_argument('-num-waves', type=int, required=False, default=1, help='If > 1, shuffle in this many waves to bound peak intermediate shard count and temp disk usage for very large (whole-dataset) shuffles. Default 1 (no waves).')
794774
optional_args.add_argument('--dry-run-print-resource-cost', type=int, required=False, metavar='NUM_DATASET_ROWS', help='Do not actually shuffle (or even scan the dataset). Assume the dataset has this many total rows, run the window-size / keep / md5-filter math, and print rough estimates of output files, peak intermediate shard count, peak temp disk usage, and peak memory. Assumes 19x19 data and typical measured per-row sizes.')
795775
required_args.add_argument('-num-processes', type=int, required=True, help='Number of multiprocessing processes for shuffling in parallel')
796-
required_args.add_argument('-batch-size', type=int, required=True, help='Batch size to write training examples in')
797-
optional_args.add_argument('-ensure-batch-multiple', type=int, required=False, help='Ensure each file is a multiple of this many batches')
798776
optional_args.add_argument('-worker-group-size', type=int, required=False, default=80000, help='Internally, target having many rows per parallel sharding worker (doesnt affect merge)')
799777
optional_args.add_argument('-exclude', required=False, help='Text file with npzs to ignore, one per line')
800778
optional_args.add_argument('-exclude-prefix', required=False, help='Prefix to concat to lines in exclude to produce the full file path')
@@ -803,24 +781,23 @@ def __exit__(self, exception_type, exception_val, trace):
803781
optional_args.add_argument('-only-include-md5-path-prop-ubound', type=float, required=False, help='Just before sharding, include only filepaths hashing to float < this')
804782
optional_args.add_argument('-skip-mtime-range-start', type=float, required=False, help='')
805783
optional_args.add_argument('-skip-mtime-range-end', type=float, required=False, help='')
806-
optional_args.add_argument('-output-npz', action="store_true", required=False, help='Output results as npz files')
807784
optional_args.add_argument('-include-meta', action="store_true", required=False, help='Include sgf metadata inputs')
808785
optional_args.add_argument('-exclude-qvalues', action="store_true", required=False, help='Exclude Q-value targets (for backwards compatibility with pre-v1.16)')
809786

810787
args = parser.parse_args()
811788
dirs = args.dirs
812789
min_rows = args.min_rows
813790
max_rows = args.max_rows
814-
keep_target_rows = args.keep_target_rows
791+
# -keep-target-rows is required, and accepts 'all' to mean "keep the whole window"
792+
# (represented internally as None, i.e. no cap).
793+
if str(args.keep_target_rows).lower() == "all":
794+
keep_target_rows = None
795+
else:
796+
keep_target_rows = int(args.keep_target_rows)
815797
expand_window_per_row = args.expand_window_per_row
816798
taper_window_exponent = args.taper_window_exponent
817799
taper_window_scale = args.taper_window_scale
818800
add_to_data_rows = args.add_to_data_rows
819-
if args.add_to_data_rows is not None and args.add_to_window_size is not None:
820-
print("Cannot specify both -add-to-data-rows and -add-to-window-size. Please use only -add-to-data-rows, -add-to-window-size is deprecated")
821-
if args.add_to_data_rows is None and args.add_to_window_size is not None:
822-
print("WARNING: -add-to-window-size is deprecated due to being misnamed, use -add-to-data-rows")
823-
add_to_data_rows = args.add_to_window_size
824801

825802
summary_file = args.summary_file
826803
out_dir = args.out_dir
@@ -843,10 +820,6 @@ def __exit__(self, exception_type, exception_val, trace):
843820
if num_waves < 1:
844821
raise ValueError("-num-waves must be >= 1")
845822
num_processes = args.num_processes
846-
batch_size = args.batch_size
847-
ensure_batch_multiple = 1
848-
if args.ensure_batch_multiple is not None:
849-
ensure_batch_multiple = args.ensure_batch_multiple
850823
worker_group_size = args.worker_group_size
851824
exclude = args.exclude
852825
exclude_prefix = args.exclude_prefix
@@ -857,21 +830,22 @@ def __exit__(self, exception_type, exception_val, trace):
857830
only_include_md5_path_prop_ubound = args.only_include_md5_path_prop_ubound
858831
skip_mtime_range_start = args.skip_mtime_range_start
859832
skip_mtime_range_end = args.skip_mtime_range_end
860-
output_npz = args.output_npz
861833
include_meta = args.include_meta
862834
include_qvalues = not args.exclude_qvalues
863835
dry_run_print_resource_cost = args.dry_run_print_resource_cost
864836

865-
if not output_npz and dry_run_print_resource_cost is None:
866-
raise AssertionError("No longer supports outputting tensorflow data")
837+
# dirs / out-dir / out-tmp-dir are only needed for a real run, not for the dry run.
838+
if dry_run_print_resource_cost is None:
839+
if len(dirs) <= 0:
840+
raise ValueError("At least one input directory is required (except in --dry-run-print-resource-cost mode)")
841+
if out_dir is None:
842+
raise ValueError("-out-dir is required (except in --dry-run-print-resource-cost mode)")
843+
if out_tmp_dir is None:
844+
raise ValueError("-out-tmp-dir is required (except in --dry-run-print-resource-cost mode)")
867845

868846
if min_rows is None:
869847
print("NOTE: -min-rows was not specified, defaulting to requiring 250K rows before shuffling.")
870848
min_rows = 250000
871-
if keep_target_rows is None:
872-
print("NOTE: -keep-target-rows was not specified, defaulting to sampling a random 20M rows out of the computed window.")
873-
print("If you intended to shuffle the whole dataset instead, pass in -keep-target-rows <very large number>")
874-
keep_target_rows = 20000000
875849
if add_to_data_rows is None:
876850
add_to_data_rows = 0
877851

@@ -1234,8 +1208,6 @@ def num_usable_rows():
12341208
bucket_tmp_dir=bucket_tmp_dir,
12351209
worker_group_size=worker_group_size,
12361210
keep_prob=keep_prob,
1237-
batch_size=batch_size,
1238-
ensure_batch_multiple=ensure_batch_multiple,
12391211
include_meta=include_meta,
12401212
include_qvalues=include_qvalues,
12411213
fill_in_qvalues=True,
@@ -1316,8 +1288,6 @@ def num_usable_rows():
13161288
bucket_tmp_dir=bucket_tmp_dir,
13171289
worker_group_size=worker_group_size,
13181290
keep_prob=1.0, # keep_prob already applied in phase 1
1319-
batch_size=batch_size,
1320-
ensure_batch_multiple=ensure_batch_multiple,
13211291
include_meta=include_meta,
13221292
include_qvalues=include_qvalues,
13231293
fill_in_qvalues=False, # wave shards already contain qValueTargetsNCMove

0 commit comments

Comments
 (0)