Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 111 additions & 76 deletions seqdata/_io/readers/bam.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
import os
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Generic, List, Literal, Optional, Type, Union, cast
from typing import (
Any,
Dict,
Generic,
List,
Literal,
Optional,
Type,
Union,
cast,
)

import joblib
import numpy as np
Expand All @@ -17,6 +28,7 @@
)
from numpy.typing import NDArray
from tqdm import tqdm
from typing_extensions import assert_never

from seqdata._io.utils import _get_row_batcher
from seqdata.types import DTYPE, PathType, RegionReader
Expand All @@ -30,28 +42,29 @@


class CountMethod(str, Enum):
DEPTH = "depth-only"
TN5_CUTSITE = "tn5-cutsite"
TN5_FRAGMENT = "tn5-fragment"
READS = "reads"
FRAGMENTS = "fragments"
ENDS = "ends"


class BAM(RegionReader, Generic[DTYPE]):
def __init__(
self,
name: str,
bams: Union[str, Path, List[str], List[Path]],
bams: Union[PathType, List[PathType]],
samples: Union[str, List[str]],
batch_size: int,
n_jobs=1,
threads_per_job=1,
count_method: Union[CountMethod, Literal["reads", "fragments", "ends"]],
n_jobs=-1,
threads_per_job=-1,
dtype: Union[str, Type[np.number]] = np.uint16,
sample_dim: Optional[str] = None,
offset_tn5=False,
count_method: Union[
CountMethod, Literal["depth-only", "tn5-cutsite", "tn5-fragment"]
] = "depth-only",
pos_shift: Optional[int] = None,
neg_shift: Optional[int] = None,
min_mapping_quality: Optional[int] = None,
) -> None:
"""Reader for BAM files.
"""Reader for next-generation sequencing paired-end BAM files. This reader will only count
reads that are properly paired and not secondary alignments.

Parameters
----------
Expand All @@ -65,24 +78,31 @@ def __init__(
Number of sequences to write at a time. Note this also sets the chunksize
along the sequence dimension.
n_jobs : int, optional
Number of BAMs to process in parallel, by default 1, which disables
multiprocessing. Don't set this higher than the number of BAMs or number of
cores available.
Number of BAMs to process in parallel, by default -1. If -1, use the number
of available cores or the number of BAMs, whichever is smaller. If 0 or 1, process
BAMs sequentially. Not recommended to set this higher than the number of BAMs.
threads_per_job : int, optional
Threads to use per job, by default 1. Make sure the number of available
cores is >= n_jobs * threads_per_job.
Threads to use per job, by default -1. If -1, uses any extra cores available after
allocating them to n_jobs. Not recommended to set this higher than the number of cores
available divided by n_jobs.
dtype : Union[str, Type[np.number]], optional
Data type to write the coverage as, by default np.uint16.
sample_dim : Optional[str], optional
Name of the sample dimension, by default None
offset_tn5 : bool, optional
Whether to adjust read lengths to account for Tn5 binding, by default False
count_method : Union[CountMethod, Literal["depth-only", "tn5-cutsite", "tn5-fragment"]]
Count method, by default "depth-only"
count_method : Union[CountMethod, Literal["reads", "fragments", "ends"]]
Count method:
- "reads" counts the base pairs spanning the aligned sequences of reads.
- "fragments" counts the base pairs spanning from the start of R1 to the end of R2.
- "ends" counts only the single base positions for the start of R1 and the end of R2.

pos_shift : Optional[int], optional
Shift the forward read by this amount, by default None
neg_shift : Optional[int], optional
Shift the negative read by this amount, by default None
min_mapping_quality : Optional[int], optional
Minimum mapping quality for reads to be counted, by default None
"""
if isinstance(bams, str):
bams = [bams]
elif isinstance(bams, Path):
if isinstance(bams, str) or isinstance(bams, Path):
bams = [bams]
if isinstance(samples, str):
samples = [samples]
Expand All @@ -92,12 +112,26 @@ def __init__(
self.bams = bams
self.samples = samples
self.batch_size = batch_size
self.n_jobs = n_jobs
self.threads_per_job = threads_per_job
self.count_method = CountMethod(count_method)
self.dtype = np.dtype(dtype)
self.sample_dim = f"{name}_sample" if sample_dim is None else sample_dim
self.offset_tn5 = offset_tn5
self.count_method = CountMethod(count_method)
self.pos_shift = pos_shift
self.neg_shift = neg_shift
self.min_mapping_quality = min_mapping_quality

n_cpus = len(os.sched_getaffinity(0))
if n_jobs == -1:
n_jobs = min(n_cpus, len(bams))
elif n_jobs == 0:
n_jobs = 1

if threads_per_job == -1:
threads_per_job = 1
if n_cpus > n_jobs:
threads_per_job = n_cpus // n_jobs

self.n_jobs = n_jobs
self.threads_per_job = threads_per_job

def _write(
self,
Expand Down Expand Up @@ -180,7 +214,7 @@ def _write_fixed_length(
)
for bam, sample_idx in zip(self.bams, sample_idxs)
]
with joblib.parallel_backend(
with joblib.parallel_config(
"loky", n_jobs=self.n_jobs, inner_max_num_threads=self.threads_per_job
):
joblib.Parallel()(tasks)
Expand Down Expand Up @@ -321,10 +355,7 @@ def read_cb(x: pysam.AlignedSegment):
def _reader(self, bed: pl.DataFrame, f: pysam.AlignmentFile):
for row in tqdm(bed.iter_rows(), total=len(bed)):
contig, start, end = row[:3]
if self.count_method is CountMethod.DEPTH:
coverage = self._count_depth_only(f, contig, start, end)
else:
coverage = self._count_tn5(f, contig, start, end)
coverage = self._count(f, contig, start, end)
yield coverage

def _spliced_reader(self, bed: pl.DataFrame, f: pysam.AlignmentFile):
Expand All @@ -337,40 +368,31 @@ def _spliced_reader(self, bed: pl.DataFrame, f: pysam.AlignmentFile):
for row in rows:
pbar.update()
contig, start, end = row[:3]
if self.count_method is CountMethod.DEPTH:
coverage = self._count_depth_only(f, contig, start, end)
else:
coverage = self._count_tn5(f, contig, start, end)
coverage = self._count(f, contig, start, end)
unspliced.append(coverage)
yield cast(NDArray[DTYPE], np.concatenate(coverage)) # type: ignore

def _count_depth_only(
self, f: pysam.AlignmentFile, contig: str, start: int, end: int
):
a, c, g, t = f.count_coverage(
contig,
max(start, 0),
end,
read_callback=lambda x: x.is_proper_pair and not x.is_secondary,
)
coverage = np.vstack([a, c, g, t]).sum(0).astype(self.dtype)
if (pad_len := end - start - len(coverage)) > 0:
pad_arr = np.zeros(pad_len, dtype=self.dtype)
pad_left = start < 0
if pad_left:
coverage = np.concatenate([pad_arr, coverage])
else:
coverage = np.concatenate([coverage, pad_arr])
return coverage

def _count_tn5(self, f: pysam.AlignmentFile, contig: str, start: int, end: int):
def _count(
self,
f: pysam.AlignmentFile,
contig: str,
start: int,
end: int,
) -> NDArray[DTYPE]:
length = end - start
out_array = np.zeros(length, dtype=self.dtype)

read_cache: Dict[str, pysam.AlignedSegment] = {}

for i, read in enumerate(f.fetch(contig, max(0, start), end)):
if not read.is_proper_pair or read.is_secondary:
for read in f.fetch(contig, max(0, start), end):
if (
not read.is_proper_pair
or read.is_secondary
or (
self.min_mapping_quality is not None
and read.mapping_quality < self.min_mapping_quality
)
):
continue

if read.query_name not in read_cache:
Expand All @@ -388,47 +410,60 @@ def _count_tn5(self, f: pysam.AlignmentFile, contig: str, start: int, end: int):
rel_start = forward_read.reference_start - start
# 0-based, 1 past aligned
# e.g. start:end == 0:2 == [0, 1] so position of end == 1
rel_end = cast(int, reverse_read.reference_end) - start
rel_end = reverse_read.reference_end - start # type: ignore | reference_end is defined for proper pairs

# Shift read if accounting for offset
if self.offset_tn5:
rel_start += 4
rel_end -= 5
if self.pos_shift:
rel_start = max(0, rel_start + self.pos_shift)
if self.neg_shift:
rel_end = min(length, rel_end + self.neg_shift)

# Check count method
if self.count_method is CountMethod.TN5_CUTSITE:
if self.count_method is CountMethod.ENDS:
# Add cut sites to out_array
if rel_start >= 0 and rel_start < length:
out_array[rel_start] += 1
if rel_end >= 0 and rel_end <= length:
out_array[rel_end - 1] += 1
elif self.count_method is CountMethod.TN5_FRAGMENT:
elif self.count_method is CountMethod.FRAGMENTS:
# Add range to out array
out_array[rel_start:rel_end] += 1
elif self.count_method is CountMethod.READS:
out_array[rel_start : forward_read.reference_end - start] += 1 # type: ignore | reference_end is defined for proper pairs
out_array[reverse_read.reference_start - start : rel_end] += 1
else:
assert_never(self.count_method)

# if any reads are still in the cache, then their mate isn't in the region
# if any reads are still in the cache, then their mate isn't in the region or didn't meet quality threshold
for read in read_cache.values():
# for reverse reads, their mate is in the 5' <- direction
if read.is_reverse:
rel_end = cast(int, read.reference_end) - start
if self.offset_tn5:
rel_end -= 5
rel_end = read.reference_end - start # type: ignore | reference_end is defined for proper pairs
if self.neg_shift:
rel_end = min(length, rel_end + self.neg_shift)
if rel_end < 0 or rel_end > length:
continue
if self.count_method is CountMethod.TN5_CUTSITE:
if self.count_method is CountMethod.ENDS:
out_array[rel_end - 1] += 1
elif self.count_method is CountMethod.TN5_FRAGMENT:
elif self.count_method is CountMethod.FRAGMENTS:
out_array[:rel_end] += 1
elif self.count_method is CountMethod.READS:
out_array[read.reference_start - start : rel_end] += 1
else:
assert_never(self.count_method)
# for forward reads, their mate is in the 3' -> direction
else:
rel_start = read.reference_start - start
if self.offset_tn5:
rel_start += 4
if self.pos_shift:
rel_start = max(0, rel_start + self.pos_shift)
if rel_start < 0 or rel_start >= length:
continue
if self.count_method is CountMethod.TN5_CUTSITE:
if self.count_method is CountMethod.ENDS:
out_array[rel_start] += 1
elif self.count_method is CountMethod.TN5_FRAGMENT:
elif self.count_method is CountMethod.FRAGMENTS:
out_array[rel_start:] += 1
elif self.count_method is CountMethod.READS:
out_array[rel_start : read.reference_end - start] += 1 # type: ignore | reference_end is defined for proper pairs
else:
assert_never(self.count_method)

return out_array