diff --git a/classify/blast.py b/classify/blast.py index 3641a930..a4a04fd5 100644 --- a/classify/blast.py +++ b/classify/blast.py @@ -8,20 +8,32 @@ import tools import tools.samtools import util.misc - +import time TOOL_NAME = "blastn" -''' -#Creating task.log -logging.basicConfig( - level=logging.DEBUG, - format='%(asctime)s - %(levelname)s - %(message)s', - handlers=[ - logging.FileHandler("blast_py.log"), - logging.StreamHandler() - ] -) -''' +#Setting up logging, blast_py.log +try: + log_directory = os.getcwd() + + # Ensure the directory exists, if not, create it + if not os.path.exists(log_directory): + os.makedirs(log_directory) + + #Set up logging directory path + log_file_path = os.path.join(log_directory, 'blast_py.log') + + logging.basicConfig( + level=logging.DEBUG, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler(log_file_path), + logging.StreamHandler() + ] + ) +except Exception as e: + print ("Failed to set up logging:", str(e)) + raise e + _log = logging.getLogger(__name__) class BlastTools(tools.Tool): @@ -49,12 +61,15 @@ class BlastnTool(BlastTools): """ Tool wrapper for blastn """ subtool_name = 'blastn' - def get_hits_pipe(self, inPipe, db, threads=None, task=None, outfmt='6', max_target_seqs=1, output_type="read_id"): - _log.debug(f"Executing get_hits_pipe function. Called with outfmt: {outfmt}") + def get_hits_pipe(self, inPipe, db, threads, outfmt, task=None, max_target_seqs=1, output_type="read_id", taxidlist=None): + start_time = time.time() + + #toggle between extracting read IDs only or full blast output (all lines) if output_type not in ['read_id', 'full_line']: _log.warning(f"Invalid output_type '{output_type}' specified. Defaulting to 'read_id'.") output_type = 'read_id' + _log.debug(f"Prior to running cmd, executing get_hits_pipe function. Called with task: {task} ,type: {type(task)},outfmt: {outfmt}") # run blastn and emit list of read IDs threads = util.misc.sanitize_thread_count(threads) cmd = [self.install_and_get_path(), @@ -66,56 +81,57 @@ def get_hits_pipe(self, inPipe, db, threads=None, task=None, outfmt='6', max_tar '-max_target_seqs', str(max_target_seqs), '-task', str(task) if task else 'blastn', ] + #Add taxidlist if specified by user + if taxidlist: + cmd.extend(['-taxidlist', taxidlist]) + _log.info(f"Using taxidlist: {taxidlist} in BLAST command") + cmd = [str(x) for x in cmd] #Log BLAST command executed - _log.debug('Running blastn command: {}'.format(' '.join(cmd))) - _log.debug('| ' + ' '.join(cmd) + ' |') - blast_pipe = subprocess.Popen(cmd, stdin=inPipe, stdout=subprocess.PIPE) - output, error = blast_pipe.communicate() - - #Display error message if BLAST failed - if blast_pipe.returncode!= 0: - _log.error('Error running blastn command: {}'.format(error)) - raise subprocess.CalledProcessError(blast_pipe.returncode, cmd) + _log.info('Running blastn command: {}'.format(' '.join(cmd))) - # If read_id is defined, strip tab output to just query read ID names and emit (default) - last_read_id = None - for line in output.decode('UTF-8').splitlines(): - if output_type == 'read_id': - #Split line by tab, and take the first element - read_id = line.split('\t')[0] - # Only emit if it is not a duplicate of the previous read ID - if read_id != last_read_id: - last_read_id = read_id - yield read_id - #Yield the full line without stripping whitespace - elif output_type == 'full_line': - yield line - - #Display on CMD if BLAST fails - if blast_pipe.returncode!= 0: - _log.error('Error running blastn command: {}'.format(error)) - raise subprocess.CalledProcessError(blast_pipe.returncode, cmd) - #Logging configuration written to blast_py.log if BLAST passes/fails - if blast_pipe.returncode == 0: - _log.info("Blastn process completed succesfully.") - else: - _log.error("Blastn process failed with exit code: %s", blast_pipe.returncode) - raise subprocess.CalledProcessError(blast_pipe.returncode, cmd) - + #Try/finally block added to ensure resource packages are cleaned up regardless of error raised + try: + with subprocess.Popen(cmd, stdin=inPipe, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as blast_pipe: + output, error = blast_pipe.communicate() + + if blast_pipe.returncode != 0: + _log.error(f'Error running blastn command: {error.decode()}') + raise subprocess.CalledProcessError(blast_pipe.returncode, cmd, output=output, stderr=error) + + # Process the output line by line in a generator fashion + last_read_id = None + for line in output.decode('UTF-8').splitlines(): + if output_type == 'read_id': + read_id = line.split('\t')[0] + if read_id != last_read_id: + last_read_id = read_id + yield read_id + elif output_type == 'full_line': + yield line + + finally: + # Ensure resources are cleaned up + _log.info("Cleaning up subprocess resources.") + + # Log successful completion and time taken + elapsed_time = time.time() - start_time + _log.debug(f"get_hits_pipe executed in {elapsed_time:.2f} seconds") def get_hits_bam(self, inBam, db, threads=None): return self.get_hits_pipe( tools.samtools.SamtoolsTool().bam2fa_pipe(inBam), db, threads=threads) - def get_hits_fasta(self, inFasta, db, threads=None, task=None, outfmt='6', max_target_seqs=1, output_type='read_id'): - _log.debug(f"Executing get_hits_fasta function. Called with outfmt: {outfmt}") + def get_hits_fasta(self, inFasta, db, threads, outfmt, task, max_target_seqs=1, output_type='read_id', taxidlist=None): + start_time = time.time() + _log.debug(f"Executing get_hits_fasta function. Called with outfmt: {outfmt}, taxidlist: {taxidlist}") with open(inFasta, 'rt') as inf: - for hit in self.get_hits_pipe(inf, db, threads=threads, task=None, outfmt=outfmt, max_target_seqs=max_target_seqs, output_type=output_type): + for hit in self.get_hits_pipe(inf, db=db, threads=threads, outfmt=outfmt, task=task, max_target_seqs=max_target_seqs, output_type=output_type, taxidlist=taxidlist): yield hit - - + elapsed_time = time.time() - start_time + _log.debug(f"get_hits_fasta exectued in {elapsed_time:.2f} seconds") + class MakeblastdbTool(BlastTools): """ Tool wrapper for makeblastdb """ subtool_name = 'makeblastdb' diff --git a/taxon_filter.py b/taxon_filter.py old mode 100755 new mode 100644 index c09ab170..dc03e79f --- a/taxon_filter.py +++ b/taxon_filter.py @@ -18,7 +18,7 @@ import shutil import concurrent.futures import contextlib - +import time from Bio import SeqIO import pysam @@ -37,19 +37,8 @@ import classify.bmtagger import read_utils -''' -#Adding logging configuration to identify issues/ time spent -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s', - handlers=[ - logging.FileHandler("task.log"), # Log file name - logging.StreamHandler() # Keeps the console output if desired - ] -) -''' - log = logging.getLogger(__name__) + # ======================= # *** deplete *** # ======================= @@ -403,114 +392,186 @@ def multi_db_deplete_bam(inBam, refDbs, deplete_method, outBam, **kwargs): # ======================== -# *** deplete_blastn *** +# *** chunk_blastn *** # ======================== -def _run_blastn_chunk(db, input_fasta, out_hits, blast_threads, task=None, outfmt='6', max_target_seqs=1, output_type='read_id'): +def _run_blastn_chunk(db, input_fasta, out_hits, blast_threads, outfmt="6", task=None, max_target_seqs=1, output_type='read_id', taxidlist=None): """ run blastn on the input fasta file. this is intended to be run in parallel by blastn_chunked_fasta """ - #Might need to remove this path, not absolute - #os.environ['BLASTDB']= 'viral-classify/blast' + start_time = time.time() try: with util.file.open_or_gzopen(out_hits, 'wt') as outf: - for line in classify.blast.BlastnTool().get_hits_fasta(input_fasta, db, threads=blast_threads, task=task, outfmt=outfmt, output_type=output_type): + for line in classify.blast.BlastnTool().get_hits_fasta(inFasta=input_fasta, db=db, threads=blast_threads, outfmt=outfmt, task=task, max_target_seqs=max_target_seqs, output_type=output_type, taxidlist=taxidlist): outf.write(line + '\n') - log.info("_run_blastn_chunk completed succesfully.") + log.debug("_run_blastn_chunk completed succesfully for one chunk.") except Exception as e: - log.error("An error occurred in _run_blastn_chunk.:%s", e) + log.error("An error occurred in _run_blastn_chunk.: %s", str(e)) raise e + elapsed_time = time.time() - start_time + log.debug(f"_run_blastn_chunk executed in {elapsed_time:.2f} seconds") -def blastn_chunked_fasta(fasta, db, out_hits, chunkSize=1000000, threads=None, task=None, outfmt='6', max_target_seqs=1, output_type='read_id'): +def blastn_chunked_fasta(fasta, db, out_hits, threads, outfmt="6", chunkSize=1000000, task=None, max_target_seqs=1, output_type='read_id', taxidlist=None): """ Helper function: blastn a fasta file, overcoming apparent memory leaks on an input with many query sequences, by splitting it into multiple chunks and running a new blastn process on each chunk. Return a list of output filenames containing hits """ + log.debug(f"Executing blastn_chunked_fasta function. Called with outfmt: {outfmt}, taxidlist: {taxidlist}") + start_time = time.time() # the lower bound of how small a fasta chunk can be. # too small and the overhead of spawning a new blast process # will be detrimental relative to actual computation time - - #checks if the blastn_chunked_fasta function is being called - log.info("Calling blastn_chunked_fasta function...") MIN_CHUNK_SIZE = 20000 - # just in case blast is not installed, install it once, not many times in parallel! + + #checks if the blastn_chunked_fasta function is being called classify.blast.BlastnTool().install() # clamp threadcount to number of CPU cores threads = util.misc.sanitize_thread_count(threads) + log.info(f"Sanitized thread count: {threads}") # determine size of input data; records in fasta file number_of_reads = util.file.fasta_length(fasta) - log.debug("number of reads in fasta file %s" % number_of_reads) + log.info("number of reads in fasta file %s" % number_of_reads) + + #Error raised if empty read file if number_of_reads == 0: + log.info("Number of reads is 0. Empty output file.") util.file.make_empty(out_hits) - - # divide (max, single-thread) chunksize by thread count - # to find the absolute max chunk size per thread - chunk_max_size_per_thread = chunkSize // threads - - # find the chunk size if evenly divided among blast threads - reads_per_thread = number_of_reads // threads - - # use the smaller of the two chunk sizes so we can run more copies of blast in parallel - chunkSize = min(reads_per_thread, chunk_max_size_per_thread) - - # if the chunk size is too small, impose a sensible size + return + + #----CHUNKING----# + # Setting each blast thread count to 4 + + blast_threads = 4 + max_workers_cpu = max(1, (threads // blast_threads)) + + # Calculate base number of reads per chunk + base_reads_per_chunk = number_of_reads // max_workers_cpu + + # Calculate remainder + remainder_reads = number_of_reads % max_workers_cpu + + # Adjust chunk sizes to distribute remainder reads + chunk_sizes = [base_reads_per_chunk + (1 if i < remainder_reads else 0) for i in range(max_workers_cpu)] + log.info(f"Print chunk sizes {chunk_sizes}") + + # Ensure that the user-defined chunk size is respected + chunkSize = min(chunkSize, max(chunk_sizes)) + + # Ensure the chunk size is not smaller than the minimum chunk size chunkSize = max(chunkSize, MIN_CHUNK_SIZE) - log.debug("chunk_max_size_per_thread %s" % chunk_max_size_per_thread) - - # adjust chunk size so we don't have a small fraction - # of a chunk running in its own blast process - # if the size of the last chunk is <80% the size of the others, - # decrease the chunk size until the last chunk is 80% - # this is bounded by the MIN_CHUNK_SIZE - while (number_of_reads / chunkSize) % 1 < 0.8 and chunkSize > MIN_CHUNK_SIZE: - chunkSize = chunkSize - 1 - - log.debug("blastn chunk size %s" % chunkSize) - log.debug("number of chunks to create %s" % (number_of_reads / chunkSize)) - log.debug("blastn parallel instances %s" % threads) - log.debug(f"outfmt value: {outfmt}") + log.info(f"blastn parallel instances {threads}") + log.info(f"outfmt value: {outfmt}") + # chunk the input file. This is a sequential operation input_fastas = [] with open(fasta, "rt") as fastaFile: record_iter = SeqIO.parse(fastaFile, "fasta") for batch in util.misc.batch_iterator(record_iter, chunkSize): chunk_fasta = mkstempfname('.fasta') - with open(chunk_fasta, "wt") as handle: count= SeqIO.write(batch, handle, "fasta") batch = None - #detail chunk sizes being processed - log.info(f"Created chunk {chunk_fasta} with {count} records") + log.info(f"Created chunk {chunk_fasta} with {count} reads.") input_fastas.append(chunk_fasta) num_chunks = len(input_fastas) - log.debug("number of chunk files to be processed by blastn %d" % num_chunks) - + log.info("number of chunk files to be processed by blastn %d" % num_chunks) + + #----EXECUTOR-----# + #Executor start time + start_time_executor = time.time() + # run blastn on each of the fasta input chunks + # Log the number of workers that will be used + log.info(f"Initializing executor with {max_workers_cpu} max_workers.") hits_files = list(mkstempfname('.hits.txt') for f in input_fastas) - with concurrent.futures.ProcessPoolExecutor(max_workers=threads) as executor: + with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers_cpu) as executor: # If we have so few chunks that there are cpus left over, # divide extra cpus evenly among chunks where possible # rounding to 1 if there are more chunks than extra threads. # Then double up this number to better maximize CPU usage. - cpus_leftover = threads - num_chunks - blast_threads = 2*max(1, int(cpus_leftover / num_chunks)) - for i in range(num_chunks): - executor.submit(_run_blastn_chunk, db, input_fastas[i], hits_files[i], blast_threads, task=task, outfmt=outfmt, max_target_seqs=max_target_seqs, output_type=output_type) - - # merge results and clean up + #cpus_leftover = threads - optimal_chunks + log.info(f"Total CPU threads: {threads} = Blast threads per chunk: {blast_threads} x max workers: {max_workers_cpu}") + + #Subumit each fasta chunk to the executor + futures = [] + for f, (fasta, hits) in enumerate(zip(input_fastas, hits_files)): + future = executor.submit(_run_blastn_chunk, db, fasta, hits, blast_threads, outfmt, task, max_target_seqs, output_type, taxidlist) + futures.append(future) + log.info(f"Submitted chunk {f} to executor with {blast_threads} threads per chunk.") + # Track the completion of futures + for f, future in enumerate(concurrent.futures.as_completed(futures)): + log.info(f"Chunk {f} processed with result: {future.result()}") + + #Measuring executor runtime + executor_elapsed_time = time.time() - start_time_executor + log.info(f"Executor for all chunks finished in {executor_elapsed_time:.2f} seconds.") + #----CLEAN UP------# + # Log starttime for cleanup + clean_up_start_time = time.time() + + # Clean up util.file.cat(out_hits, hits_files) for i in range(num_chunks): os.unlink(input_fastas[i]) os.unlink(hits_files[i]) + log.info("Cleaned up all temporary files.") + + #Measure clean up runtime + elapsed_clean_up = time.time() - clean_up_start_time + log.debug(f"clean up finished in {elapsed_clean_up:.2f} seconds") + + #----OVERALL RUNTIME------# + #Measure entire function runtime + elapsed_time = time.time() - start_time + log.info(f"Completed the WHOLE blastn_chunked_fasta in {elapsed_time:.2f} seconds.") + +def chunk_blast_hits(inFasta, db, blast_hits_output, threads, outfmt="6", chunkSize=1000000, task=None, max_target_seqs=1, output_type= 'read_id', taxidlist=None): + '''Process BLAST hits from a FASTA file by dividing the file into smaller chunks for parallel processing (blastn_chunked_fasta).''' + log.debug(f"Executing chunk_blast_hits function. Called with outfmt: {outfmt}, taxidlist: {taxidlist}") + if chunkSize > 0: + log.info("Running BLASTN on %s against database %s with chunkSize: %s", inFasta, db, chunkSize) + blastn_chunked_fasta(fasta=inFasta, db=db, out_hits=blast_hits_output, threads=threads, outfmt=outfmt, chunkSize=chunkSize, task=task, max_target_seqs=max_target_seqs, output_type=output_type, taxidlist=taxidlist) + else: + log.warning("Invalid or zero chunkSize provided (%s), running BLASTN without chunking.", chunkSize) + try: + with open(blast_hits_output, 'wt') as outf: + for output in classify.blast.BlastnTool().get_hits_fasta(inFasta=inFasta, db=db, threads=threads, task=task, outfmt=outfmt, max_target_seqs=max_target_seqs, output_type=output_type, taxidlist=taxidlist): + if output_type == 'read_id': + read_id = output.split('\t')[0] + outf.write(read_id + '\n') + else: + outf.write(output + '\n') + except Exception as e: + log.error("An error occurred while running BLASTN without chunking: %s", e) + raise +def parser_chunk_blast_hits(parser=argparse.ArgumentParser()): + parser.add_argument('inFasta', help='Input FASTA file.') + parser.add_argument('db', help='BLASTN database.') + parser.add_argument('blast_hits_output', help='Output file to store hits from BLASTN.') + parser.add_argument("--outfmt", type=str, default ="6", help="Output format for BLAST results.") + parser.add_argument("--chunkSize", type=int, default=1000000, help='Size of FASTA chunks for processing.') + parser.add_argument("--task", type=str, help="Type of BLAST search to perform, e.g., megablast, blastn, etc.") + parser.add_argument("--max_target_seqs", type=int, default=1, help="Maximum number of target sequences to return per query.") + parser.add_argument("--output_type", choices=["read_id", "full_line"], default="read_id", help="Specify the output type: read IDs or full BLAST output lines.") + parser.add_argument("--taxidlist", help="Optional path to a taxidlist file for limiting the BLAST search to specific taxa.", required=False) + util.cmd.common_args(parser, (('threads', None), ('loglevel', None), ('version', None), ('tmp_dir', None))) + util.cmd.attach_main(parser, chunk_blast_hits, split_args=True) + return parser + +__commands__.append(('chunk_blast_hits', parser_chunk_blast_hits)) + +# ======================== +# *** deplete_bwa *** +# ======================== def deplete_blastn_bam(inBam, db, outBam, threads=None, chunkSize=1000000, JVMmemory=None): #def deplete_blastn_bam(inBam, db, outBam, threads, chunkSize=0, JVMmemory=None): @@ -529,46 +590,14 @@ def deplete_blastn_bam(inBam, db, outBam, threads=None, chunkSize=1000000, JVMme else: ## pipe tools together and run blastn multithreaded with open(blast_hits, 'wt') as outf: - for read_id in classify.blast.BlastnTool().get_hits_bam(inBam, db_prefix, threads=threads): + for read_id in classify.blast.BlastnTool().get_hits_bam(inBam, db_prefix, threads,task=None, outfmt='6', max_target_seqs=1): outf.write(read_id + '\n') + # Deplete BAM of hits tools.picard.FilterSamReadsTool().execute(inBam, True, blast_hits, outBam, JVMmemory=JVMmemory) os.unlink(blast_hits) -def chunk_blast_hits(inFasta, db, blast_hits_output, threads=None, chunkSize=1000000, task=None, outfmt='6', max_target_seqs=1, output_type= 'read_id'): - 'Process BLAST hits from a FASTA file by dividing the file into smaller chunks for parallel processing (blastn_chunked_fasta).' - if chunkSize: - log.info("Running BLASTN on %s against database %s", inFasta, db) - #Execute blastn_chunked_fasta - blastn_chunked_fasta(inFasta, db, blast_hits_output, chunkSize, threads, task, outfmt, max_target_seqs, output_type=output_type) - else: - #Pipe tools together and run blastn multithreaded - with open(blast_hits_output, 'wt') as outf: - for output in classify.blast.BlastnTool().get_hits_fasta(inFasta, db, threads, task=task, outfmt=outfmt, max_target_seqs=max_target_seqs, output_type=output_type): - #Account for read_ids extract only or full blast output run. Default = read_lines. - if output_type == 'read_id': - # Extract the first clmn in the output (assuming its the read ID) - read_id = output.split('\t')[0] - outf.write(read_id + '\n') - else: - #Extract and write full line if the output_type is not set to just read IDs - outf.write(output + '\n') - -def parser_chunk_blast_hits(parser=argparse.ArgumentParser()): - parser = argparse.ArgumentParser(description="Run BLASTN on chunks of a FASTA file.") - parser.add_argument('inBam', help='Input BAM file.') - parser.add_argument('db', help='BLASTN database.') - parser.add_argument('blast_hits_output', help='Stores hits found by BLASTN.') - parser.add_argument("--chunkSize", type=int, default=1000000, help='FASTA chunk size (default: %(default)s)') - parser.add_argument("-task", help="details the type of search (i.e. megablast,blatn,etc)") - parser.add_argument("-outfmt", type=str, default=6, help="Custom output formats(default: %(default)s)") - parser.add_argument("-max_target_seqs", type=int, default=1, help="BLAST will return the first (if set to default) database hits for a sequence query. (default: %(default)s)") - parser.add_argument("--output_type", default= "read_id", choices=["read_id", "full_line"], help="Specify the type of output: 'read_id' for read IDs only, or 'full_line' for full BLAST output lines. Default is 'read_id'. Useful when adding taxonomy IDs to outfmt type 6.") - util.cmd.common_args(parser, (('threads', None), ('loglevel', None), ('version', None), ('tmp_dir', None))) - util.cmd.attach_main(parser, chunk_blast_hits) - return parser - def parser_deplete_blastn_bam(parser=argparse.ArgumentParser()): parser.add_argument('inBam', help='Input BAM file.') parser.add_argument('refDbs', nargs='+', help='One or more reference databases for blast. ' @@ -902,6 +931,6 @@ def parser_bmtagger_build_db(parser=argparse.ArgumentParser()): def full_parser(): return util.cmd.make_parser(__commands__, __doc__) - if __name__ == '__main__': util.cmd.main_argparse(__commands__, __doc__) +