diff --git a/Utilities/for_trees/CladeGrabbing.py b/Utilities/for_trees/CladeGrabbing.py index 9bc7027..00775b4 100644 --- a/Utilities/for_trees/CladeGrabbing.py +++ b/Utilities/for_trees/CladeGrabbing.py @@ -1,10 +1,10 @@ -#Author, date: Auden Cote-L'Heureux, last updated Apr 1st 2024 by GA +#Author, date: Auden Cote-L'Heureux, last updated Aug 18th 2025 by AKG #Motivation: Select robust sequences from trees #Intent: Select clades of interest from large trees using taxonomic specifications #Dependencies: Python3, ete3, Biopython #Inputs: A folder containing: all PTLp2 output trees and all corresponding unaligned .fasta (pre-guidance) files #Outputs: A folder of grabbed clades and filtered unaligned fasta files -#Example: python CladeGrabbing.py --input /Path/To/TreesandPreGuidance --target Sr_rh --min_presence 20 +#Example: python3 CladeGrabbing.py --input /Path/To/TreesandPreGuidance --target Sr_rh --min_presence 20 #IMPORTANT: key parameters explained in "add_argument" section below #Dependencies @@ -28,6 +28,8 @@ def get_args(): parser.add_argument('-nr', '--required_taxa_num', type = int, default = 0, help = 'The number of species belonging to taxa in the --required_taxa list that must be present in the clade. Default is 0.') parser.add_argument('-o', '--outgroup', type = str, default = '', help = 'A comma-separated list of any number of digits/characters (e.g. Sr_ci_S OR Am_t), or a file with the extension .txt containing a list of complete or partial taxon codes, to describe taxa that will be included as outgroups in the output unaligned fasta files (which will contain only sequences from a single selected clade, and all outgroup sequences in the tree captured by this argument).') parser.add_argument('-c', '--contaminants', type = float, default = 2, help = 'The number of non-ingroup contaminants allowed in a clade, or if less than 1 the proportion of sequences in a clade that can be non-ingroup (i.e. presumed contaminants). Default is to allow 2 contaminants.') + parser.add_argument('-ft', '--first_target', type=str, default='', help='[Optional] A comma-separated list or .txt file of complete/partial taxon codes for an initial, broad clade search. If provided, the script will first find clades with these taxa before applying the main --target filter.') + parser.add_argument('-fm', '--first_min_presence', type=int, default=0, help='[Optional] Minimum number of sequences from --first_target required in a clade for it to be used in the second-stage search. Ignored if --first_target is not provided.') return parser.parse_args() @@ -85,86 +87,155 @@ def reroot(tree): def get_subtrees(args, file): - newick = get_newick(args.input + '/' + file) + newick = get_newick(args.input + '/' + file) + tree = ete3.Tree(newick) - tree = ete3.Tree(newick) + majs = list(dict.fromkeys([leaf.name[:2] for leaf in tree])) - majs = list(dict.fromkeys([leaf.name[:2] for leaf in tree])) + # Only try to reroot trees with more than 2 major clades (original behavior) + if len(majs) > 2: + tree = reroot(tree) - #Only try to reroot trees with more than 2 major clades. This was added to fix the ETE3 "Cannot set myself as outgroup" error - if len(majs) > 2: - tree = reroot(tree) + # ------------------------------- + # FIRST-STAGE (optional) FILTER + # ------------------------------- + def get_outer_leafsets(): + """ + Return a list of sets, each set = leaf names of an outer clade + that passes --first_target, --first_min_presence, children_keep, + and contaminants logic (using args.contaminants). + If --first_target is not used, return one set containing ALL leaves. + """ + if not args.first_target or args.first_min_presence == 0: + return [set(leaf.name for leaf in tree)] # no outer filter → whole tree - #Getting a clean list of all target taxa - if '.' in args.target: - try: - target_codes = [l.strip() for l in open(args.target, 'r').readlines() if l.strip() != ''] - except AttributeError: - print('\n\nError: invalid "target" argument. This must be a comma-separated list of any number of digits/characters to describe focal taxa (e.g. Sr_ci_S OR Am_t), or a file with the extension .txt containing a list of complete or partial taxon codes. All sequences containing the complete/partial code will be identified as belonging to target taxa.\n\n') - else: - target_codes = [code.strip() for code in args.target.split(',') if code.strip() != ''] + # Parse first_target codes + if '.' in args.first_target: + first_target_codes = [l.strip() for l in open(args.first_target, 'r').readlines() if l.strip() != ''] + else: + first_target_codes = [code.strip() for code in args.first_target.split(',') if code.strip() != ''] - #Getting a clean list of all "at least" taxa - if '.' in args.required_taxa: - try: - required_taxa_codes = [l.strip() for l in open(args.required_taxa, 'r').readlines() if l.strip() != ''] - except AttributeError: - print('\n\nError: invalid "required_taxa" argument. This must be a comma-separated list of any number of digits/characters (e.g. Sr_ci_S OR Am_t), or a file with the extension .txt containing a list of complete or partial taxon codes, to describe taxa that MUST be present in a clade for it to be selected (e.g. you may want at least one whole genome).\n\n') - else: - required_taxa_codes = [code.strip() for code in args.required_taxa.split(',') if code.strip() != ''] + outer_sets = [] + seen_leaves = [] - target_codes = list(dict.fromkeys(target_codes + required_taxa_codes)) - + for node in tree.traverse('levelorder'): + # large enough and not subsumed by already accepted outer node + if len(node) >= args.first_min_presence and len(set(seen_leaves) & set([leaf.name for leaf in node])) == 0: + leaves = [leaf.name for leaf in node] - #Creating a record of selected subtrees, and all of the leaves in those subtrees - selected_nodes = []; seen_leaves = [] + # children_keep logic but for first_target + children_keep = 0 + for child in node.children: + taken = False + for code in first_target_codes: + for leaf in child: + if leaf.name.startswith(code): + children_keep += 1 + taken = True + break + if taken: + break + if children_keep != len(node.children): + continue - #Iterating through all nodes in tree, starting at "root" then working towards leaves - for node in tree.traverse('levelorder'): - #If a node is large enough and is not contained in an already selected clade + # count first-target hits (use [:10] uniqueness like original) + first_hits = set() + for code in first_target_codes: + for leaf in leaves[::-1]: + if leaf.startswith(code): + first_hits.add(leaf[:10]) + leaves.remove(leaf) - if len(node) >= args.min_presence and len(list(set(seen_leaves) & set([leaf.name for leaf in node]))) == 0: - leaves = [leaf.name for leaf in node] + # contaminants logic applied to FIRST-STAGE (reuse args.contaminants) + passes_contam = ((args.contaminants < 1 and len(leaves) <= args.contaminants * len(first_hits)) or + (args.contaminants >= 1 and len(leaves) <= args.contaminants)) - #Accounting for cases where e.g. one child is a contaminant, and the other child is a good clade with 1 fewer than the max number of contaminants - children_keep = 0 - for child in node.children: - for code in target_codes: - taken = False - for leaf in child: - if leaf.name.startswith(code): - children_keep += 1 - taken = True - break - if taken: - break + if len(first_hits) >= args.first_min_presence and passes_contam: + outer_sets.append(set(leaf.name for leaf in node)) + seen_leaves.extend([leaf.name for leaf in node]) - if children_keep == len(node.children): - target_leaves = set(); required_taxa_leaves = set() - - for code in target_codes: - for leaf in leaves[::-1]: - #print(leaf) - if leaf.startswith(code): - target_leaves.add(leaf[:10]) + return outer_sets - for req in required_taxa_codes: - if leaf.startswith(req): - required_taxa_leaves.add(leaf[:10]) - break - leaves.remove(leaf) + # Build outer sets; if user supplied first-stage args, we'll restrict inner search to these + using_first = bool(args.first_target) and args.first_min_presence > 0 + outer_leafsets = get_outer_leafsets() + # -------------------------------- + # ORIGINAL INNER FILTER (unchanged) + # -------------------------------- + # Getting a clean list of all target taxa + if '.' in args.target: + try: + target_codes = [l.strip() for l in open(args.target, 'r').readlines() if l.strip() != ''] + except AttributeError: + print('\n\nError: invalid "target" argument. This must be a comma-separated list of any number of digits/characters to describe focal taxa (e.g. Sr_ci_S OR Am_t), or a file with the extension .txt containing a list of complete or partial taxon codes. All sequences containing the complete/partial code will be identified as belonging to target taxa.\n\n') + else: + target_codes = [code.strip() for code in args.target.split(',') if code.strip() != ''] + # Getting a clean list of all "at least" taxa + if '.' in args.required_taxa: + try: + required_taxa_codes = [l.strip() for l in open(args.required_taxa, 'r').readlines() if l.strip() != ''] + except AttributeError: + print('\n\nError: invalid "required_taxa" argument. This must be a comma-separated list of any number of digits/characters (e.g. Sr_ci_S OR Am_t), or a file with the extension .txt containing a list of complete or partial taxon codes, to describe taxa that MUST be present in a clade for it to be selected (e.g. you may want at least one whole genome).\n\n') + else: + required_taxa_codes = [code.strip() for code in args.required_taxa.split(',') if code.strip() != ''] - #Grab a clade as a subtree if 1) it has enough target taxa; 2) it has enough "at least" taxa; 3) it does not have too many contaminants - if len(target_leaves) >= args.min_presence and len(required_taxa_leaves) >= args.required_taxa_num and ((args.contaminants < 1 and len(leaves) <= args.contaminants * len(target_leaves)) or len(leaves) <= args.contaminants): - selected_nodes.append(node) - seen_leaves.extend([leaf.name for leaf in node]) - #Write the subtrees to output .tre files - for i, node in enumerate(selected_nodes[::-1]): - with open('Subtrees/' + '.'.join(file.split('.')[:-1]) + '_' + str(i) + '.tre', 'w') as o: - o.write(node.write()) + target_codes = list(dict.fromkeys(target_codes + required_taxa_codes)) + + # Creating a record of selected subtrees, and all of the leaves in those subtrees + selected_nodes = []; seen_leaves = [] + + # Iterating through all nodes in tree, starting at "root" then working towards leaves + for node in tree.traverse('levelorder'): + # If using first-stage filter, only consider nodes fully inside some outer clade + if using_first: + node_leafs = set(leaf.name for leaf in node) + # require subset (node fully contained in an accepted outer clade) + if not any(node_leafs.issubset(S) for S in outer_leafsets): + continue + + # If a node is large enough and is not contained in an already selected clade + if len(node) >= args.min_presence and len(list(set(seen_leaves) & set([leaf.name for leaf in node]))) == 0: + leaves = [leaf.name for leaf in node] + + # Accounting for cases where e.g. one child is a contaminant, and the other child is a good clade + children_keep = 0 + for child in node.children: + for code in target_codes: + taken = False + for leaf in child: + if leaf.name.startswith(code): + children_keep += 1 + taken = True + break + if taken: + break + + if children_keep == len(node.children): + target_leaves = set(); required_taxa_leaves = set() + + for code in target_codes: + for leaf in leaves[::-1]: + if leaf.startswith(code): + target_leaves.add(leaf[:10]) + + for req in required_taxa_codes: + if leaf.startswith(req): + required_taxa_leaves.add(leaf[:10]) + break + leaves.remove(leaf) + + # Grab a clade as a subtree if it passes all filters + if len(target_leaves) >= args.min_presence and len(required_taxa_leaves) >= args.required_taxa_num and ((args.contaminants < 1 and len(leaves) <= args.contaminants * len(target_leaves)) or len(leaves) <= args.contaminants): + selected_nodes.append(node) + seen_leaves.extend([leaf.name for leaf in node]) + + # Write the subtrees to output .tre files + for i, node in enumerate(selected_nodes[::-1]): + with open('Subtrees/' + '.'.join(file.split('.')[:-1]) + '_' + str(i) + '.tre', 'w') as o: + o.write(node.write()) def make_new_unaligned(args):