Source code for myerson.pyg_explain.myerson

from myerson import MyersonCalculator, MyersonSampler

try:
    import torch
    import torch_geometric
except ImportError:
    raise ImportError("Failed to import torch and/or torch_geometric. PyG explanations not available.")
import numpy as np

import networkx as nx
from tqdm import tqdm
import logging
# try: 
#     from .myerson import fast_restrict
# except:
#     pass


[docs] class MyersonExplainer(MyersonCalculator): r"""Explains the prediction of a graph neural network (GNN) with Myerson values. The GNN is treated as the coalition function of a game and its prediction as the payoff of the game. The Myerson values show how much each node of the graph contributed to the final prediction. Args: graph (torch_geometric.data.Data): The data instance that is to be explained. coalition_function (torch.nn.Module): The GNN. disable_tqdm (bool, optional): Disables progress bar. Defaults to True. """ def __init__(self, graph: torch_geometric.data.Data, coalition_function: torch.nn.Module, disable_tqdm: bool=True) -> None: """Instantiate the class. """ self.disable_tqdm = disable_tqdm self.log = logging.getLogger("MyersonExplainer") self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu' self.log.info(f"using device {self.device}") self.pyg_graph = graph.to(self.device) self.coalition_function = coalition_function self.coalition_function.to(self.device) self.nx_graph = torch_geometric.utils.to_networkx(graph, to_undirected=True) self.grand_coalition = list(self.nx_graph.nodes()) # alias: set of players / set of nodes / F cc = nx.number_connected_components(self.nx_graph) if cc > 1: self.log.warning(f"Your graph has {cc} individual components. The worth" " of the grand coalition and the prediction of a GNN can" " differ.") pred = self.calculate_prediction() worth = self.calculate_worth_of_grand_coalition() self.log.warning(f"Prediction={pred:.4f}, Worth={worth:.4f}") # if "myerson.fast_restrict" in reversed(sys.modules): # self.fast_restrict_available = True # else: # self.fast_restrict_available = False # self.set_restrict(self.fast_restrict_available) # def set_restrict(self, use_fast_restrict: bool) -> None: # """Set wheter to use the fast C++ implementation or the networkX # implementation of `restrict(...)`. # Args: # use_fast_restrict (bool): True or False. # """ # if use_fast_restrict: # self._set_variables_for_fast_restrict() # self.log.info("Using fast_restrict() from external C++ library.") # self.restrict = self.fast_restrict # else: # self.log.info("Using python only slow `restrict()` (networkx package).") # self.restrict = super().restrict
[docs] def calculate_worth_of_single_graph_restricted_coalition(self, graph_restricted_coalition: tuple, pyg_graph: torch_geometric.data.Data) -> float: """Calculate the worth of a graph restricted coalition, i. e. a single connected component. Args: graph_restricted_coalition (tuple): Graph restricted coalition as node indices. pyg_graph (torch_geometric.data.Data): Graph from which a subgraph of the connected components will be extracted according to the graph restricted coalition. Returns: float: Worth, the output of the coalition function for the connected subgraph. """ if graph_restricted_coalition == (): return 0. subgraph = self.subgraph_from_coalition(graph_restricted_coalition, pyg_graph) out = self.coalition_function(subgraph.x, subgraph.edge_index, self._batch_var(subgraph)) return out.cpu().item()
[docs] def calculate_worth_of_graph_restricted_coalitions(self, graph_restricted_coalitions: list) -> dict: """Calculate the worth of every graph restricted coalition and map it to its worth. Args: graph_restricted_coalitions (list): Set of connected components as tuples of node indices. Returns: dict: Dictionary mapping each connected component to its worth. """ self.log.info(f"Calculating worth of graph restricted coalitions.") graph_restricted_coalitions_to_worth = {} for coalition in tqdm(graph_restricted_coalitions, desc="Calculating worth of graph restricted coalitions", disable=self.disable_tqdm): worth = self.calculate_worth_of_single_graph_restricted_coalition(coalition, self.pyg_graph) graph_restricted_coalitions_to_worth.update({coalition: worth}) return graph_restricted_coalitions_to_worth
[docs] def calculate_worth_of_grand_coalition(self) -> float: """Calculate payoff of the game, i.e. the model prediction. Note that a disconnected graph (> 2 molecules) can lead to differeces between the model prediction and this function. Args: coalition_function (Callable): The coalition function associating a coalition with a payoff. nx_graph (nx.classes.graph.Graph): Coalition structure of the game as a graph. Returns: float: Payoff of the game / worth of grand coalition. """ restricted_grand_coalition = self.restrict(self.grand_coalition, self.nx_graph) worth = sum([self.calculate_worth_of_single_graph_restricted_coalition(S, self.pyg_graph) \ for S in restricted_grand_coalition]) return worth
[docs] def calculate_prediction(self) -> float: """Calculate the prediction of the GNN for the investigated graph. When the graph is disconnected this prediction may differ from the worth of the grand coalition. Returns: float: Prediction. """ return self.coalition_function(self.pyg_graph.x, self.pyg_graph.edge_index, self._batch_var(self.pyg_graph)).cpu().item()
def _batch_var(self, pyg_graph: torch_geometric.data.Data) -> torch.tensor: """Return a batch argument for single graphs, required for models trained in batches. Args: pyg_graph (torch_geometric.data.Data): Graph for which to generate batch. Returns: torch.tensor: Batch attribute in the correct dimensions. """ return torch.zeros(pyg_graph.x.shape[0], dtype=int, device=pyg_graph.x.device) # def fast_restrict(self, coalition: tuple, nx_graph: nx.classes.graph.Graph) -> list[tuple]: # """Restricts a graph through a (sub)set of nodes / players. Generate a # list of graph restricted coalitions, i. e. a list of node indices of # connected nodes in the subgraph. Uses python wrapped C++ code for # efficiency. # Args: # coalition (tuple): Nodes that remain in the graph. # nx_graph (nx.classes.graph.Graph): Graph from which to generate # subgraphs. # Returns: # list[tuple]: Graph restricted coalitions as tuples of node indices. # """ # remove_nodes = set(nx_graph.nodes)-set(coalition) # if remove_nodes == set(nx_graph.nodes): # return [()] # empty_graph # component_map = cpp_graph_divide.get_connected_components(self.num_nodes, # type: ignore # self.num_edges, # self.edge_from_ptr, # self.edge_to_ptr, # list(remove_nodes)) # connected_subgraph_nodes = {} # seen_component = [] # for node, component in enumerate(component_map): # if (component not in seen_component) and (node not in remove_nodes): # connected_subgraph_nodes.update({component: [node]}) # seen_component.append(component) # else: # if node not in remove_nodes: # connected_subgraph_nodes.update({component: connected_subgraph_nodes[component]+[node]}) # return [tuple(connected_subgraph_nodes[key]) for key in connected_subgraph_nodes.keys()]
[docs] def subgraph_from_coalition(self, graph_restricted_coalition: tuple, pyg_graph: torch_geometric.data.Data) -> torch_geometric.data.Data: """Generates a subgraph from a graph restricted coalition (a subset of nodes / players) and a graph. Args: nodes (tuple): Nodes which form the subgraph. pyg_graph (torch_geometric.data.Data): Subgraph induced in this graph by the subset of nodes. Returns: torch_geometric.data.Data: The new subgraph. """ # unsorted nodes can result in the wrong edges nodes = sorted(graph_restricted_coalition) nodes = torch.tensor(nodes, dtype=torch.long, device=pyg_graph.x.device) node_mask = torch.zeros(pyg_graph.x.shape[0], dtype=torch.bool, device=pyg_graph.x.device) node_mask[nodes] = True x = pyg_graph.x[node_mask] edge_mask = node_mask[pyg_graph.edge_index[0]] & node_mask[pyg_graph.edge_index[1]] edge_index = pyg_graph.edge_index[:, edge_mask] # fancy indexing to relabel edge_index node_idx = torch.zeros(node_mask.size(0), dtype=torch.long, device=pyg_graph.x.device) node_idx[nodes] = torch.arange(node_mask.sum().item(), device=pyg_graph.x.device) edge_index = node_idx[edge_index] subgraph = torch_geometric.data.Data(x=x, edge_index=edge_index) return subgraph
# def _set_variables_for_fast_restrict(self) -> None: # """Set class variables for "fast_restrict" function (pointers passed to C++) # """ # self.num_nodes = self.pyg_graph.num_nodes # self.edge_from = self.pyg_graph.edge_index.cpu().numpy()[0] # self.edge_to = self.pyg_graph.edge_index.cpu().numpy()[1] # self.edge_from_ptr = self.edge_from.__array_interface__['data'][0] # self.edge_to_ptr = self.edge_to.__array_interface__['data'][0] # self.num_edges = len(self.edge_from)
[docs] class MyersonSamplingExplainer(MyersonSampler, MyersonExplainer): """A class explaining GNN predictions with approximated Myerson values. Args: graph (torch_geometric.data.Data): The data instance that is to be explained. coalition_function (torch.nn.Module): The GNN. seed (None | int, optional): Seed for randomness. Defaults to None. number_of_samples (int, optional): Number of sampling steps. Defaults to 1000. disable_tqdm (bool, optional): Disables progress bar. Defaults to True. """ def __init__(self, graph: torch_geometric.data.Data, coalition_function: torch.nn.Module, seed: None | int = None, number_of_samples: int = 1000, disable_tqdm: bool=True) -> None: """Instantiates the class. """ self.disable_tqdm = disable_tqdm self.log = logging.getLogger("MyersonSamplingExplainer") self.seed = seed self.rng = np.random.default_rng(seed) self.number_of_samples = number_of_samples self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu' self.log.info(f"using device {self.device}") self.pyg_graph = graph.to(self.device) self.coalition_function = coalition_function self.coalition_function.to(self.device) self.nx_graph = torch_geometric.utils.to_networkx(graph, to_undirected=True) self.grand_coalition = list(self.nx_graph.nodes()) # alias: set of players / set of nodes / F cc = nx.number_connected_components(self.nx_graph) if cc > 1: self.log.warning(f"Your graph has {cc} individual components. The worth" " of the grand coalition and the prediction of a GNN can" " differ.") pred = self.calculate_prediction() worth = self.calculate_worth_of_grand_coalition() self.log.warning(f"Prediction={pred:.4f}, Worth={worth:.4f}")
# if "myerson.cpp_graph_divide" in sys.modules: # self.fast_restrict_available = True # else: # self.fast_restrict_available = False # self.set_restrict(self.fast_restrict_available)
[docs] class MyersonClassExplainer(MyersonExplainer): r"""Explains the prediction of a graph neural network (GNN) classifier with Myerson values. The GNN is treated as the coalition function of a game and its prediction as the payoff of the game. The Myerson values show how much each node of the graph contributed to the final prediction. Args: graph (torch_geometric.data.Data): The data instance that is to be explained. coalition_function (torch.nn.Module): The GNN. disable_tqdm (bool, optional): Disables progress bar. Defaults to True. """ def __init__(self, graph: torch_geometric.data.Data, coalition_function: torch.nn.Module, disable_tqdm: bool=True) -> None: """Instantiate the class. """ self.disable_tqdm = disable_tqdm self.log = logging.getLogger("MyersonClassExplainer") self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu' self.log.info(f"using device {self.device}") self.pyg_graph = graph.to(self.device) self.coalition_function = coalition_function self.coalition_function.to(self.device) self.nx_graph = torch_geometric.utils.to_networkx(graph, to_undirected=True) self.grand_coalition = list(self.nx_graph.nodes()) # alias: set of players / set of nodes / F self.pred = self.calculate_prediction() cc = nx.number_connected_components(self.nx_graph) if cc > 1: self.log.warning(f"Your graph has {cc} individual components. The worth" " of the grand coalition and the prediction of a GNN can" " differ.") pred = self.calculate_prediction() worth = self.calculate_worth_of_grand_coalition() self.log.warning(f"Prediction={pred}, Worth={worth}")
[docs] def calculate_worth_of_single_graph_restricted_coalition(self, graph_restricted_coalition: tuple, pyg_graph: torch_geometric.data.Data) -> torch.tensor: """Calculate the worth of a graph restricted coalition, i. e. a single connected component. Args: graph_restricted_coalition (tuple): Graph restricted coalition as node indices. pyg_graph (torch_geometric.data.Data): Graph from which a subgraph of the connected components will be extracted according to the graph restricted coalition. Returns: tensor: Worth, the output of the coalition function for the connected subgraph. """ if graph_restricted_coalition == (): return torch.zeros(self.pred.shape) subgraph = self.subgraph_from_coalition(graph_restricted_coalition, pyg_graph) out = self.coalition_function(subgraph.x, subgraph.edge_index, self._batch_var(subgraph)) return out.detach().cpu().squeeze(0)
[docs] def calculate_prediction(self) -> torch.tensor: """Calculate the prediction of the GNN for the investigated graph. When the graph is disconnected this prediction may differ from the worth of the grand coalition. Returns: float: Prediction. """ return self.coalition_function(self.pyg_graph.x, self.pyg_graph.edge_index, self._batch_var(self.pyg_graph)).cpu().squeeze(0)
[docs] class MyersonSamplingClassExplainer(MyersonSamplingExplainer, MyersonClassExplainer): """A class explaining a GNNs classifier predictions with approximated Myerson values. Args: graph (torch_geometric.data.Data): The data instance that is to be explained. coalition_function (torch.nn.Module): The GNN. seed (None | int, optional): Seed for randomness. Defaults to None. number_of_samples (int, optional): Number of sampling steps. Defaults to 1000. disable_tqdm (bool, optional): Disables progress bar. Defaults to True. """ def __init__(self, graph: torch_geometric.data.Data, coalition_function: torch.nn.Module, seed: None | int = None, number_of_samples: int = 1000, disable_tqdm: bool=True) -> None: """Instantiates the class. """ self.disable_tqdm = disable_tqdm self.log = logging.getLogger("MyersonSamplingClassExplainer") self.seed = seed self.rng = np.random.default_rng(seed) self.number_of_samples = number_of_samples self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu' self.log.info(f"using device {self.device}") self.pyg_graph = graph.to(self.device) self.coalition_function = coalition_function self.coalition_function.to(self.device) self.nx_graph = torch_geometric.utils.to_networkx(graph, to_undirected=True) self.grand_coalition = list(self.nx_graph.nodes()) # alias: set of players / set of nodes / F self.pred = self.calculate_prediction() cc = nx.number_connected_components(self.nx_graph) if cc > 1: self.log.warning(f"Your graph has {cc} individual components. The worth" " of the grand coalition and the prediction of a GNN can" " differ.") pred = self.calculate_prediction() worth = self.calculate_worth_of_grand_coalition() self.log.warning(f"Prediction={pred}, Worth={worth}")
[docs] def map_coalition_to_worth(self, coalitions: list[tuple], coalitions_to_graph_restricted_coalitions: dict, graph_restricted_coalitions_to_worth: dict) -> dict: """Map every coalition to its worth. Args: coalitions (list): List of all coalitions (2^{num_nodes}). coalitions_to_graph_restricted_coalitions (dict): Dictionary mapping the coalitions to the corresponding graph restricted coalitions. graph_restricted_coalitions_to_worth (dict): Dictionary mapping the graph restricted coalitions to their worth. Returns: dict: Dictionary mapping each coalition to its worth. """ self.log.info(f"Mapping coalitions to worth.") coalition_to_worth = {} for coalition in tqdm(coalitions, desc="Mapping coalitions to worth", disable=self.disable_tqdm): worth = torch.zeros(self.pred.shape) for graph_restricted_coalition in coalitions_to_graph_restricted_coalitions[coalition]: worth += graph_restricted_coalitions_to_worth[graph_restricted_coalition] coalition_to_worth.update({coalition: worth}) return coalition_to_worth
[docs] def sample_all_myerson_values(self) -> np.ndarray: """Use Monte Carlo sampling to approximate the Myerson values for every node / player in the graph. Returns: np.ndarray: Sampled Myerson values. """ self.sample_all_mappings() pred = self.calculate_prediction() nodes_array = np.array(self.grand_coalition) my_values = np.zeros((len(nodes_array), pred.shape[0]), dtype=float) self.log.info(f"Calculating sampled Myerson values.") for permutation in tqdm(self.permutations_without_random_node, disable=self.disable_tqdm, desc="Calculate sampled Myerson values"): for node_idx, node in enumerate(nodes_array): sampled_permutation_with_current_swapped_in_random_node = permutation.copy() sampled_permutation_with_current_swapped_in_random_node \ = self._replace_in_array(sampled_permutation_with_current_swapped_in_random_node, node, self.random_node) worth_with_node = self.coalitions_to_worth[tuple(np.sort(np.append(sampled_permutation_with_current_swapped_in_random_node, node)))] worth_without_node = self.coalitions_to_worth[tuple(np.sort(sampled_permutation_with_current_swapped_in_random_node))] my_values[node_idx] = (my_values[node_idx] + worth_with_node.numpy().squeeze() - worth_without_node.numpy().squeeze()) my_values = my_values / self.number_of_samples log_string = "".join([f"\t{node}: {val}\n" for node, val in zip(self.grand_coalition, my_values)]) self.log.info(f"Sampled Myerson Values:\n{log_string}") return my_values
def explain(graph: torch_geometric.data.Data, model: torch.nn.Module, sample_if_more_nodes_than: int=20, verbose: bool=False) -> dict: """A function to quickly get started with explaining GNN predictions using Myerson values. Args: graph (torch_geometric.data.Data): The graph. model (torch.nn.Module): The graph neural network. sample_if_more_nodes_than (int, optional): Barrier for when to start sampling instead of exact calculations. Defaults to 20. verbose (bool, optional): Whether to log information to the output and show progress bars. Defaults to False. Returns: dict: The (sampled) Myerson values. """ if verbose: logging.basicConfig(level=logging.INFO, format='[%(asctime)s - %(levelname)s] %(message)s', force=True) disable_tqdm=False else: disable_tqdm=True node_count = graph.x.size()[0] if node_count > sample_if_more_nodes_than: logging.info("Sampling Myerson values.") sampler = MyersonSamplingExplainer(graph, model, disable_tqdm=disable_tqdm) return sampler.sample_all_myerson_values() else: logging.info("Calculating exact Myerson values.") explainer = MyersonExplainer(graph, model, disable_tqdm=disable_tqdm) return explainer.calculate_all_myerson_values()