"""Module containing code for calculating ancestral states."""
import math
import numpy as np
import scipy.linalg as la
from lmpy import Matrix, TreeWrapper
# .............................................................................
[docs]def _get_node_label(node):
"""Returns the node label or taxon label if node is a tip.
Args:
node (Node): A tree node to get the label for.
Returns:
str: The node label or taxon label.
"""
if node.label is not None:
return node.label
else:
return node.taxon.label
# .............................................................................
[docs]def calculate_ancestral_distributions(tree, char_mtx):
"""Calculates ancestral distributions.
Args:
tree (Tree): A dendropy tree or TreeWrapper object.
char_mtx (Matrix): A Matrix object with character information. Each
row should represent a tip in the tree and each column should be a
bin to calculate ancestral distribution.
Returns:
A matrix of character data with the following dimensions:
* rows: nodes / tips in the tree
* columns: character variables
* depth: first is the calculated value, second layer is standard
error if desired
"""
return calculate_continuous_ancestral_states(
tree, char_mtx,
sum_to_one=True,
calc_std_err=True
)
# .............................................................................
[docs]def calculate_continuous_ancestral_states(
tree,
char_mtx,
sum_to_one=False,
calc_std_err=False
):
"""Calculates the continuous ancestral states for the nodes in a tree.
Args:
tree (Tree): A dendropy tree or TreeWrapper object.
char_mtx (Matrix): A Matrix object with character information. Each
row should represent a tip in the tree and each column should be a
variable to calculate ancestral state for.
calc_std_err (:obj:`bool`, optional): If True, calculate standard error
for each variable. Defaults to False.
sum_to_one (:obj:`bool`, optional): If True, standardize the character
matrix so that the values in a row sum to one. Defaults to False.
Raises:
ValueError: Raised if none of the tree tips were found in the character data.
Returns:
A matrix of character data with the following dimensions:
* rows: nodes / tips in the tree
* columns: character variables
* depth: first is the calculated value, second layer is standard
error if desired
Todo:
* Add function for consistent label handling.
"""
# Wrap tree if dendropy tree
if not isinstance(tree, TreeWrapper):
tree = TreeWrapper.from_base_tree(tree)
# Assign labels to nodes that don't have them
tree.add_node_labels()
# Synchronize tree and character data
# Prune tree
prune_taxa = []
keep_taxon_labels = []
init_row_headers = char_mtx.get_row_headers()
for taxon in tree.taxon_namespace:
label = taxon.label.replace(' ', '_')
if label not in init_row_headers:
prune_taxa.append(taxon)
print(
'Could not find {} in character matrix, pruning'.format(label))
else:
keep_taxon_labels.append(label)
if len(keep_taxon_labels) == 0:
raise ValueError('None of the tree tips were found in the character data')
tree.prune_taxa(prune_taxa)
tree.purge_taxon_namespace()
# Prune character data
keep_rows = []
i = 0
for label in init_row_headers:
if label in keep_taxon_labels:
keep_rows.append(i)
else:
print('Could not find {} in tree tips, pruning'.format(label))
i += 1
char_mtx = char_mtx.slice(keep_rows)
# Standardize character matrix if requested
tip_count, num_vars = char_mtx.shape
if sum_to_one:
for i in range(tip_count):
sc = float(1.0) / np.sum(char_mtx[i])
for j in range(num_vars):
char_mtx[i, j] *= sc
# Initialize data matrix
num_nodes = len(tree.nodes())
data_shape = (num_nodes, num_vars, 2 if calc_std_err else 1)
data = np.zeros(data_shape, dtype=float)
# Initialize headers
row_headers = []
tip_col_headers = char_mtx.get_column_headers()
tip_row_headers = char_mtx.get_row_headers()
tip_lookup = {
tip_row_headers[i].replace('_', ' '): i for i in range(tip_count)}
# Get the number of internal nodes in the tree
internal_node_count = num_nodes - tip_count
# Loop through the tree and set the matrix index for each node
# Also set data values
node_headers = []
node_i = tip_count
tip_i = 0
node_index_lookup = {}
for node in tree.nodes():
label = _get_node_label(node)
if len(node.child_nodes()) == 0:
# Tip
node_index_lookup[label] = tip_i
row_headers.append(label)
data[tip_i, :, 0] = char_mtx[tip_lookup[label]]
tip_i += 1
else:
node_index_lookup[label] = node_i
node_headers.append(label)
# Internal node
data[node_i, :, 0] = np.zeros((1, num_vars), dtype=float)
node_i += 1
# Row headers should be extended with node headers
row_headers.extend(node_headers)
# For each variable
for x in range(num_vars):
# Compute the ML estimate of the root
full_mcp = np.zeros((internal_node_count, internal_node_count),
dtype=float)
full_vcp = np.zeros(internal_node_count, dtype=float)
for k in tree.postorder_edge_iter():
i = k.head_node
if len(i.child_nodes()) != 0:
node_num_i = node_index_lookup[_get_node_label(i)] - tip_count
for j in i.child_nodes():
tbl = 2./j.edge_length
full_mcp[node_num_i][node_num_i] += tbl
node_num_j = node_index_lookup[_get_node_label(j)]
if len(j.child_nodes()) == 0:
full_vcp[node_num_i] += (data[node_num_j, x, 0] * tbl)
else:
node_num_j -= tip_count
full_mcp[node_num_i][node_num_j] -= tbl
full_mcp[node_num_j][node_num_i] -= tbl
full_mcp[node_num_j][node_num_j] += tbl
b = la.cho_factor(full_mcp)
# these are the ML estimates for the ancestral states
ml_est = la.cho_solve(b, full_vcp)
sos = 0
for k in tree.postorder_edge_iter():
i = k.head_node
node_num_i = node_index_lookup[_get_node_label(i)]
if len(i.child_nodes()) != 0:
data[node_num_i, x, 0] = ml_est[node_num_i - tip_count]
if calc_std_err:
for j in i.child_nodes():
node_num_j = node_index_lookup[_get_node_label(j)]
temp = data[node_num_i, x, 0] - data[node_num_j, x, 0]
sos += temp * temp / j.edge_length
# nni is node_num_i adjusted for only nodes
nni = node_num_i - tip_count
qpq = full_mcp[nni][nni]
tm1 = np.delete(full_mcp, (nni), axis=0)
tm = np.delete(tm1, (nni), axis=1)
b = la.cho_factor(tm)
sol = la.cho_solve(b, tm1[:, nni])
temp_std_err = qpq - np.inner(tm1[:, nni], sol)
data[node_num_i, x, 1] = math.sqrt(2.0 * sos / (
(internal_node_count - 1) * temp_std_err))
depth_headers = ['maximum_likelihood']
if calc_std_err:
depth_headers.append('standard_error')
mtx_headers = {
'0': row_headers,
'1': tip_col_headers,
'2': depth_headers
}
return tree, Matrix(data, headers=mtx_headers)