diff --git a/pred.py b/pred.py index 1c20cba9..ddf9df1b 100644 --- a/pred.py +++ b/pred.py @@ -138,14 +138,17 @@ def main(): if item["_id"] not in has_data: data.append(item) - data_subsets = [data[i::args.n_proc] for i in range(args.n_proc)] - processes = [] - for rank in range(args.n_proc): - p = mp.Process(target=get_pred, args=(data_subsets[rank], args, fout)) - p.start() - processes.append(p) - for p in processes: - p.join() + if args.n_proc == 1: + get_pred(data, args, fout) + else: + data_subsets = [data[i::args.n_proc] for i in range(args.n_proc)] + processes = [] + for rank in range(args.n_proc): + p = mp.Process(target=get_pred, args=(data_subsets[rank], args, fout)) + p.start() + processes.append(p) + for p in processes: + p.join() if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -156,4 +159,4 @@ def main(): parser.add_argument("--rag", "-rag", type=int, default=0) # set to 0 if RAG is not used, otherwise set to N when using top-N retrieved context parser.add_argument("--n_proc", "-n", type=int, default=16) args = parser.parse_args() - main() \ No newline at end of file + main()