From 6c9e5046e7d540ff8d3470f9065d9163d9d7459f Mon Sep 17 00:00:00 2001 From: Tom Bedor Date: Tue, 15 Apr 2025 16:13:55 -0700 Subject: [PATCH] skip multiprocessing if n_proc == 1 with mp, adding breakpoints results in bad exit errors. Being able to debug in the foreground makes it a bit easier --- pred.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/pred.py b/pred.py index 1c20cba9..ddf9df1b 100644 --- a/pred.py +++ b/pred.py @@ -138,14 +138,17 @@ def main(): if item["_id"] not in has_data: data.append(item) - data_subsets = [data[i::args.n_proc] for i in range(args.n_proc)] - processes = [] - for rank in range(args.n_proc): - p = mp.Process(target=get_pred, args=(data_subsets[rank], args, fout)) - p.start() - processes.append(p) - for p in processes: - p.join() + if args.n_proc == 1: + get_pred(data, args, fout) + else: + data_subsets = [data[i::args.n_proc] for i in range(args.n_proc)] + processes = [] + for rank in range(args.n_proc): + p = mp.Process(target=get_pred, args=(data_subsets[rank], args, fout)) + p.start() + processes.append(p) + for p in processes: + p.join() if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -156,4 +159,4 @@ def main(): parser.add_argument("--rag", "-rag", type=int, default=0) # set to 0 if RAG is not used, otherwise set to N when using top-N retrieved context parser.add_argument("--n_proc", "-n", type=int, default=16) args = parser.parse_args() - main() \ No newline at end of file + main()