-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathfilter_cells.py
66 lines (59 loc) · 2.41 KB
/
filter_cells.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import cupy as cp
import cudf
import numpy as np
import scipy
import math
from cuml.linear_model import LinearRegression
def filter_cells(sparse_gpu_array, min_genes, max_genes, rows_per_batch=10000, barcodes=None):
"""
Filter cells that have genes greater than a max number of genes or less than
a minimum number of genes.
Parameters
----------
sparse_gpu_array : cupy.sparse.csr_matrix of shape (n_cells, n_genes)
CSR matrix to filter
min_genes : int
Lower bound on number of genes to keep
max_genes : int
Upper bound on number of genes to keep
rows_per_batch : int
Batch size to use for filtering. This can be adjusted for performance
to trade-off memory use.
barcodes : series
cudf series containing cell barcodes.
Returns
-------
filtered : scipy.sparse.csr_matrix of shape (n_cells, n_genes)
Matrix on host with filtered cells
barcodes : If barcodes are provided, also returns a series of
filtered barcodes.
"""
n_batches = math.ceil(sparse_gpu_array.shape[0] / rows_per_batch)
filtered_list = []
barcodes_batch = None
for batch in range(n_batches):
batch_size = rows_per_batch
start_idx = batch * batch_size
stop_idx = min(batch * batch_size + batch_size, sparse_gpu_array.shape[0])
arr_batch = sparse_gpu_array[start_idx:stop_idx]
if barcodes is not None:
barcodes_batch = barcodes[start_idx:stop_idx]
filtered_list.append(_filter_cells(arr_batch,
min_genes=min_genes,
max_genes=max_genes,
barcodes=barcodes_batch))
if barcodes is None:
return scipy.sparse.vstack(filtered_list)
else:
filtered_data = [x[0] for x in filtered_list]
filtered_barcodes = [x[1] for x in filtered_list]
filtered_barcodes = cudf.concat(filtered_barcodes)
return scipy.sparse.vstack(filtered_data), filtered_barcodes.reset_index(drop=True)
def _filter_cells(sparse_gpu_array, min_genes, max_genes, barcodes=None):
degrees = cp.diff(sparse_gpu_array.indptr)
query = ((min_genes <= degrees) & (degrees <= max_genes)).ravel()
query = query.get()
if barcodes is None:
return sparse_gpu_array.get()[query]
else:
return sparse_gpu_array.get()[query], barcodes[query]