diff --git a/config/config.yaml b/config/config.yaml index b23a8822..f007deba 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -35,6 +35,9 @@ intervals: min_nmer: 500 num_gvcf_intervals: 50 db_scatter_factor: 0.15 + min_contig_length: 0 + db_max_intervals_per_shard: 200 + db_max_contigs_per_shard: 200 callable_sites: diff --git a/docs/config-fields-supported.md b/docs/config-fields-supported.md index bcfe739b..4696bdef 100644 --- a/docs/config-fields-supported.md +++ b/docs/config-fields-supported.md @@ -38,6 +38,9 @@ These are the supported v2 config keys used by the main workflow. Unknown keys a | `intervals.min_nmer` | integer | no | `500` | `>= 1` | | `intervals.num_gvcf_intervals` | integer | no | `50` | `>= 1` | | `intervals.db_scatter_factor` | number | no | `0.15` | `>= 0` | +| `intervals.min_contig_length` | integer | no | `0` | `>= 0`; excludes interval records on shorter contigs when nonzero | +| `intervals.db_max_intervals_per_shard` | integer | no | `200` | `>= 0`; `0` disables cap | +| `intervals.db_max_contigs_per_shard` | integer | no | `200` | `>= 0`; `0` disables cap | | `callable_sites.generate_bed_file` | boolean | no | `true` | Controls final callable BED target | | `callable_sites.coverage.enabled` | boolean | no | `true` | Requires BAM-backed samples if true | | `callable_sites.coverage.fraction` | number | no | `1.0` | `0..1` | diff --git a/docs/setup.md b/docs/setup.md index 9a0635e4..7e2b3644 100644 --- a/docs/setup.md +++ b/docs/setup.md @@ -118,7 +118,10 @@ The following options can be adjusted based on your needs and your dataset. | ---- | -------------| ------ | |`intervals.min_nmer`| Minimum span of Ns used to split reference for interval generation. | `int`| |`intervals.num_gvcf_intervals` | Maximum number of GVCF intervals to create. | `int`| -|`intervals.db_scatter_factor` | Used to calculate number of DB intervals (`num_db_intervals = scatter_factor * num_samples * num_gvcf_intervals`). | `float`| +|`intervals.db_scatter_factor` | Used to calculate the target number of DB intervals before complexity-aware post-processing (`num_db_intervals = scatter_factor * num_samples * num_gvcf_intervals`). | `float`| +|`intervals.db_max_intervals_per_shard` | Maximum interval records allowed in each final DB shard; set to `0` to disable this cap. | `int`| +|`intervals.db_max_contigs_per_shard` | Maximum unique contigs allowed in each final DB shard; set to `0` to disable this cap. | `int`| +|`intervals.min_contig_length` | Exclude interval records on contigs shorter than this length; `0` keeps all contigs. This changes called regions but does not filter the reference used for mapping. | `int`| | `variant_calling.expected_coverage` | Coverage profile used to set caller tuning (`low` or `high`). | `str` | | `variant_calling.ploidy` | Ploidy for variant calling step. | `int` | | `variant_calling.gatk.het_prior` | Heterozygosity prior passed to GATK GenotypeGVCFs. | `float` | @@ -138,6 +141,7 @@ When `variant_calling.tool` is `bcftools`, `deepvariant`, or `parabricks`, sampl `intervals.enabled` controls interval-split HaplotypeCaller for the GATK backend. Parabricks uses interval-split joint genotyping (GenomicsDBImport/GenotypeGVCFs) regardless of `intervals.enabled`. +snpArcher automatically reserves native memory for GenomicsDBImport and applies GATK contig merging to DB shards with many whole-contig intervals. Parabricks execution expects NVIDIA GPUs and an Apptainer/Singularity image path in `variant_calling.parabricks.container_image`. Parabricks HaplotypeCaller also follows `variant_calling.expected_coverage` to set `--min-pruning` and `--min-dangling-branch-length`. diff --git a/tests/data/fixtures/config/config.yaml b/tests/data/fixtures/config/config.yaml index 47a55b81..11ff91dc 100644 --- a/tests/data/fixtures/config/config.yaml +++ b/tests/data/fixtures/config/config.yaml @@ -35,6 +35,9 @@ intervals: min_nmer: 500 num_gvcf_intervals: 50 db_scatter_factor: 0.15 + min_contig_length: 0 + db_max_intervals_per_shard: 200 + db_max_contigs_per_shard: 200 callable_sites: diff --git a/tests/tests.py b/tests/tests.py index 4f277a55..c43151ad 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -6,19 +6,20 @@ import shutil import socket import subprocess +import sys import tempfile import threading from contextlib import contextmanager from pathlib import Path import pytest - -from conftest import SnakemakeRunner, WORKFLOW_DIR, TEST_DATA_DIR +from conftest import TEST_DATA_DIR, WORKFLOW_DIR, SnakemakeRunner TEST_DIR = Path(__file__).parent CONFIGS_DIR = TEST_DIR / "configs" SAMPLES_DIR = TEST_DIR / "sample_sheets" METADATA_DIR = TEST_DIR / "sample_metadata" +INTERVAL_LIST_TOOLS = WORKFLOW_DIR / "scripts" / "interval_list_tools.py" def get_samples_file(): @@ -41,6 +42,261 @@ def get_multistage_config_file(): return CONFIGS_DIR / "local_genome_multistage.yaml" +def write_interval_list(path, contig_lengths, records): + """Write a small Picard/GATK interval_list for script tests.""" + lines = ["@HD\tVN:1.6\n"] + lines.extend(f"@SQ\tSN:{contig}\tLN:{length}\n" for contig, length in contig_lengths.items()) + lines.extend(f"{contig}\t{start}\t{end}\t+\tACGTmer\n" for contig, start, end in records) + path.write_text("".join(lines)) + + +def read_interval_records(path): + return [ + line + for line in path.read_text().splitlines() + if line and not line.startswith("@") + ] + + +def read_interval_header(path): + return [line for line in path.read_text().splitlines() if line.startswith("@")] + + +def run_interval_list_tool(*args): + result = subprocess.run( + [sys.executable, str(INTERVAL_LIST_TOOLS), *map(str, args)], + capture_output=True, + text=True, + ) + assert result.returncode == 0, result.stderr + result.stdout + return result + + +def write_interval_complexity_config(base_config, out_dir): + """Write a config copy with the new interval complexity keys explicitly set.""" + text = Path(base_config).read_text() + pattern = re.compile( + r"(intervals:\n" + r" enabled: (?:true|false)\n" + r" min_nmer: \d+\n" + r" num_gvcf_intervals: \d+\n" + r" db_scatter_factor: [0-9.]+\n)" + ) + if not pattern.search(text): + raise AssertionError("Expected intervals block not found in config") + + out_path = Path(out_dir) / "config_interval_complexity.yaml" + out_path.write_text( + pattern.sub( + r"\1" + " min_contig_length: 1000\n" + " db_max_intervals_per_shard: 200\n" + " db_max_contigs_per_shard: 200\n", + text, + count=1, + ) + ) + return out_path + + +def test_db_interval_split_caps_fragmented_tail(tmp_path): + raw_dir = tmp_path / "raw" + out_dir = tmp_path / "out" + raw_dir.mkdir() + contig_lengths = {f"ctg{i:04d}": 1000 for i in range(1005)} + records = [(contig, 1, length) for contig, length in contig_lengths.items()] + raw_path = raw_dir / "0591-scattered.interval_list" + write_interval_list(raw_path, contig_lengths, records) + + run_interval_list_tool( + "split-db", + "--input-dir", + raw_dir, + "--output-dir", + out_dir, + "--fof", + out_dir / "intervals.txt", + "--max-intervals-per-shard", + 200, + "--max-contigs-per-shard", + 200, + ) + + shard_paths = [Path(line) for line in (out_dir / "intervals.txt").read_text().splitlines()] + assert len(shard_paths) == 6 + + combined_records = [] + expected_header = read_interval_header(raw_path) + for shard_path in shard_paths: + shard_records = read_interval_records(shard_path) + shard_contigs = {record.split("\t", 1)[0] for record in shard_records} + assert len(shard_records) <= 200 + assert len(shard_contigs) <= 200 + assert read_interval_header(shard_path) == expected_header + combined_records.extend(shard_records) + + assert combined_records == read_interval_records(raw_path) + + +def test_interval_filter_uses_contig_length_not_interval_length(tmp_path): + input_path = tmp_path / "input.interval_list" + output_path = tmp_path / "filtered.interval_list" + contig_lengths = {"short_contig": 900, "long_contig": 4000} + records = [ + ("short_contig", 1, 900), + ("long_contig", 1, 50), + ("long_contig", 1000, 2000), + ] + write_interval_list(input_path, contig_lengths, records) + + run_interval_list_tool( + "filter", + "--input", + input_path, + "--output", + output_path, + "--min-contig-length", + 1000, + ) + + assert read_interval_header(output_path) == read_interval_header(input_path) + assert read_interval_records(output_path) == [ + "long_contig\t1\t50\t+\tACGTmer", + "long_contig\t1000\t2000\t+\tACGTmer", + ] + + +def test_db_interval_split_disabled_preserves_raw_shard_contents(tmp_path): + raw_dir = tmp_path / "raw" + out_dir = tmp_path / "out" + raw_dir.mkdir() + + first_path = raw_dir / "0010-scattered.interval_list" + second_path = raw_dir / "0020-scattered.interval_list" + write_interval_list( + first_path, + {"ctg1": 1000, "ctg2": 1000}, + [("ctg1", 1, 1000), ("ctg2", 1, 1000)], + ) + write_interval_list( + second_path, + {"ctg3": 1000}, + [("ctg3", 1, 1000)], + ) + + run_interval_list_tool( + "split-db", + "--input-dir", + raw_dir, + "--output-dir", + out_dir, + "--fof", + out_dir / "intervals.txt", + "--max-intervals-per-shard", + 0, + "--max-contigs-per-shard", + 0, + ) + + shard_paths = [Path(line) for line in (out_dir / "intervals.txt").read_text().splitlines()] + assert [path.name for path in shard_paths] == [ + "0000-scattered.interval_list", + "0001-scattered.interval_list", + ] + assert shard_paths[0].read_text() == first_path.read_text() + assert shard_paths[1].read_text() == second_path.read_text() + + +def test_db_interval_split_rewrites_pathological_contig_shards(tmp_path): + raw_dir = tmp_path / "raw" + out_dir = tmp_path / "out" + raw_dir.mkdir() + input_path = raw_dir / "0000-scattered.interval_list" + write_interval_list( + input_path, + {"ctg1": 1000, "ctg2": 2000, "ctg3": 3000}, + [ + ("ctg1", 10, 100), + ("ctg1", 500, 900), + ("ctg2", 20, 200), + ("ctg3", 30, 300), + ], + ) + + run_interval_list_tool( + "split-db", + "--input-dir", + raw_dir, + "--output-dir", + out_dir, + "--fof", + out_dir / "intervals.txt", + "--max-intervals-per-shard", + 0, + "--max-contigs-per-shard", + 3, + "--merge-contigs-threshold", + 2, + ) + + shard_paths = [Path(line) for line in (out_dir / "intervals.txt").read_text().splitlines()] + assert [path.name for path in shard_paths] == [ + "0000-scattered.interval_list", + ] + assert read_interval_records(shard_paths[0]) == [ + "ctg1\t1\t1000\t+\tctg1", + "ctg2\t1\t2000\t+\tctg2", + "ctg3\t1\t3000\t+\tctg3", + ] + assert read_interval_header(shard_paths[0]) == read_interval_header(input_path) + + +def test_genomicsdb_merge_contigs_arg_only_for_whole_contig_pathology(tmp_path): + whole_path = tmp_path / "whole.interval_list" + partial_path = tmp_path / "partial.interval_list" + fai_path = tmp_path / "ref.fa.fai" + contig_lengths = {"ctg1": 1000, "ctg2": 2000, "ctg3": 3000} + write_interval_list( + whole_path, + contig_lengths, + [("ctg1", 1, 1000), ("ctg2", 1, 2000), ("ctg3", 1, 3000)], + ) + write_interval_list( + partial_path, + contig_lengths, + [("ctg1", 1, 1000), ("ctg2", 20, 200), ("ctg3", 1, 3000)], + ) + + result = run_interval_list_tool( + "genomicsdb-merge-contigs-arg", + "--input", + whole_path, + "--threshold", + 2, + ) + assert result.stdout.strip() == "--merge-contigs-into-num-partitions 2" + + result = run_interval_list_tool( + "genomicsdb-merge-contigs-arg", + "--input", + partial_path, + "--threshold", + 2, + ) + assert result.stdout.strip() == "" + assert "not all records are whole-contig intervals" in result.stderr + + fai_path.write_text("ctg1\t1000\t0\t80\t81\nctg2\t2000\t1020\t80\t81\nctg3\t3000\t3040\t80\t81\n") + result = run_interval_list_tool( + "genomicsdb-merge-contigs-arg", + "--input", + fai_path, + "--threshold", + 2, + ) + assert result.stdout.strip() == "--merge-contigs-into-num-partitions 2" + + def write_config_for_tool(base_config, out_dir, tool, parabricks_image=None): """Write a config copy with variant_calling.tool overridden.""" text = Path(base_config).read_text() @@ -426,6 +682,7 @@ def test_setup_dry_run(request): "prepare_reference", "index_reference", "picard_intervals", + "filter_picard_intervals", "create_gvcf_intervals", "create_db_intervals", ] @@ -434,6 +691,24 @@ def test_setup_dry_run(request): assert rule in output, f"Expected rule '{rule}' not found" +@pytest.mark.dry_run +def test_setup_dry_run_accepts_interval_complexity_config(request): + no_conda = request.config.getoption("--no-conda") + with tempfile.TemporaryDirectory() as tmpdir: + smk = SnakemakeRunner(Path(tmpdir), use_conda=not no_conda) + cfg = write_interval_complexity_config(get_config_file(), tmpdir) + + result = smk.dry_run( + target="setup", + configfile=cfg, + samples=get_samples_file(), + ) + + result.assert_success() + output = result.stdout + result.stderr + assert "filter_picard_intervals" in output + + @pytest.mark.dry_run def test_full_pipeline_dry_run(request): no_conda = request.config.getoption("--no-conda") @@ -700,6 +975,26 @@ def test_gatk_without_intervals_dry_run(request): assert "joint_genomics_db_import" in output +@pytest.mark.dry_run +def test_genomicsdb_import_dry_run_uses_internal_memory_and_contig_merge_guard(request): + no_conda = request.config.getoption("--no-conda") + with tempfile.TemporaryDirectory() as tmpdir: + smk = SnakemakeRunner(Path(tmpdir), use_conda=not no_conda) + cfg = write_intervals_config(get_config_file(), tmpdir, enabled=False) + + result = smk.dry_run( + target="call_variants", + configfile=cfg, + samples=get_samples_file(), + ) + result.assert_success() + + output = result.stdout + result.stderr + assert "--java-options '-Xmx3072m'" in output + assert "genomicsdb-merge-contigs-arg" in output + assert "--threshold 50" in output + + @pytest.mark.full_run @pytest.mark.parametrize("compressed", [False, True]) def test_reference_url_sources(request, compressed): diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk index 76eb15f2..99e2dd32 100644 --- a/workflow/rules/common.smk +++ b/workflow/rules/common.smk @@ -90,6 +90,9 @@ DEFAULTS = { "min_nmer": 500, "num_gvcf_intervals": 50, "db_scatter_factor": 0.15, + "min_contig_length": 0, + "db_max_intervals_per_shard": 200, + "db_max_contigs_per_shard": 200, }, "callable_sites": { "generate_bed_file": True, @@ -128,6 +131,24 @@ DEFAULTS = { }, } + +GENOMICSDB_IMPORT_HEAP_FRACTION = 0.75 +GENOMICSDB_MERGE_CONTIG_THRESHOLD = 50 + + +def _coerce_resource_mem_mb(resources, default_mem_mb=4096): + mem_mb = getattr(resources, "mem_mb", default_mem_mb) + try: + return int(float(mem_mb)) + except (TypeError, ValueError): + return default_mem_mb + + +def get_gatk_genomicsdb_import_java_opts(resources, default_mem_mb=4096): + mem_mb = _coerce_resource_mem_mb(resources, default_mem_mb) + return f"-Xmx{max(1, int(mem_mb * GENOMICSDB_IMPORT_HEAP_FRACTION))}m" + + REMOVED_MODULES = ("mk", "trackhub") V1_CONFIG_MARKERS = ( diff --git a/workflow/rules/intervals.smk b/workflow/rules/intervals.smk index 4607a1f4..9d7e798c 100644 --- a/workflow/rules/intervals.smk +++ b/workflow/rules/intervals.smk @@ -6,6 +6,9 @@ def get_db_interval_count(wildcards): return max(int(scatter_factor * num_samples * num_gvcf_intervals), 1) +INTERVAL_LIST_TOOLS = str(Path(workflow.basedir, "scripts/interval_list_tools.py")) + + rule picard_intervals: input: **REF_FILES, @@ -15,13 +18,13 @@ rule picard_intervals: "../envs/gatk.yaml" params: min_nmer=config["intervals"]["min_nmer"], - java_opts=lambda wildcards, resources: f"-Xmx{int(resources.mem_mb * 0.9)}m", + java_opts=lambda wildcards, resources: f"-Xmx{int(resources.mem_mb*0.9)}m", resources: mem_mb=4096, benchmark: "benchmarks/intervals_picard.txt" log: - "logs/intervals_picard.txt" + "logs/intervals_picard.txt", shell: """ picard ScatterIntervalsByNs \ @@ -33,15 +36,38 @@ rule picard_intervals: &> {log} """ -checkpoint create_gvcf_intervals: + +rule filter_picard_intervals: input: intervals="results/intervals/picard.interval_list", + output: + intervals="results/intervals/filtered.interval_list", + params: + min_contig_length=config["intervals"]["min_contig_length"], + interval_tools=INTERVAL_LIST_TOOLS, + benchmark: + "benchmarks/intervals_filter.txt" + log: + "logs/intervals_filter.txt", + shell: + """ + python {params.interval_tools} filter \ + --input {input.intervals} \ + --output {output.intervals} \ + --min-contig-length {params.min_contig_length} \ + &> {log} + """ + + +checkpoint create_gvcf_intervals: + input: **REF_FILES, + intervals="results/intervals/filtered.interval_list", output: fof="results/intervals/gvcf/intervals.txt", out_dir=directory("results/intervals/gvcf"), params: - java_opts=lambda wildcards, resources: f"-Xmx{int(resources.mem_mb * 0.9)}m", + java_opts=lambda wildcards, resources: f"-Xmx{int(resources.mem_mb*0.9)}m", scatter=config["intervals"]["num_gvcf_intervals"], resources: mem_mb=4096, @@ -50,7 +76,7 @@ checkpoint create_gvcf_intervals: benchmark: "benchmarks/intervals_gvcf.txt" log: - "logs/intervals_gvcf.txt" + "logs/intervals_gvcf.txt", shell: """ gatk SplitIntervals \ @@ -68,14 +94,18 @@ checkpoint create_gvcf_intervals: checkpoint create_db_intervals: input: - intervals="results/intervals/picard.interval_list", **REF_FILES, + intervals="results/intervals/filtered.interval_list", output: fof="results/intervals/db/intervals.txt", out_dir=directory("results/intervals/db"), params: - java_opts=lambda wildcards, resources: f"-Xmx{int(resources.mem_mb * 0.9)}m", + java_opts=lambda wildcards, resources: f"-Xmx{int(resources.mem_mb*0.9)}m", scatter=get_db_interval_count, + interval_tools=INTERVAL_LIST_TOOLS, + max_intervals=config["intervals"]["db_max_intervals_per_shard"], + max_contigs=config["intervals"]["db_max_contigs_per_shard"], + merge_contig_threshold=GENOMICSDB_MERGE_CONTIG_THRESHOLD, resources: mem_mb=4096, conda: @@ -83,17 +113,26 @@ checkpoint create_db_intervals: benchmark: "benchmarks/intervals_db.txt" log: - "logs/intervals_db.txt" + "logs/intervals_db.txt", shell: """ + mkdir -p {output.out_dir} + raw_dir=$(mktemp -d {output.out_dir}/gatk_split_raw.XXXXXX) gatk SplitIntervals \ --java-options '{params.java_opts}' \ -R {input.ref} \ -L {input.intervals} \ - -O {output.out_dir} \ + -O "$raw_dir" \ --scatter-count {params.scatter} \ --subdivision-mode INTERVAL_SUBDIVISION \ --interval-merging-rule OVERLAPPING_ONLY \ &> {log} - ls -1 {output.out_dir}/*-scattered.interval_list > {output.fof} + python {params.interval_tools} split-db \ + --input-dir "$raw_dir" \ + --output-dir {output.out_dir} \ + --fof {output.fof} \ + --max-intervals-per-shard {params.max_intervals} \ + --max-contigs-per-shard {params.max_contigs} \ + --merge-contigs-threshold {params.merge_contig_threshold} \ + >> {log} 2>&1 """ diff --git a/workflow/rules/variant_calling/gatk_intervals.smk b/workflow/rules/variant_calling/gatk_intervals.smk index b473976b..cb5b171b 100644 --- a/workflow/rules/variant_calling/gatk_intervals.smk +++ b/workflow/rules/variant_calling/gatk_intervals.smk @@ -319,7 +319,9 @@ rule gatk_genomics_db_import: db=temp(directory("results/gatk_genomics_db/L{interval}")), tar="results/gatk_genomics_db/L{interval}.tar", params: - java_mem=lambda wildcards, resources: f"-Xmx{int(resources.mem_mb * 0.9)}m", + java_mem=lambda wildcards, resources: get_gatk_genomicsdb_import_java_opts(resources), + interval_tools=INTERVAL_LIST_TOOLS, + merge_contig_threshold=GENOMICSDB_MERGE_CONTIG_THRESHOLD, threads: 1 resources: mem_mb=4096, @@ -331,18 +333,22 @@ rule gatk_genomics_db_import: "logs/gatk_genomics_db_import/{interval}.txt" shell: """ + : > {log} export TILEDB_DISABLE_FILE_LOCKING=1 + MERGE_CONTIGS_ARG=$(python {params.interval_tools} genomicsdb-merge-contigs-arg \ + --input {input.interval} \ + --threshold {params.merge_contig_threshold} 2>> {log}) gatk GenomicsDBImport \ --java-options '{params.java_mem}' \ --genomicsdb-shared-posixfs-optimizations true \ --batch-size 25 \ --genomicsdb-workspace-path {output.db} \ - --merge-input-intervals \ + --merge-input-intervals $MERGE_CONTIGS_ARG \ --reader-threads {threads} \ -L {input.interval} \ --tmp-dir {resources.tmpdir} \ --sample-name-map {input.db_mapfile} \ - &> {log} + &>> {log} tar -cf {output.tar} {output.db} &>> {log} """ diff --git a/workflow/rules/variant_calling/joint_gvcf.smk b/workflow/rules/variant_calling/joint_gvcf.smk index 92909dc9..e625e4c0 100644 --- a/workflow/rules/variant_calling/joint_gvcf.smk +++ b/workflow/rules/variant_calling/joint_gvcf.smk @@ -35,7 +35,9 @@ rule joint_genomics_db_import: db=temp(directory("results/gatk_genomics_db")), tar="results/gatk_genomics_db.tar", params: - java_opts=lambda wildcards, resources: _java_opts_from_resources(resources), + java_opts=lambda wildcards, resources: get_gatk_genomicsdb_import_java_opts(resources), + interval_tools=INTERVAL_LIST_TOOLS, + merge_contig_threshold=GENOMICSDB_MERGE_CONTIG_THRESHOLD, threads: 1 conda: "../../envs/gatk.yaml" @@ -45,18 +47,22 @@ rule joint_genomics_db_import: "logs/joint_genomics_db_import.txt" shell: """ + : > {log} export TILEDB_DISABLE_FILE_LOCKING=1 + MERGE_CONTIGS_ARG=$(python {params.interval_tools} genomicsdb-merge-contigs-arg \ + --input {input.ref_fai} \ + --threshold {params.merge_contig_threshold} 2>> {log}) gatk GenomicsDBImport \ --java-options '{params.java_opts}' \ --genomicsdb-shared-posixfs-optimizations true \ --batch-size 25 \ --genomicsdb-workspace-path {output.db} \ - --merge-input-intervals \ + --merge-input-intervals $MERGE_CONTIGS_ARG \ --reader-threads {threads} \ -L {input.ref_fai} \ --tmp-dir {resources.tmpdir} \ --sample-name-map {input.db_mapfile} \ - &> {log} + &>> {log} tar -cf {output.tar} {output.db} >> {log} 2>&1 """ diff --git a/workflow/rules/variant_calling/joint_gvcf_intervals.smk b/workflow/rules/variant_calling/joint_gvcf_intervals.smk index be492ecd..6b1f7c52 100644 --- a/workflow/rules/variant_calling/joint_gvcf_intervals.smk +++ b/workflow/rules/variant_calling/joint_gvcf_intervals.smk @@ -191,7 +191,9 @@ rule gatk_genomics_db_import: db=temp(directory("results/gatk_genomics_db/L{interval}")), tar="results/gatk_genomics_db/L{interval}.tar", params: - java_mem=lambda wildcards, resources: _java_opts_from_resources(resources), + java_mem=lambda wildcards, resources: get_gatk_genomicsdb_import_java_opts(resources), + interval_tools=INTERVAL_LIST_TOOLS, + merge_contig_threshold=GENOMICSDB_MERGE_CONTIG_THRESHOLD, threads: 1 resources: mem_mb=4096, @@ -203,18 +205,22 @@ rule gatk_genomics_db_import: "logs/gatk_genomics_db_import/{interval}.txt" shell: """ + : > {log} export TILEDB_DISABLE_FILE_LOCKING=1 + MERGE_CONTIGS_ARG=$(python {params.interval_tools} genomicsdb-merge-contigs-arg \ + --input {input.interval} \ + --threshold {params.merge_contig_threshold} 2>> {log}) gatk GenomicsDBImport \ --java-options '{params.java_mem}' \ --genomicsdb-shared-posixfs-optimizations true \ --batch-size 25 \ --genomicsdb-workspace-path {output.db} \ - --merge-input-intervals \ + --merge-input-intervals $MERGE_CONTIGS_ARG \ --reader-threads {threads} \ -L {input.interval} \ --tmp-dir {resources.tmpdir} \ --sample-name-map {input.db_mapfile} \ - &> {log} + &>> {log} tar -cf {output.tar} {output.db} >> {log} 2>&1 """ diff --git a/workflow/schemas/config.schema.yaml b/workflow/schemas/config.schema.yaml index 133a31e6..1d164963 100644 --- a/workflow/schemas/config.schema.yaml +++ b/workflow/schemas/config.schema.yaml @@ -152,6 +152,18 @@ properties: type: number minimum: 0 default: 0.15 + min_contig_length: + type: integer + minimum: 0 + default: 0 + db_max_intervals_per_shard: + type: integer + minimum: 0 + default: 200 + db_max_contigs_per_shard: + type: integer + minimum: 0 + default: 200 callable_sites: type: object diff --git a/workflow/scripts/interval_list_tools.py b/workflow/scripts/interval_list_tools.py new file mode 100644 index 00000000..39130a1a --- /dev/null +++ b/workflow/scripts/interval_list_tools.py @@ -0,0 +1,383 @@ +#!/usr/bin/env python3 +"""Utilities for Picard/GATK interval_list filtering and shard splitting.""" + +from __future__ import annotations + +import argparse +import shutil +import sys +from dataclasses import dataclass +from pathlib import Path + + +@dataclass(frozen=True) +class IntervalRecord: + raw: str + contig: str + start: int + end: int + + +@dataclass(frozen=True) +class IntervalList: + header: list[str] + records: list[IntervalRecord] + contig_lengths: dict[str, int] + + +def parse_interval_list(path: Path) -> IntervalList: + header: list[str] = [] + records: list[IntervalRecord] = [] + contig_lengths: dict[str, int] = {} + + with path.open() as handle: + for line in handle: + stripped = line.rstrip("\n") + if stripped.startswith("@"): + header.append(line) + if stripped.startswith("@SQ"): + fields = stripped.split("\t") + contig = None + length = None + for field in fields[1:]: + if field.startswith("SN:"): + contig = field[3:] + elif field.startswith("LN:"): + try: + length = int(field[3:]) + except ValueError as err: + raise ValueError( + f"Invalid @SQ length in {path}: {field}" + ) from err + if contig is not None and length is not None: + contig_lengths[contig] = length + continue + + if not stripped: + continue + + fields = stripped.split("\t") + if len(fields) < 3: + raise ValueError(f"Invalid interval record in {path}: {stripped}") + try: + start = int(fields[1]) + end = int(fields[2]) + except ValueError as err: + raise ValueError(f"Invalid interval record in {path}: {stripped}") from err + records.append( + IntervalRecord(raw=line, contig=fields[0], start=start, end=end) + ) + + return IntervalList(header=header, records=records, contig_lengths=contig_lengths) + + +def parse_fai(path: Path) -> IntervalList: + records: list[IntervalRecord] = [] + contig_lengths: dict[str, int] = {} + + with path.open() as handle: + for line in handle: + stripped = line.rstrip("\n") + if not stripped: + continue + + fields = stripped.split("\t") + if len(fields) < 2: + raise ValueError(f"Invalid FAI record in {path}: {stripped}") + try: + length = int(fields[1]) + except ValueError as err: + raise ValueError(f"Invalid FAI length in {path}: {stripped}") from err + + contig = fields[0] + contig_lengths[contig] = length + records.append( + IntervalRecord( + raw=f"{contig}\t1\t{length}\t+\t{contig}\n", + contig=contig, + start=1, + end=length, + ) + ) + + return IntervalList(header=[], records=records, contig_lengths=contig_lengths) + + +def parse_interval_source(path: Path) -> IntervalList: + if path.suffix == ".fai": + return parse_fai(path) + return parse_interval_list(path) + + +def write_interval_list(path: Path, header: list[str], records: list[IntervalRecord]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w") as handle: + handle.writelines(header) + for record in records: + handle.write(record.raw) + + +def filter_intervals(args: argparse.Namespace) -> int: + interval_list = parse_interval_list(args.input) + + if args.min_contig_length <= 0: + args.output.parent.mkdir(parents=True, exist_ok=True) + shutil.copyfile(args.input, args.output) + retained = len(interval_list.records) + else: + retained_records: list[IntervalRecord] = [] + for record in interval_list.records: + if record.contig not in interval_list.contig_lengths: + raise ValueError( + f"Cannot filter {args.input}: contig {record.contig!r} has no @SQ LN entry" + ) + if interval_list.contig_lengths[record.contig] >= args.min_contig_length: + retained_records.append(record) + + write_interval_list(args.output, interval_list.header, retained_records) + retained = len(retained_records) + + if retained == 0: + raise ValueError( + f"Interval filtering removed all intervals from {args.input}; " + "lower intervals.min_contig_length or disable it with 0" + ) + + print( + f"Retained {retained} of {len(interval_list.records)} interval(s) " + f"with min contig length {args.min_contig_length}" + ) + return 0 + + +def would_exceed_limit( + records: list[IntervalRecord], + contigs: set[str], + next_record: IntervalRecord, + max_intervals: int, + max_contigs: int, +) -> bool: + if max_intervals > 0 and len(records) + 1 > max_intervals: + return True + + return ( + max_contigs > 0 + and next_record.contig not in contigs + and len(contigs) + 1 > max_contigs + ) + + +def split_records( + records: list[IntervalRecord], + max_intervals: int, + max_contigs: int, +) -> list[list[IntervalRecord]]: + if max_intervals <= 0 and max_contigs <= 0: + return [records] + + chunks: list[list[IntervalRecord]] = [] + current: list[IntervalRecord] = [] + current_contigs: set[str] = set() + + for record in records: + if current and would_exceed_limit( + current, current_contigs, record, max_intervals, max_contigs + ): + chunks.append(current) + current = [] + current_contigs = set() + + current.append(record) + current_contigs.add(record.contig) + + if current: + chunks.append(current) + + return chunks + + +def remove_existing_final_shards(output_dir: Path, fof: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + for path in output_dir.glob("*-scattered.interval_list"): + if path.is_file(): + path.unlink() + if fof.exists(): + fof.unlink() + + +def count_unique_contigs(records: list[IntervalRecord]) -> int: + return len({record.contig for record in records}) + + +def whole_contig_records( + records: list[IntervalRecord], + contig_lengths: dict[str, int], + path: Path, +) -> list[IntervalRecord]: + whole_records: list[IntervalRecord] = [] + seen: set[str] = set() + + for record in records: + if record.contig in seen: + continue + if record.contig not in contig_lengths: + raise ValueError( + f"Cannot create whole-contig DB interval for {record.contig!r} " + f"from {path}: contig has no @SQ LN entry" + ) + + length = contig_lengths[record.contig] + whole_records.append( + IntervalRecord( + raw=f"{record.contig}\t1\t{length}\t+\t{record.contig}\n", + contig=record.contig, + start=1, + end=length, + ) + ) + seen.add(record.contig) + + return whole_records + + +def maybe_rewrite_pathological_contig_chunk( + records: list[IntervalRecord], + contig_lengths: dict[str, int], + path: Path, + threshold: int, +) -> list[IntervalRecord]: + if threshold <= 0 or count_unique_contigs(records) <= threshold: + return records + return whole_contig_records(records, contig_lengths, path) + + +def records_are_whole_contigs( + records: list[IntervalRecord], + contig_lengths: dict[str, int], +) -> bool: + seen: set[str] = set() + for record in records: + if record.contig in seen: + return False + if record.contig not in contig_lengths: + return False + if record.start != 1 or record.end != contig_lengths[record.contig]: + return False + seen.add(record.contig) + return True + + +def split_db_intervals(args: argparse.Namespace) -> int: + input_files = sorted(args.input_dir.glob("*-scattered.interval_list")) + if not input_files: + raise ValueError(f"No GATK scattered interval lists found in {args.input_dir}") + + output_chunks: list[tuple[list[str], list[IntervalRecord]]] = [] + total_records = 0 + for path in input_files: + interval_list = parse_interval_list(path) + total_records += len(interval_list.records) + if not interval_list.records: + continue + for chunk in split_records( + interval_list.records, + args.max_intervals_per_shard, + args.max_contigs_per_shard, + ): + chunk = maybe_rewrite_pathological_contig_chunk( + chunk, + interval_list.contig_lengths, + path, + args.merge_contigs_threshold, + ) + output_chunks.append((interval_list.header, chunk)) + + if not output_chunks: + raise ValueError(f"No interval records found in {args.input_dir}") + + remove_existing_final_shards(args.output_dir, args.fof) + + width = max(4, len(str(len(output_chunks) - 1))) + output_paths: list[Path] = [] + for index, (header, records) in enumerate(output_chunks): + output_path = args.output_dir / f"{index:0{width}d}-scattered.interval_list" + write_interval_list(output_path, header, records) + output_paths.append(output_path) + + with args.fof.open("w") as handle: + for path in output_paths: + handle.write(f"{path}\n") + + print( + f"Wrote {len(output_paths)} DB interval shard(s) from {len(input_files)} " + f"GATK shard(s) and {total_records} interval record(s)" + ) + return 0 + + +def genomicsdb_merge_contigs_arg(args: argparse.Namespace) -> int: + interval_list = parse_interval_source(args.input) + contigs = count_unique_contigs(interval_list.records) + if contigs <= args.threshold: + return 0 + + if records_are_whole_contigs(interval_list.records, interval_list.contig_lengths): + print(f"--merge-contigs-into-num-partitions {args.threshold}") + return 0 + + print( + f"WARNING: {args.input} has {contigs} contigs, but not all records are " + "whole-contig intervals; not enabling --merge-contigs-into-num-partitions", + file=sys.stderr, + ) + return 0 + + +def nonnegative_int(value: str) -> int: + parsed = int(value) + if parsed < 0: + raise argparse.ArgumentTypeError("value must be >= 0") + return parsed + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description=__doc__) + subparsers = parser.add_subparsers(dest="command", required=True) + + filter_parser = subparsers.add_parser("filter") + filter_parser.add_argument("--input", type=Path, required=True) + filter_parser.add_argument("--output", type=Path, required=True) + filter_parser.add_argument("--min-contig-length", type=nonnegative_int, required=True) + filter_parser.set_defaults(func=filter_intervals) + + split_parser = subparsers.add_parser("split-db") + split_parser.add_argument("--input-dir", type=Path, required=True) + split_parser.add_argument("--output-dir", type=Path, required=True) + split_parser.add_argument("--fof", type=Path, required=True) + split_parser.add_argument( + "--max-intervals-per-shard", type=nonnegative_int, required=True + ) + split_parser.add_argument("--max-contigs-per-shard", type=nonnegative_int, required=True) + split_parser.add_argument("--merge-contigs-threshold", type=nonnegative_int, default=0) + split_parser.set_defaults(func=split_db_intervals) + + merge_parser = subparsers.add_parser("genomicsdb-merge-contigs-arg") + merge_parser.add_argument("--input", type=Path, required=True) + merge_parser.add_argument("--threshold", type=nonnegative_int, required=True) + merge_parser.set_defaults(func=genomicsdb_merge_contigs_arg) + + return parser + + +def main(argv: list[str] | None = None) -> int: + parser = build_parser() + args = parser.parse_args(argv) + try: + return args.func(args) + except Exception as err: + print(f"ERROR: {err}", file=sys.stderr) + return 1 + + +if __name__ == "__main__": + raise SystemExit(main())