Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ intervals:
min_nmer: 500
num_gvcf_intervals: 50
db_scatter_factor: 0.15
min_contig_length: 0
db_max_intervals_per_shard: 200
db_max_contigs_per_shard: 200


callable_sites:
Expand Down
3 changes: 3 additions & 0 deletions docs/config-fields-supported.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ These are the supported v2 config keys used by the main workflow. Unknown keys a
| `intervals.min_nmer` | integer | no | `500` | `>= 1` |
| `intervals.num_gvcf_intervals` | integer | no | `50` | `>= 1` |
| `intervals.db_scatter_factor` | number | no | `0.15` | `>= 0` |
| `intervals.min_contig_length` | integer | no | `0` | `>= 0`; excludes interval records on shorter contigs when nonzero |
| `intervals.db_max_intervals_per_shard` | integer | no | `200` | `>= 0`; `0` disables cap |
| `intervals.db_max_contigs_per_shard` | integer | no | `200` | `>= 0`; `0` disables cap |
| `callable_sites.generate_bed_file` | boolean | no | `true` | Controls final callable BED target |
| `callable_sites.coverage.enabled` | boolean | no | `true` | Requires BAM-backed samples if true |
| `callable_sites.coverage.fraction` | number | no | `1.0` | `0..1` |
Expand Down
6 changes: 5 additions & 1 deletion docs/setup.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ The following options can be adjusted based on your needs and your dataset.
| ---- | -------------| ------ |
|`intervals.min_nmer`| Minimum span of Ns used to split reference for interval generation. | `int`|
|`intervals.num_gvcf_intervals` | Maximum number of GVCF intervals to create. | `int`|
|`intervals.db_scatter_factor` | Used to calculate number of DB intervals (`num_db_intervals = scatter_factor * num_samples * num_gvcf_intervals`). | `float`|
|`intervals.db_scatter_factor` | Used to calculate the target number of DB intervals before complexity-aware post-processing (`num_db_intervals = scatter_factor * num_samples * num_gvcf_intervals`). | `float`|
|`intervals.db_max_intervals_per_shard` | Maximum interval records allowed in each final DB shard; set to `0` to disable this cap. | `int`|
|`intervals.db_max_contigs_per_shard` | Maximum unique contigs allowed in each final DB shard; set to `0` to disable this cap. | `int`|
|`intervals.min_contig_length` | Exclude interval records on contigs shorter than this length; `0` keeps all contigs. This changes called regions but does not filter the reference used for mapping. | `int`|
| `variant_calling.expected_coverage` | Coverage profile used to set caller tuning (`low` or `high`). | `str` |
| `variant_calling.ploidy` | Ploidy for variant calling step. | `int` |
| `variant_calling.gatk.het_prior` | Heterozygosity prior passed to GATK GenotypeGVCFs. | `float` |
Expand All @@ -138,6 +141,7 @@ When `variant_calling.tool` is `bcftools`, `deepvariant`, or `parabricks`, sampl

`intervals.enabled` controls interval-split HaplotypeCaller for the GATK backend.
Parabricks uses interval-split joint genotyping (GenomicsDBImport/GenotypeGVCFs) regardless of `intervals.enabled`.
snpArcher automatically reserves native memory for GenomicsDBImport and applies GATK contig merging to DB shards with many whole-contig intervals.

Parabricks execution expects NVIDIA GPUs and an Apptainer/Singularity image path in `variant_calling.parabricks.container_image`.
Parabricks HaplotypeCaller also follows `variant_calling.expected_coverage` to set `--min-pruning` and `--min-dangling-branch-length`.
Expand Down
3 changes: 3 additions & 0 deletions tests/data/fixtures/config/config.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

299 changes: 297 additions & 2 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,20 @@
import shutil
import socket
import subprocess
import sys
import tempfile
import threading
from contextlib import contextmanager
from pathlib import Path

import pytest

from conftest import SnakemakeRunner, WORKFLOW_DIR, TEST_DATA_DIR
from conftest import TEST_DATA_DIR, WORKFLOW_DIR, SnakemakeRunner

TEST_DIR = Path(__file__).parent
CONFIGS_DIR = TEST_DIR / "configs"
SAMPLES_DIR = TEST_DIR / "sample_sheets"
METADATA_DIR = TEST_DIR / "sample_metadata"
INTERVAL_LIST_TOOLS = WORKFLOW_DIR / "scripts" / "interval_list_tools.py"


def get_samples_file():
Expand All @@ -41,6 +42,261 @@ def get_multistage_config_file():
return CONFIGS_DIR / "local_genome_multistage.yaml"


def write_interval_list(path, contig_lengths, records):
"""Write a small Picard/GATK interval_list for script tests."""
lines = ["@HD\tVN:1.6\n"]
lines.extend(f"@SQ\tSN:{contig}\tLN:{length}\n" for contig, length in contig_lengths.items())
lines.extend(f"{contig}\t{start}\t{end}\t+\tACGTmer\n" for contig, start, end in records)
path.write_text("".join(lines))


def read_interval_records(path):
return [
line
for line in path.read_text().splitlines()
if line and not line.startswith("@")
]


def read_interval_header(path):
return [line for line in path.read_text().splitlines() if line.startswith("@")]


def run_interval_list_tool(*args):
result = subprocess.run(
[sys.executable, str(INTERVAL_LIST_TOOLS), *map(str, args)],
capture_output=True,
text=True,
)
assert result.returncode == 0, result.stderr + result.stdout
return result


def write_interval_complexity_config(base_config, out_dir):
"""Write a config copy with the new interval complexity keys explicitly set."""
text = Path(base_config).read_text()
pattern = re.compile(
r"(intervals:\n"
r" enabled: (?:true|false)\n"
r" min_nmer: \d+\n"
r" num_gvcf_intervals: \d+\n"
r" db_scatter_factor: [0-9.]+\n)"
)
if not pattern.search(text):
raise AssertionError("Expected intervals block not found in config")

out_path = Path(out_dir) / "config_interval_complexity.yaml"
out_path.write_text(
pattern.sub(
r"\1"
" min_contig_length: 1000\n"
" db_max_intervals_per_shard: 200\n"
" db_max_contigs_per_shard: 200\n",
text,
count=1,
)
)
return out_path


def test_db_interval_split_caps_fragmented_tail(tmp_path):
raw_dir = tmp_path / "raw"
out_dir = tmp_path / "out"
raw_dir.mkdir()
contig_lengths = {f"ctg{i:04d}": 1000 for i in range(1005)}
records = [(contig, 1, length) for contig, length in contig_lengths.items()]
raw_path = raw_dir / "0591-scattered.interval_list"
write_interval_list(raw_path, contig_lengths, records)

run_interval_list_tool(
"split-db",
"--input-dir",
raw_dir,
"--output-dir",
out_dir,
"--fof",
out_dir / "intervals.txt",
"--max-intervals-per-shard",
200,
"--max-contigs-per-shard",
200,
)

shard_paths = [Path(line) for line in (out_dir / "intervals.txt").read_text().splitlines()]
assert len(shard_paths) == 6

combined_records = []
expected_header = read_interval_header(raw_path)
for shard_path in shard_paths:
shard_records = read_interval_records(shard_path)
shard_contigs = {record.split("\t", 1)[0] for record in shard_records}
assert len(shard_records) <= 200
assert len(shard_contigs) <= 200
assert read_interval_header(shard_path) == expected_header
combined_records.extend(shard_records)

assert combined_records == read_interval_records(raw_path)


def test_interval_filter_uses_contig_length_not_interval_length(tmp_path):
input_path = tmp_path / "input.interval_list"
output_path = tmp_path / "filtered.interval_list"
contig_lengths = {"short_contig": 900, "long_contig": 4000}
records = [
("short_contig", 1, 900),
("long_contig", 1, 50),
("long_contig", 1000, 2000),
]
write_interval_list(input_path, contig_lengths, records)

run_interval_list_tool(
"filter",
"--input",
input_path,
"--output",
output_path,
"--min-contig-length",
1000,
)

assert read_interval_header(output_path) == read_interval_header(input_path)
assert read_interval_records(output_path) == [
"long_contig\t1\t50\t+\tACGTmer",
"long_contig\t1000\t2000\t+\tACGTmer",
]


def test_db_interval_split_disabled_preserves_raw_shard_contents(tmp_path):
raw_dir = tmp_path / "raw"
out_dir = tmp_path / "out"
raw_dir.mkdir()

first_path = raw_dir / "0010-scattered.interval_list"
second_path = raw_dir / "0020-scattered.interval_list"
write_interval_list(
first_path,
{"ctg1": 1000, "ctg2": 1000},
[("ctg1", 1, 1000), ("ctg2", 1, 1000)],
)
write_interval_list(
second_path,
{"ctg3": 1000},
[("ctg3", 1, 1000)],
)

run_interval_list_tool(
"split-db",
"--input-dir",
raw_dir,
"--output-dir",
out_dir,
"--fof",
out_dir / "intervals.txt",
"--max-intervals-per-shard",
0,
"--max-contigs-per-shard",
0,
)

shard_paths = [Path(line) for line in (out_dir / "intervals.txt").read_text().splitlines()]
assert [path.name for path in shard_paths] == [
"0000-scattered.interval_list",
"0001-scattered.interval_list",
]
assert shard_paths[0].read_text() == first_path.read_text()
assert shard_paths[1].read_text() == second_path.read_text()


def test_db_interval_split_rewrites_pathological_contig_shards(tmp_path):
raw_dir = tmp_path / "raw"
out_dir = tmp_path / "out"
raw_dir.mkdir()
input_path = raw_dir / "0000-scattered.interval_list"
write_interval_list(
input_path,
{"ctg1": 1000, "ctg2": 2000, "ctg3": 3000},
[
("ctg1", 10, 100),
("ctg1", 500, 900),
("ctg2", 20, 200),
("ctg3", 30, 300),
],
)

run_interval_list_tool(
"split-db",
"--input-dir",
raw_dir,
"--output-dir",
out_dir,
"--fof",
out_dir / "intervals.txt",
"--max-intervals-per-shard",
0,
"--max-contigs-per-shard",
3,
"--merge-contigs-threshold",
2,
)

shard_paths = [Path(line) for line in (out_dir / "intervals.txt").read_text().splitlines()]
assert [path.name for path in shard_paths] == [
"0000-scattered.interval_list",
]
assert read_interval_records(shard_paths[0]) == [
"ctg1\t1\t1000\t+\tctg1",
"ctg2\t1\t2000\t+\tctg2",
"ctg3\t1\t3000\t+\tctg3",
]
assert read_interval_header(shard_paths[0]) == read_interval_header(input_path)


def test_genomicsdb_merge_contigs_arg_only_for_whole_contig_pathology(tmp_path):
whole_path = tmp_path / "whole.interval_list"
partial_path = tmp_path / "partial.interval_list"
fai_path = tmp_path / "ref.fa.fai"
contig_lengths = {"ctg1": 1000, "ctg2": 2000, "ctg3": 3000}
write_interval_list(
whole_path,
contig_lengths,
[("ctg1", 1, 1000), ("ctg2", 1, 2000), ("ctg3", 1, 3000)],
)
write_interval_list(
partial_path,
contig_lengths,
[("ctg1", 1, 1000), ("ctg2", 20, 200), ("ctg3", 1, 3000)],
)

result = run_interval_list_tool(
"genomicsdb-merge-contigs-arg",
"--input",
whole_path,
"--threshold",
2,
)
assert result.stdout.strip() == "--merge-contigs-into-num-partitions 2"

result = run_interval_list_tool(
"genomicsdb-merge-contigs-arg",
"--input",
partial_path,
"--threshold",
2,
)
assert result.stdout.strip() == ""
assert "not all records are whole-contig intervals" in result.stderr

fai_path.write_text("ctg1\t1000\t0\t80\t81\nctg2\t2000\t1020\t80\t81\nctg3\t3000\t3040\t80\t81\n")
result = run_interval_list_tool(
"genomicsdb-merge-contigs-arg",
"--input",
fai_path,
"--threshold",
2,
)
assert result.stdout.strip() == "--merge-contigs-into-num-partitions 2"


def write_config_for_tool(base_config, out_dir, tool, parabricks_image=None):
"""Write a config copy with variant_calling.tool overridden."""
text = Path(base_config).read_text()
Expand Down Expand Up @@ -426,6 +682,7 @@ def test_setup_dry_run(request):
"prepare_reference",
"index_reference",
"picard_intervals",
"filter_picard_intervals",
"create_gvcf_intervals",
"create_db_intervals",
]
Expand All @@ -434,6 +691,24 @@ def test_setup_dry_run(request):
assert rule in output, f"Expected rule '{rule}' not found"


@pytest.mark.dry_run
def test_setup_dry_run_accepts_interval_complexity_config(request):
no_conda = request.config.getoption("--no-conda")
with tempfile.TemporaryDirectory() as tmpdir:
smk = SnakemakeRunner(Path(tmpdir), use_conda=not no_conda)
cfg = write_interval_complexity_config(get_config_file(), tmpdir)

result = smk.dry_run(
target="setup",
configfile=cfg,
samples=get_samples_file(),
)

result.assert_success()
output = result.stdout + result.stderr
assert "filter_picard_intervals" in output


@pytest.mark.dry_run
def test_full_pipeline_dry_run(request):
no_conda = request.config.getoption("--no-conda")
Expand Down Expand Up @@ -700,6 +975,26 @@ def test_gatk_without_intervals_dry_run(request):
assert "joint_genomics_db_import" in output


@pytest.mark.dry_run
def test_genomicsdb_import_dry_run_uses_internal_memory_and_contig_merge_guard(request):
no_conda = request.config.getoption("--no-conda")
with tempfile.TemporaryDirectory() as tmpdir:
smk = SnakemakeRunner(Path(tmpdir), use_conda=not no_conda)
cfg = write_intervals_config(get_config_file(), tmpdir, enabled=False)

result = smk.dry_run(
target="call_variants",
configfile=cfg,
samples=get_samples_file(),
)
result.assert_success()

output = result.stdout + result.stderr
assert "--java-options '-Xmx3072m'" in output
assert "genomicsdb-merge-contigs-arg" in output
assert "--threshold 50" in output


@pytest.mark.full_run
@pytest.mark.parametrize("compressed", [False, True])
def test_reference_url_sources(request, compressed):
Expand Down
Loading
Loading