From e8d8f5069b06c0cfe9b9f8311abb30bf8153015b Mon Sep 17 00:00:00 2001 From: d-laub Date: Sat, 28 Dec 2024 22:16:58 -0800 Subject: [PATCH 1/4] feat: option to use corrections other than +4/-5 from BAMs. --- seqdata/_io/readers/bam.py | 89 +++++++++++++++----------------------- 1 file changed, 36 insertions(+), 53 deletions(-) diff --git a/seqdata/_io/readers/bam.py b/seqdata/_io/readers/bam.py index e2319cd..7f29808 100644 --- a/seqdata/_io/readers/bam.py +++ b/seqdata/_io/readers/bam.py @@ -30,9 +30,8 @@ class CountMethod(str, Enum): - DEPTH = "depth-only" - TN5_CUTSITE = "tn5-cutsite" - TN5_FRAGMENT = "tn5-fragment" + FRAGMENTS = "fragments" + ENDS = "ends" class BAM(RegionReader, Generic[DTYPE]): @@ -46,10 +45,9 @@ def __init__( threads_per_job=1, dtype: Union[str, Type[np.number]] = np.uint16, sample_dim: Optional[str] = None, - offset_tn5=False, - count_method: Union[ - CountMethod, Literal["depth-only", "tn5-cutsite", "tn5-fragment"] - ] = "depth-only", + count_method: Union[CountMethod, Literal["fragments", "ends"]] = "fragments", + left_shift: Optional[int] = None, + right_shift: Optional[int] = None, ) -> None: """Reader for BAM files. @@ -75,10 +73,12 @@ def __init__( Data type to write the coverage as, by default np.uint16. sample_dim : Optional[str], optional Name of the sample dimension, by default None - offset_tn5 : bool, optional - Whether to adjust read lengths to account for Tn5 binding, by default False - count_method : Union[CountMethod, Literal["depth-only", "tn5-cutsite", "tn5-fragment"]] - Count method, by default "depth-only" + count_method : Union[CountMethod, Literal["fragments", "ends"]], optional + Count method, by default "fragments" + left_shift : Optional[int], optional + Shift the left end of the fragment by this amount, by default None + right_shift : Optional[int], optional + Shift the right end of the fragment by this amount, by default None """ if isinstance(bams, str): bams = [bams] @@ -96,8 +96,9 @@ def __init__( self.threads_per_job = threads_per_job self.dtype = np.dtype(dtype) self.sample_dim = f"{name}_sample" if sample_dim is None else sample_dim - self.offset_tn5 = offset_tn5 self.count_method = CountMethod(count_method) + self.left_shift = left_shift + self.right_shift = right_shift def _write( self, @@ -321,10 +322,7 @@ def read_cb(x: pysam.AlignedSegment): def _reader(self, bed: pl.DataFrame, f: pysam.AlignmentFile): for row in tqdm(bed.iter_rows(), total=len(bed)): contig, start, end = row[:3] - if self.count_method is CountMethod.DEPTH: - coverage = self._count_depth_only(f, contig, start, end) - else: - coverage = self._count_tn5(f, contig, start, end) + coverage = self._count(f, contig, start, end) yield coverage def _spliced_reader(self, bed: pl.DataFrame, f: pysam.AlignmentFile): @@ -337,33 +335,17 @@ def _spliced_reader(self, bed: pl.DataFrame, f: pysam.AlignmentFile): for row in rows: pbar.update() contig, start, end = row[:3] - if self.count_method is CountMethod.DEPTH: - coverage = self._count_depth_only(f, contig, start, end) - else: - coverage = self._count_tn5(f, contig, start, end) + coverage = self._count(f, contig, start, end) unspliced.append(coverage) yield cast(NDArray[DTYPE], np.concatenate(coverage)) # type: ignore - def _count_depth_only( - self, f: pysam.AlignmentFile, contig: str, start: int, end: int - ): - a, c, g, t = f.count_coverage( - contig, - max(start, 0), - end, - read_callback=lambda x: x.is_proper_pair and not x.is_secondary, - ) - coverage = np.vstack([a, c, g, t]).sum(0).astype(self.dtype) - if (pad_len := end - start - len(coverage)) > 0: - pad_arr = np.zeros(pad_len, dtype=self.dtype) - pad_left = start < 0 - if pad_left: - coverage = np.concatenate([pad_arr, coverage]) - else: - coverage = np.concatenate([coverage, pad_arr]) - return coverage - - def _count_tn5(self, f: pysam.AlignmentFile, contig: str, start: int, end: int): + def _count( + self, + f: pysam.AlignmentFile, + contig: str, + start: int, + end: int, + ) -> NDArray[DTYPE]: length = end - start out_array = np.zeros(length, dtype=self.dtype) @@ -391,18 +373,19 @@ def _count_tn5(self, f: pysam.AlignmentFile, contig: str, start: int, end: int): rel_end = cast(int, reverse_read.reference_end) - start # Shift read if accounting for offset - if self.offset_tn5: - rel_start += 4 - rel_end -= 5 + if self.left_shift: + rel_start = max(0, rel_start + self.left_shift) + if self.right_shift: + rel_end = min(length, rel_end + self.right_shift) # Check count method - if self.count_method is CountMethod.TN5_CUTSITE: + if self.count_method is CountMethod.ENDS: # Add cut sites to out_array if rel_start >= 0 and rel_start < length: out_array[rel_start] += 1 if rel_end >= 0 and rel_end <= length: out_array[rel_end - 1] += 1 - elif self.count_method is CountMethod.TN5_FRAGMENT: + elif self.count_method is CountMethod.FRAGMENTS: # Add range to out array out_array[rel_start:rel_end] += 1 @@ -411,24 +394,24 @@ def _count_tn5(self, f: pysam.AlignmentFile, contig: str, start: int, end: int): # for reverse reads, their mate is in the 5' <- direction if read.is_reverse: rel_end = cast(int, read.reference_end) - start - if self.offset_tn5: - rel_end -= 5 + if self.right_shift: + rel_end = min(length, rel_end + self.right_shift) if rel_end < 0 or rel_end > length: continue - if self.count_method is CountMethod.TN5_CUTSITE: + if self.count_method is CountMethod.ENDS: out_array[rel_end - 1] += 1 - elif self.count_method is CountMethod.TN5_FRAGMENT: + elif self.count_method is CountMethod.FRAGMENTS: out_array[:rel_end] += 1 # for forward reads, their mate is in the 3' -> direction else: rel_start = read.reference_start - start - if self.offset_tn5: - rel_start += 4 + if self.left_shift: + rel_start = max(0, rel_start + self.left_shift) if rel_start < 0 or rel_start >= length: continue - if self.count_method is CountMethod.TN5_CUTSITE: + if self.count_method is CountMethod.ENDS: out_array[rel_start] += 1 - elif self.count_method is CountMethod.TN5_FRAGMENT: + elif self.count_method is CountMethod.FRAGMENTS: out_array[rel_start:] += 1 return out_array From 74387bb723d051249201a472e08bee01743ec679 Mon Sep 17 00:00:00 2001 From: d-laub Date: Fri, 3 Jan 2025 14:06:22 -0800 Subject: [PATCH 2/4] feat: adds read count method and mapping quality threshold for BAM reader. --- seqdata/_io/readers/bam.py | 94 +++++++++++++++++++++++++++----------- 1 file changed, 67 insertions(+), 27 deletions(-) diff --git a/seqdata/_io/readers/bam.py b/seqdata/_io/readers/bam.py index 7f29808..b5cd024 100644 --- a/seqdata/_io/readers/bam.py +++ b/seqdata/_io/readers/bam.py @@ -1,6 +1,16 @@ from enum import Enum from pathlib import Path -from typing import Any, Dict, Generic, List, Literal, Optional, Type, Union, cast +from typing import ( + Any, + Dict, + Generic, + List, + Literal, + Optional, + Type, + Union, + cast, +) import joblib import numpy as np @@ -17,6 +27,7 @@ ) from numpy.typing import NDArray from tqdm import tqdm +from typing_extensions import assert_never from seqdata._io.utils import _get_row_batcher from seqdata.types import DTYPE, PathType, RegionReader @@ -30,6 +41,7 @@ class CountMethod(str, Enum): + READS = "reads" FRAGMENTS = "fragments" ENDS = "ends" @@ -41,15 +53,17 @@ def __init__( bams: Union[str, Path, List[str], List[Path]], samples: Union[str, List[str]], batch_size: int, + count_method: Union[CountMethod, Literal["reads", "fragments", "ends"]], n_jobs=1, threads_per_job=1, dtype: Union[str, Type[np.number]] = np.uint16, sample_dim: Optional[str] = None, - count_method: Union[CountMethod, Literal["fragments", "ends"]] = "fragments", - left_shift: Optional[int] = None, - right_shift: Optional[int] = None, + pos_shift: Optional[int] = None, + neg_shift: Optional[int] = None, + min_mapping_quality: Optional[int] = None, ) -> None: - """Reader for BAM files. + """Reader for next-generation sequencing paired-end BAM files. This reader will only count + reads that are properly paired and not secondary alignments. Parameters ---------- @@ -73,12 +87,18 @@ def __init__( Data type to write the coverage as, by default np.uint16. sample_dim : Optional[str], optional Name of the sample dimension, by default None - count_method : Union[CountMethod, Literal["fragments", "ends"]], optional - Count method, by default "fragments" - left_shift : Optional[int], optional - Shift the left end of the fragment by this amount, by default None - right_shift : Optional[int], optional - Shift the right end of the fragment by this amount, by default None + count_method : Union[CountMethod, Literal["reads", "fragments", "ends"]] + Count method: + - "reads" counts the base pairs spanning the aligned sequences of reads. + - "fragments" counts the base pairs spanning from the start of R1 to the end of R2. + - "ends" counts only the single base positions for the start of R1 and the end of R2. + + pos_shift : Optional[int], optional + Shift the forward read by this amount, by default None + neg_shift : Optional[int], optional + Shift the negative read by this amount, by default None + min_mapping_quality : Optional[int], optional + Minimum mapping quality for reads to be counted, by default None """ if isinstance(bams, str): bams = [bams] @@ -97,8 +117,9 @@ def __init__( self.dtype = np.dtype(dtype) self.sample_dim = f"{name}_sample" if sample_dim is None else sample_dim self.count_method = CountMethod(count_method) - self.left_shift = left_shift - self.right_shift = right_shift + self.pos_shift = pos_shift + self.neg_shift = neg_shift + self.min_mapping_quality = min_mapping_quality def _write( self, @@ -351,8 +372,15 @@ def _count( read_cache: Dict[str, pysam.AlignedSegment] = {} - for i, read in enumerate(f.fetch(contig, max(0, start), end)): - if not read.is_proper_pair or read.is_secondary: + for read in f.fetch(contig, max(0, start), end): + if ( + not read.is_proper_pair + or read.is_secondary + or ( + self.min_mapping_quality is not None + and read.mapping_quality < self.min_mapping_quality + ) + ): continue if read.query_name not in read_cache: @@ -370,15 +398,14 @@ def _count( rel_start = forward_read.reference_start - start # 0-based, 1 past aligned # e.g. start:end == 0:2 == [0, 1] so position of end == 1 - rel_end = cast(int, reverse_read.reference_end) - start + rel_end = reverse_read.reference_end - start # type: ignore | reference_end is defined for proper pairs # Shift read if accounting for offset - if self.left_shift: - rel_start = max(0, rel_start + self.left_shift) - if self.right_shift: - rel_end = min(length, rel_end + self.right_shift) + if self.pos_shift: + rel_start = max(0, rel_start + self.pos_shift) + if self.neg_shift: + rel_end = min(length, rel_end + self.neg_shift) - # Check count method if self.count_method is CountMethod.ENDS: # Add cut sites to out_array if rel_start >= 0 and rel_start < length: @@ -388,30 +415,43 @@ def _count( elif self.count_method is CountMethod.FRAGMENTS: # Add range to out array out_array[rel_start:rel_end] += 1 + elif self.count_method is CountMethod.READS: + out_array[rel_start : forward_read.reference_end - start] += 1 # type: ignore | reference_end is defined for proper pairs + out_array[reverse_read.reference_start - start : rel_end] += 1 + else: + assert_never(self.count_method) - # if any reads are still in the cache, then their mate isn't in the region + # if any reads are still in the cache, then their mate isn't in the region or didn't meet quality threshold for read in read_cache.values(): # for reverse reads, their mate is in the 5' <- direction if read.is_reverse: - rel_end = cast(int, read.reference_end) - start - if self.right_shift: - rel_end = min(length, rel_end + self.right_shift) + rel_end = read.reference_end - start # type: ignore | reference_end is defined for proper pairs + if self.neg_shift: + rel_end = min(length, rel_end + self.neg_shift) if rel_end < 0 or rel_end > length: continue if self.count_method is CountMethod.ENDS: out_array[rel_end - 1] += 1 elif self.count_method is CountMethod.FRAGMENTS: out_array[:rel_end] += 1 + elif self.count_method is CountMethod.READS: + out_array[read.reference_start - start : rel_end] += 1 + else: + assert_never(self.count_method) # for forward reads, their mate is in the 3' -> direction else: rel_start = read.reference_start - start - if self.left_shift: - rel_start = max(0, rel_start + self.left_shift) + if self.pos_shift: + rel_start = max(0, rel_start + self.pos_shift) if rel_start < 0 or rel_start >= length: continue if self.count_method is CountMethod.ENDS: out_array[rel_start] += 1 elif self.count_method is CountMethod.FRAGMENTS: out_array[rel_start:] += 1 + elif self.count_method is CountMethod.READS: + out_array[rel_start : read.reference_end - start] += 1 # type: ignore | reference_end is defined for proper pairs + else: + assert_never(self.count_method) return out_array From 355e6367704a91fe5d13090fdda6183b107fef7d Mon Sep 17 00:00:00 2001 From: d-laub Date: Fri, 3 Jan 2025 16:47:45 -0800 Subject: [PATCH 3/4] Automatically infer best n_jobs and threads_per_job. --- seqdata/_io/readers/bam.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/seqdata/_io/readers/bam.py b/seqdata/_io/readers/bam.py index b5cd024..0f1e8e5 100644 --- a/seqdata/_io/readers/bam.py +++ b/seqdata/_io/readers/bam.py @@ -1,3 +1,4 @@ +import os from enum import Enum from pathlib import Path from typing import ( @@ -54,8 +55,8 @@ def __init__( samples: Union[str, List[str]], batch_size: int, count_method: Union[CountMethod, Literal["reads", "fragments", "ends"]], - n_jobs=1, - threads_per_job=1, + n_jobs=-1, + threads_per_job=-1, dtype: Union[str, Type[np.number]] = np.uint16, sample_dim: Optional[str] = None, pos_shift: Optional[int] = None, @@ -77,12 +78,13 @@ def __init__( Number of sequences to write at a time. Note this also sets the chunksize along the sequence dimension. n_jobs : int, optional - Number of BAMs to process in parallel, by default 1, which disables - multiprocessing. Don't set this higher than the number of BAMs or number of - cores available. + Number of BAMs to process in parallel, by default -1. If -1, use the number + of available cores or the number of BAMs, whichever is smaller. If 0 or 1, process + BAMs sequentially. Not recommended to set this higher than the number of BAMs. threads_per_job : int, optional - Threads to use per job, by default 1. Make sure the number of available - cores is >= n_jobs * threads_per_job. + Threads to use per job, by default -1. If -1, uses any extra cores available after + allocating them to n_jobs. Not recommended to set this higher than the number of cores + available divided by n_jobs. dtype : Union[str, Type[np.number]], optional Data type to write the coverage as, by default np.uint16. sample_dim : Optional[str], optional @@ -112,8 +114,6 @@ def __init__( self.bams = bams self.samples = samples self.batch_size = batch_size - self.n_jobs = n_jobs - self.threads_per_job = threads_per_job self.dtype = np.dtype(dtype) self.sample_dim = f"{name}_sample" if sample_dim is None else sample_dim self.count_method = CountMethod(count_method) @@ -121,6 +121,20 @@ def __init__( self.neg_shift = neg_shift self.min_mapping_quality = min_mapping_quality + n_cpus = len(os.sched_getaffinity(0)) + if n_jobs == -1: + n_jobs = min(n_cpus, len(self.bams)) + elif n_jobs == 0: + n_jobs = 1 + + if threads_per_job == -1: + threads_per_job = 1 + if n_cpus > n_jobs: + threads_per_job = n_cpus // n_jobs + + self.n_jobs = n_jobs + self.threads_per_job = threads_per_job + def _write( self, out: PathType, @@ -202,7 +216,7 @@ def _write_fixed_length( ) for bam, sample_idx in zip(self.bams, sample_idxs) ] - with joblib.parallel_backend( + with joblib.parallel_config( "loky", n_jobs=self.n_jobs, inner_max_num_threads=self.threads_per_job ): joblib.Parallel()(tasks) From ae0291a3713a1e3818c5be44cc6916b5f604c4b7 Mon Sep 17 00:00:00 2001 From: d-laub Date: Fri, 3 Jan 2025 16:51:43 -0800 Subject: [PATCH 4/4] fix: update bams argument type --- seqdata/_io/readers/bam.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/seqdata/_io/readers/bam.py b/seqdata/_io/readers/bam.py index 0f1e8e5..a157a4a 100644 --- a/seqdata/_io/readers/bam.py +++ b/seqdata/_io/readers/bam.py @@ -51,7 +51,7 @@ class BAM(RegionReader, Generic[DTYPE]): def __init__( self, name: str, - bams: Union[str, Path, List[str], List[Path]], + bams: Union[PathType, List[PathType]], samples: Union[str, List[str]], batch_size: int, count_method: Union[CountMethod, Literal["reads", "fragments", "ends"]], @@ -102,9 +102,7 @@ def __init__( min_mapping_quality : Optional[int], optional Minimum mapping quality for reads to be counted, by default None """ - if isinstance(bams, str): - bams = [bams] - elif isinstance(bams, Path): + if isinstance(bams, str) or isinstance(bams, Path): bams = [bams] if isinstance(samples, str): samples = [samples] @@ -114,16 +112,16 @@ def __init__( self.bams = bams self.samples = samples self.batch_size = batch_size + self.count_method = CountMethod(count_method) self.dtype = np.dtype(dtype) self.sample_dim = f"{name}_sample" if sample_dim is None else sample_dim - self.count_method = CountMethod(count_method) self.pos_shift = pos_shift self.neg_shift = neg_shift self.min_mapping_quality = min_mapping_quality n_cpus = len(os.sched_getaffinity(0)) if n_jobs == -1: - n_jobs = min(n_cpus, len(self.bams)) + n_jobs = min(n_cpus, len(bams)) elif n_jobs == 0: n_jobs = 1