diff --git a/examples/cluster_tree b/examples/cluster_tree new file mode 100644 index 0000000..b9f2191 --- /dev/null +++ b/examples/cluster_tree @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 + +"""Clustering Model Based on Decision Tree +""" + +import numpy as np +import numpy.linalg as LA +from scipy.stats import entropy +from scipy.spatial.distance import cdist + +import pandas as pd + +from treelib import Node, Tree + +from sklearn.base import ClusterMixin, BaseEstimator +from sklearn.cluster import KMeans + + +class TreeCluster(BaseEstimator, ClusterMixin, Tree): + """Decision Tree for classification/cluster + + epsilon: the threshold of info gain or other select method + selection_method: the selection method + features_: the features of the input vars + classes_: the classes of output vars + """ + + def __init__(self, epsilon=0.6, features=None, classes=None, *args, **kwargs): + super().__init__(*args, **kwargs) + self.epsilon = epsilon + self.features_ = features + self.classes_ = classes + + def fit(self, X, Y=None, mean=None, level=()): + """ + calc cond_proba, proba, priori_proba, features + then call fit_with_proba + + Arguments: + X {2D array|list|dataframe} -- input vars + + Returns: + TreeCluster + """ + + kmeans = KMeans(n_clusters=2) + + if mean is None: + mean = X.mean(axis=0) + self.add_node(Node(tag='-'.join(map(str, level)), identifier=level, data={'mean':mean})) + + if len(X)>2: + kmeans.fit(X) + y = kmeans.predict(X) + classes_ = np.unique(y) + means_ = kmeans.cluster_centers_ + + gain = 1 - kmeans.inertia_ / LA.norm(X - mean, 'fro')**2 + + if gain > self.epsilon: + for k, m in zip(classes_, means_): + t = TreeCluster(epsilon=self.epsilon) + t.fit(X[y==k], mean=m, level=level+(k,)) + self.paste(level, t) + + if level == (): + # get cluster centers from the data of the nodes + self.cluster_centers_ = [node.data['mean'] for node in self.all_nodes_itr() if node.is_leaf()] + self.classes_ = np.arange(len(self.cluster_centers_)) + + return self + + def predict_proba(self, X): + distances = np.exp(-cdist(X, self.cluster_centers_)) + return distances / distances.sum(axis=0)[None,:] + + def predict(self, X): + p = self.predict_proba(X) + return self.classes_[np.argmax(p, axis=1)] + + +if __name__ == '__main__': + + from sklearn import datasets + + iris = datasets.load_iris() + X_train, y_train = iris.data, iris.target + + tc = TreeCluster(epsilon=0.5) + tc.fit(X_train) + y_ = tc.predict(X_train) + + print(tc) + print(tc.cluster_centers_) diff --git a/examples/huffman_tree.py b/examples/huffman_tree.py new file mode 100644 index 0000000..783e44c --- /dev/null +++ b/examples/huffman_tree.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python + + +"""Huffman coding +""" + +from toolz import concat +from treelib import Tree, Node + +import numpy as np + + +def _get_symbols(tree): + """Get `symbols` from the root of a tree or a node + + tree: Tree or Node + """ + if isinstance(tree, Node): + a = tree.data["symbols"] + else: + a = tree.get_node(tree.root).data["symbols"] + if isinstance(a, str): + return [a] + else: + return a + + +def _get_frequency(tree): + """Get `frequency` from the root of a tree or a node + + tree: Tree or Node + """ + if isinstance(tree, Node): + a = tree.data["frequency"] + else: + a = tree.get_node(tree.root).data["frequency"] + if isinstance(a, str): + return [a] + else: + return a + + +def merge(trees, level=""): + """merge the trees to one tree by add a root + + Args: + trees (list): list of trees or nodes + level (tuple, optional): the prefix for identifier + + Returns: + Tree + """ + + data = list(concat(map(_get_symbols, trees))) + freq = sum(map(_get_frequency, trees)) + t = Tree() + root = Node(identifier=level, data={"symbols": data, "frequency": freq, "code": ""}) + t.add_node(root) + t.root = level + root.tag = f"root: {{{','.join(root.data['symbols'])}}}/{root.data['frequency']}" + for k, tree in enumerate(trees): + if isinstance(tree, Node): + tree.identifier = f"{k}" + tree.identifier + tree.data["code"] = f"{k}" + tree.data["code"] + tree.tag = f"{tree.data['code']}: {{{','.join(tree.data['symbols'])}}}/{tree.data['frequency']}" + t.add_node(tree, parent=level) + else: + for n in tree.all_nodes_itr(): + n.identifier = f"{k}" + n.identifier + n.data["code"] = f"{k}" + n.data["code"] + n.tag = f"{n.data['code']}: {{{','.join(n.data['symbols'])}}}/{n.data['frequency']}" + + tree._nodes = {n.identifier: n for k, n in tree._nodes.items()} + tree.root = f"{k}{tree.root}" + tid = tree.identifier + for n in tree.all_nodes_itr(): + if n.is_root(): + n.set_successors([f"{k}{nid}" for nid in n._successors[tid]], tid) + elif n.is_leaf(): + n.set_predecessor(f"{k}{n._predecessor[tid]}", tid) + else: + n.set_predecessor(f"{k}{n._predecessor[tid]}", tid) + n.set_successors([f"{k}{nid}" for nid in n._successors[tid]], tid) + + t.paste(level, tree, deep=True) + return t + + +def huffman_tree(trees, level="", n_branches=2): + """Huffman coding + + Args: + trees (list): list of trees or nodes + level (tuple, optional): the prefix for identifier + set n_branches=2 by default + + Returns: + Tree: Huffman tree + """ + assert len(trees) >= 2 + + if len(trees) == 2: + return merge(trees, level=level) + else: + ks = np.argsort([_get_frequency(tree) for tree in trees])[:n_branches] + t = merge([trees[k] for k in ks], level=level) + trees = [t, *(tree for k, tree in enumerate(trees) if k not in ks)] + return huffman_tree(trees, level=level) + + +def make_node(s, f): + """Make `Node` object + + s: str + f: number + """ + return Node(identifier="", data={"symbols": s, "frequency": f, "code": ""}) + + +d = {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5} +nodes = [make_node(s, f) for s, f in d.items()] +nodes = list(nodes) +t = huffman_tree(nodes) +print(t) diff --git a/examples/random_tree.py b/examples/random_tree.py new file mode 100644 index 0000000..96827e8 --- /dev/null +++ b/examples/random_tree.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 + +""" +Generate a tree randomly; Test the `apply` method; +""" + +import random +from treelib import Tree + + +def _random(max_depth=5, min_width=1, max_width=2, offset=()): + # generate a tree randomly + tree = Tree() + tree.create_node(identifier=offset) + if max_depth == 0: + return tree + elif max_depth == 1: + nb = random.randint(min_width, max_width) + for i in range(nb): + identifier = offset + (i,) + tree.create_node(identifier=identifier, parent=offset) + else: + nb = random.randint(min_width, max_width) + for i in range(nb): + _offset = offset + (i,) + max_depth -= 1 + subtree = _random(max_depth=max_depth, max_width=max_width, offset=_offset) + tree.paste(offset, subtree) + return tree + + +def _map(func, tree): + # tree as a functor + tree = tree._clone(with_tree=True) + print(tree) + for a in tree.all_nodes_itr(): + key(a) + return tree + + +def key(node): + node.tag = "-".join(map(str, node.identifier)) + + +print(_map(key, _random())) + +print(_random().apply(key)) diff --git a/treelib/node.py b/treelib/node.py index 958c734..067f30b 100644 --- a/treelib/node.py +++ b/treelib/node.py @@ -267,7 +267,7 @@ def tag(self): @tag.setter def tag(self, value): """Set the value of `_tag`.""" - self._tag = value if value is not None else None + self._tag = value def __repr__(self): name = self.__class__.__name__ diff --git a/treelib/tree.py b/treelib/tree.py index 1cc9ac8..aadb62c 100644 --- a/treelib/tree.py +++ b/treelib/tree.py @@ -336,6 +336,12 @@ def all_nodes_itr(self): """ return self._nodes.values() + def iternodes(self): + """ + alias of `all_nodes_itr` but conform to the convention of Python. + """ + return self._nodes.values() + def ancestor(self, nid, level=None): """ For a given id, get ancestor node object at a given level. @@ -377,7 +383,7 @@ def children(self, nid): def contains(self, nid): """Check if the tree contains node of given id""" - return True if nid in self._nodes else False + return nid in self._nodes def create_node(self, tag=None, identifier=None, parent=None, data=None): """ @@ -515,6 +521,9 @@ def get_node(self, nid): return None return self._nodes[nid] + def get_root(self): + return self.get_node(self.root) + def is_branch(self, nid): """ Return the children (ID) list of nid. @@ -689,10 +698,12 @@ def paste(self, nid, new_tree, deep=False): if set_joint: raise ValueError("Duplicated nodes %s exists." % list(map(text, set_joint))) - for cid, node in iteritems(new_tree.nodes): - if deep: - node = deepcopy(new_tree[node]) - self._nodes.update({cid: node}) + if deep: + new_nodes = {cid: deepcopy(node) for cid, node in iteritems(new_tree.nodes)} + else: + new_nodes = new_tree.nodes + self._nodes.update(new_nodes) + for _, node in iteritems(new_nodes): node.clone_pointers(new_tree.identifier, self._identifier) self.__update_bpointer(new_tree.root, nid) @@ -1003,9 +1014,8 @@ def subtree(self, nid, identifier=None): # define nodes parent/children in this tree # all pointers are the same as copied tree, except the root st[node_n].clone_pointers(self._identifier, st.identifier) - if node_n == nid: - # reset root parent for the new tree - st[node_n].set_predecessor(None, st.identifier) + # reset root parent for the new tree + st[nid].set_predecessor(None, st.identifier) return st def update_node(self, nid, **attrs): @@ -1129,6 +1139,32 @@ def to_graphviz( f.close() + def apply(self, key, deep=True): + """Morphism of tree + Work like the built-in `map` + + Arguments + key -- impure function of a node + deep -- please keep it true + """ + tree = self._clone(with_tree=True, deep=deep) + for a in tree.all_nodes(): + key(a) + return tree + + def apply_data(self, key, deep=True): + """morphism of tree, but acts on data of nodes. + It calls the method `apply` + + Arguments + key -- pure function of node.data + """ + + def _key(a): + a.data = key(a.data) + + return self.apply(_key, deep=deep) + @classmethod def from_map(cls, child_parent_dict, id_func=None, data_func=None): """