Source code for graphdot.model.tree_search.graph_transformer

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
from scipy.stats import norm
from graphdot.util.iterable import argmax
from ._tree import Tree


[docs]class MCTSGraphTransformer: '''A varient of Monte Carlo tree search for optimization and root-finding in a space of graphs. Parameters ---------- rewriter: callable A callable that implements the :py:class:`Rewriter` abstract class. surrogate: object A predictor used to calculate the target property of a given graph. exploration_bias: float Tunes the preference of the MCTS model between exploitation and exploration of the search space. precision: float Target precision of MCTS search outcome. ''' def __init__(self, rewriter, surrogate, exploration_bias=1.0, precision=0.01): self.rewriter = rewriter self.surrogate = surrogate self.exploration_bias = exploration_bias self.precision = precision
[docs] def seek(self, g0, target, maxiter=500, return_tree=False, random_state=None): '''Transforms an initial graph into one with a specific desired target property value. Parameters ---------- g0: object A graph to start the tree search with. target: float Target property value of the desired graph. maxiter: int Maximum number of MCTS iterations to perform. return_tree: Boolean Whether or not to return the search tree in its original form or as a flattened dataframe. random_state: int or :py:`np.random.Generator` The seed to the random number generator (RNG), or the RNG itself. If None, the default RNG in numpy will be used. Returns ------- tree: DataFrame If `return_tree` is True, a hierarchical dataframe representing the search tree will be returned; otherwise, a flattened dataframe will be returned. ''' random_state = self._parse_random_state(random_state) tree = self._spawn(None, [g0]) self._evaluate(tree) for _ in range(maxiter): self._mcts_step( tree, lambda nodes: self._likelihood_ucb(target, nodes), random_state=random_state ) if return_tree is True: return tree else: df = tree.flat df['likelihood'] = self._likelihood(target, df) return df.to_pandas().sort_values(['likelihood'], ascending=False)
@staticmethod def _parse_random_state(random_state): if isinstance(random_state, np.random.Generator): return random_state elif random_state is not None: return np.random.Generator(np.random.PCG64(random_state)) else: return np.random.default_rng() def _spawn(self, node, leaves): return Tree( parent=[node] * len(leaves), children=[None] * len(leaves), g=leaves, visits=np.zeros(len(leaves), dtype=np.int) ) def _likelihood(self, target, nodes): return norm.pdf( target, nodes.tree_mean, np.maximum(nodes.tree_std, self.precision) # This line below does not work, especially the '+' part: # target, nodes.tree_mean, nodes.tree_std + self.precision ) def _confidence_bounds(self, nodes): return self.exploration_bias * np.sqrt( np.log(nodes.parent[0].visits) / nodes.visits ) def _likelihood_ucb(self, target, nodes): return self._likelihood(target, nodes) + self._confidence_bounds(nodes) def _evaluate(self, nodes): mean, cov = self.surrogate.predict(nodes.g, return_cov=True) nodes['self_mean'] = mean.copy() nodes['tree_mean'] = mean.copy() nodes['self_std'] = cov.diagonal()**0.5 nodes['tree_std'] = cov.diagonal()**0.5 nodes['score'] = np.zeros_like(mean) nodes.visits += 1 def _mcts_step(self, tree, score_fn, random_state): '''selection''' n = next(tree.iternodes()) n.visits += 1 while n.children is not None: n = argmax( n.children.iternodes(), lambda i, j: i.score < j.score ) n.visits += 1 '''expansion''' n.children = self._spawn(n, self.rewriter(n, random_state)) '''simulate''' self._evaluate(n.children) '''back-propagate''' p = n while p: p.tree_mean = np.average( p.children.tree_mean, weights=p.children.tree_std**-2 ) p.tree_std = np.average( (p.children.tree_mean - p.tree_mean)**2, weights=p.children.tree_std**-2 )**0.5 p.children['score'] = score_fn(p.children) p = p.parent