Source code for BNMPy.model_compressor

import numpy as np
import networkx as nx
from typing import Set, List, Dict, Tuple, Union, Optional
from collections import deque
import copy
import re


[docs] class ModelCompressor: """ Model compression for Boolean Networks. Provides methods to compress models by removing non-observable/non-controllable nodes and collapsing linear paths to simplify the network structure. """ # Regex pattern to match single identifier (alias) rules _alias_pat = re.compile(r'^[A-Za-z_][A-Za-z0-9_]*$')
[docs] def __init__(self, network, measured_nodes: Set[str] = None, perturbed_nodes: Set[str] = None): """ Initialize the model compressor. Parameters: ----------- network : BooleanNetwork The Boolean network to compress measured_nodes : Set[str], optional Set of node names that are measured/observed. If None, will auto-detect output nodes. perturbed_nodes : Set[str], optional Set of node names that are perturbed/controlled """ # Check if it's a PBN and warn user if hasattr(network, 'cij'): raise ValueError("PBN compression is not currently supported. Please use Boolean Networks only.") self.network = network self.perturbed_nodes = perturbed_nodes or set() # Build directed graph representation self.graph = self._build_graph() # Auto-detect measured nodes if not provided if measured_nodes is None: self.measured_nodes = self._detect_output_nodes() else: self.measured_nodes = measured_nodes # Track original network structure for restoration self.original_structure = self._save_network_structure() # Track removed elements for visualization self.removed_nodes = set() self.collapsed_paths = []
def _build_graph(self) -> nx.DiGraph: """ Build a NetworkX directed graph from the Boolean network connectivity matrix. Returns: -------- nx.DiGraph Directed graph representation of the network """ graph = nx.DiGraph() # Add all nodes for name, idx in self.network.nodeDict.items(): graph.add_node(name, index=idx) # Add edges based on connectivity matrix for BooleanNetwork for node_idx in range(self.network.N): for j in range(self.network.K[node_idx]): input_idx = self.network.varF[node_idx, j] if input_idx != -1: input_name = self._get_node_name(input_idx) output_name = self._get_node_name(node_idx) if input_name and output_name: graph.add_edge(input_name, output_name) return graph def _get_node_name(self, index: int) -> Optional[str]: """Get node name from index.""" for name, idx in self.network.nodeDict.items(): if idx == index: return name return None def _save_network_structure(self) -> Dict: """Save current network structure for potential restoration.""" structure = { 'varF': self.network.varF.copy(), 'F': self.network.F.copy(), 'nodeDict': self.network.nodeDict.copy(), 'N': self.network.N, 'K': self.network.K.copy(), } return structure
[docs] def find_non_observable_nodes(self) -> Set[str]: """ Find nodes that are non-observable (no path to measured species). Returns: -------- Set[str] Set of non-observable node names """ if not self.measured_nodes: return set() # Find all nodes that can reach measured nodes observable_nodes = set() for measured_node in self.measured_nodes: if measured_node in self.graph: # Find all predecessors (nodes that can reach this measured node) predecessors = nx.ancestors(self.graph, measured_node) observable_nodes.update(predecessors) observable_nodes.add(measured_node) # measured node is observable # Non-observable nodes are those not in the observable set all_nodes = set(self.graph.nodes()) non_observable = all_nodes - observable_nodes return non_observable
[docs] def find_non_controllable_nodes(self) -> Set[str]: """ Find nodes that are non-controllable (no path from perturbed species). Returns: -------- Set[str] Set of non-controllable node names """ if not self.perturbed_nodes: return set() # Find all nodes reachable from perturbed nodes controllable_nodes = set() for perturbed_node in self.perturbed_nodes: if perturbed_node in self.graph: # Find all successors (nodes reachable from this perturbed node) successors = nx.descendants(self.graph, perturbed_node) controllable_nodes.update(successors) controllable_nodes.add(perturbed_node) # perturbed node is controllable # Non-controllable nodes are those not in the controllable set all_nodes = set(self.graph.nodes()) non_controllable = all_nodes - controllable_nodes return non_controllable
[docs] def find_collapsible_paths(self) -> List[List[str]]: """ Find linear paths that can be collapsed. A collapsible path is a series of nodes that form a linear cascade, where intermediate nodes can be removed without losing connectivity. Returns: -------- List[List[str]] List of paths, where each path is a list of node names """ collapsible_paths = [] visited = set() # Important nodes that should not be removed important_nodes = self.measured_nodes | self.perturbed_nodes complex_nodes = set() for node in self.graph.nodes(): if self._is_used_in_complex_rule(node): complex_nodes.add(node) protected_nodes = important_nodes | complex_nodes # Find all linear paths starting from any node for node in self.graph.nodes(): if node in visited: continue # Look for the start of a potential path path = self._trace_linear_path(node, protected_nodes, visited) if len(path) > 1: # Only paths with at least one node to remove # Check if this path has collapsible nodes collapsible_nodes = [] for i, n in enumerate(path): if n in protected_nodes: # If this is the last node, it's okay (endpoint) if i == len(path) - 1: break else: # Protected node in the middle - can't collapse this path collapsible_nodes = [] break else: collapsible_nodes.append(n) if collapsible_nodes: collapsible_paths.append(path) visited.update(path[:-1] if path[-1] in protected_nodes else path) return collapsible_paths
def _trace_linear_path(self, start_node: str, important_nodes: Set[str], visited: Set[str]) -> List[str]: """ Trace a linear path starting from the given node. A linear path is a sequence of nodes where each node has exactly one output (except the last) and exactly one input (except the first). Important nodes and nodes used in complex rules can be endpoints but not intermediate nodes. """ path = [start_node] current = start_node # Follow the path forward while True: successors = list(self.graph.successors(current)) # Check if we can continue the path if len(successors) != 1: break next_node = successors[0] # Check if next node is suitable for the path if (next_node in visited or next_node in path or # Avoid cycles self.graph.has_edge(next_node, next_node)): # Avoid self-loops break # If we've reached an important node if next_node in important_nodes: path.append(next_node) break # If we've reached a node used in complex rules if self._is_used_in_complex_rule(next_node): path.append(next_node) break # Check if current node is linear (for continuing the path) if not self._is_linear_node(next_node): break # Add to path and continue path.append(next_node) current = next_node return path def _is_linear_node(self, node: str) -> bool: """ Check if a node is part of a linear cascade. A linear node has exactly one input and one output, and is not involved in self-loops. """ # Check for self-loops if self.graph.has_edge(node, node): return False # Check input/output constraints for intermediate nodes in_degree = self.graph.in_degree(node) out_degree = self.graph.out_degree(node) # A linear intermediate node has exactly one input and one output return in_degree == 1 and out_degree == 1
[docs] def collapse_paths(self, paths: List[List[str]]) -> None: """ Collapse linear paths by removing intermediate nodes and creating direct connections. For each path, update all references to any relay node in the path to the ultimate source. """ self.collapsed_paths = paths.copy() important_nodes = self.measured_nodes | self.perturbed_nodes # Also consider nodes used in complex rules as important nodes_in_complex_rules = set() for node in self.graph.nodes(): if self._is_used_in_complex_rule(node): nodes_in_complex_rules.add(node) all_important_nodes = important_nodes | nodes_in_complex_rules # Process each path for path in paths: if len(path) < 2: continue # Get the ultimate source (the node that should replace all others in the path) src = self._ultimate_source(path[0]) # Determine which nodes in the path should be removed and which preserved nodes_to_replace = [] preserved_nodes = [] for node in path: if node not in all_important_nodes: nodes_to_replace.append(node) else: preserved_nodes.append(node) # Update rules for preserved nodes at path endpoints for preserved_node in preserved_nodes: # If this preserved node was getting its input from a node in the path, # update it to get input from the ultimate source instead current_rule = self._get_rule(preserved_node).strip() # Replace any nodes from this path with the ultimate source new_rule = current_rule for node in nodes_to_replace: pattern = r'\b' + re.escape(node) + r'\b' new_rule = re.sub(pattern, src, new_rule) if new_rule != current_rule: self._update_node_rule(preserved_node, new_rule) # Update all other equations that reference nodes to be replaced for node in nodes_to_replace: # Update all equations that reference this node for i, eq in enumerate(self.network.equations): lhs, rhs = eq.split('=', 1) lhs = lhs.strip() rule = rhs.strip() # Skip equations for nodes we're removing if lhs in nodes_to_replace: continue # Replace this node with the ultimate source pattern = r'\b' + re.escape(node) + r'\b' new_rule = re.sub(pattern, src, rule) if new_rule != rule: self.network.equations[i] = f"{lhs} = {new_rule}" self._update_node_rule(lhs, new_rule) # Remove all nodes that were marked for removal all_nodes_to_remove = set() # Collect all nodes that should be removed from all paths for path in paths: if len(path) < 2: continue for node in path: if node not in all_important_nodes: all_nodes_to_remove.add(node) # Update self.removed_nodes with the actual nodes being removed self.removed_nodes.update(all_nodes_to_remove) if all_nodes_to_remove: self.remove_nodes(all_nodes_to_remove) # Remove equations for removed nodes if hasattr(self.network, 'equations'): self.network.equations = [eq for eq in self.network.equations if eq.split('=')[0].strip() not in all_nodes_to_remove] # Remove nodes from nodeDict for node in all_nodes_to_remove: if node in self.network.nodeDict: del self.network.nodeDict[node] self.graph = self._build_graph() self._collapse_alias_nodes()
def _update_node_rule(self, node_name: str, new_rule: str) -> None: """Update the rule for a specific node and its connectivity matrix.""" if not hasattr(self.network, 'equations') or not new_rule: return # Update the equations list for i, equation in enumerate(self.network.equations): eq_parts = equation.split('=', 1) if len(eq_parts) == 2 and eq_parts[0].strip() == node_name: self.network.equations[i] = f"{node_name} = {new_rule}" break # Update connectivity matrix if node_name in self.network.nodeDict: node_idx = self.network.nodeDict[node_name] # Clear existing connections for j in range(self.network.varF.shape[1]): self.network.varF[node_idx, j] = -1 # Extract new dependencies from the rule import re dependencies = re.findall(r'\b[A-Za-z0-9_]+\b', new_rule) # Add new connections conn_count = 0 for dep in dependencies: if dep in self.network.nodeDict and dep != node_name: dep_idx = self.network.nodeDict[dep] if conn_count < self.network.varF.shape[1]: self.network.varF[node_idx, conn_count] = dep_idx conn_count += 1 # Update connection count self.network.K[node_idx] = conn_count
[docs] def remove_nodes(self, nodes_to_remove: Set[str]) -> None: """ Remove specified nodes from the network. Parameters: ----------- nodes_to_remove : Set[str] Set of node names to remove """ if not nodes_to_remove: return # Track removed nodes for visualization self.removed_nodes.update(nodes_to_remove) # Get indices of nodes to remove indices_to_remove = [self.network.nodeDict[name] for name in nodes_to_remove if name in self.network.nodeDict] if not indices_to_remove: return # Sort in descending order to maintain correct indices when removing indices_to_remove.sort(reverse=True) # Create new nodeDict mapping new_nodeDict = {} new_index = 0 for name, old_index in self.network.nodeDict.items(): if name not in nodes_to_remove: new_nodeDict[name] = new_index new_index += 1 # Update network structure self._update_network_structure(indices_to_remove, new_nodeDict)
def _update_network_structure(self, indices_to_remove: List[int], new_nodeDict: Dict[str, int]) -> None: """Update network structure after removing nodes.""" # Create index mapping from old to new old_to_new = {} new_idx = 0 for old_idx in range(self.network.N): if old_idx not in indices_to_remove: old_to_new[old_idx] = new_idx new_idx += 1 # Update network properties new_N = len(new_nodeDict) self._update_bn_structure(indices_to_remove, old_to_new, new_N, new_nodeDict) def _update_bn_structure(self, indices_to_remove: List[int], old_to_new: Dict[int, int], new_N: int, new_nodeDict: Dict[str, int]) -> None: """Update BooleanNetwork structure.""" # Create new arrays new_varF = [] new_F = [] new_K = [] new_equations = [] for old_idx in range(self.network.N): if old_idx not in indices_to_remove: # Update connectivity matrix row old_row = self.network.varF[old_idx] new_row = [] for conn in old_row: if conn == -1: new_row.append(-1) elif conn in old_to_new: new_row.append(old_to_new[conn]) # Skip connections to removed nodes # Pad with -1 if necessary while len(new_row) < self.network.varF.shape[1]: new_row.append(-1) new_varF.append(new_row[:self.network.varF.shape[1]]) new_F.append(self.network.F[old_idx]) # Update K (count non-(-1) connections) new_K.append(sum(1 for x in new_row if x != -1)) # Update equations if they exist if hasattr(self.network, 'equations') and self.network.equations: if old_idx < len(self.network.equations): new_equations.append(self.network.equations[old_idx]) # Update network properties self.network.varF = np.array(new_varF) self.network.F = np.array(new_F) self.network.K = np.array(new_K) self.network.N = new_N self.network.nodeDict = new_nodeDict # Update equations if they exist if hasattr(self.network, 'equations') and new_equations: self.network.equations = new_equations # Update other arrays if they exist if hasattr(self.network, 'nodes'): new_nodes = [] for old_idx in range(len(self.network.nodes)): if old_idx not in indices_to_remove: new_nodes.append(self.network.nodes[old_idx]) self.network.nodes = np.array(new_nodes)
[docs] def compress(self, remove_non_observable: bool = True, remove_non_controllable: bool = True, collapse_linear_paths: bool = True) -> Dict[str, Set[str]]: """ Compress the model by removing non-observable/non-controllable nodes and collapsing linear paths. Parameters: ----------- remove_non_observable : bool, default=True Whether to remove non-observable nodes remove_non_controllable : bool, default=True Whether to remove non-controllable nodes collapse_linear_paths : bool, default=True Whether to collapse linear paths Returns: -------- Dict[str, Set[str]] Dictionary containing information about the compression: - 'removed_non_observable': Set of removed non-observable nodes - 'removed_non_controllable': Set of removed non-controllable nodes - 'collapsed_paths': List of collapsed paths - 'removed_nodes': Set of all removed nodes - 'removed_edges': Set of all removed edges """ # Store original structure for edge comparison original_graph = self._build_graph() compression_info = { 'removed_non_observable': set(), 'removed_non_controllable': set(), 'collapsed_paths': [], 'measured_nodes': self.measured_nodes, 'perturbed_nodes': self.perturbed_nodes } # Find nodes to remove nodes_to_remove = set() if remove_non_observable: non_observable = self.find_non_observable_nodes() nodes_to_remove.update(non_observable) compression_info['removed_non_observable'] = non_observable if remove_non_controllable: non_controllable = self.find_non_controllable_nodes() nodes_to_remove.update(non_controllable) compression_info['removed_non_controllable'] = non_controllable # Remove non-observable/non-controllable nodes first if nodes_to_remove: self.remove_nodes(nodes_to_remove) # Rebuild graph after removal self.graph = self._build_graph() # Collapse linear paths if collapse_linear_paths: paths = self.find_collapsible_paths() if paths: self.collapse_paths(paths) compression_info['collapsed_paths'] = paths # Rebuild graph after collapsing self.graph = self._build_graph() self._collapse_alias_nodes() # Identify removed edges by comparing original and final structures compression_info['removed_edges'] = self._identify_removed_edges(original_graph) compression_info['removed_nodes'] = self.removed_nodes return compression_info
def _identify_removed_edges(self, original_graph: nx.DiGraph) -> Set[Tuple[str, str]]: """ Identify removed edges by comparing original and current graph structures. Parameters: ----------- original_graph : nx.DiGraph The original graph before compression Returns: -------- Set[Tuple[str, str]] Set of removed edges as (source, target) tuples """ removed_edges = set() # Get current graph edges current_edges = set(self.graph.edges()) # Find edges that existed in original but not in current for edge in original_graph.edges(): if edge not in current_edges: removed_edges.add(edge) # Also include edges involving removed nodes for node in self.removed_nodes: # Add all edges from/to removed nodes that were in the original for pred in original_graph.predecessors(node): removed_edges.add((pred, node)) for succ in original_graph.successors(node): removed_edges.add((node, succ)) return removed_edges
[docs] def get_compression_summary(self, compression_info: Dict) -> str: """ Generate a summary of the compression results. Parameters: ----------- compression_info : Dict Compression information returned by compress() Returns: -------- str Human-readable summary of compression """ summary = ["Model Compression Summary:"] summary.append("=" * 30) # Add information about measured and perturbed nodes measured_nodes = compression_info.get('measured_nodes', set()) perturbed_nodes = compression_info.get('perturbed_nodes', set()) important_nodes = measured_nodes | perturbed_nodes if measured_nodes: summary.append(f" Measured nodes: {', '.join(sorted(measured_nodes))}") else: summary.append(f"No measured nodes provided.") if perturbed_nodes: summary.append(f" Perturbed nodes: {', '.join(sorted(perturbed_nodes))}") else: summary.append(f"No perturbed nodes provided.") summary.append("") # Empty line for separation non_obs = compression_info.get('removed_non_observable', set()) non_ctrl = compression_info.get('removed_non_controllable', set()) paths = compression_info.get('collapsed_paths', []) summary.append(f"Removed {len(non_obs)} non-observable nodes") if non_obs: summary.append(f" Non-observable: {', '.join(sorted(non_obs))}") summary.append(f"Removed {len(non_ctrl)} non-controllable nodes") if non_ctrl: summary.append(f" Non-controllable: {', '.join(sorted(non_ctrl))}") summary.append(f"Collapsed {len(paths)} linear paths") if paths: for i, path in enumerate(paths, 1): # Show the complete path including the ultimate source if path: # Get the ultimate source for proper display src = self._ultimate_source(path[0]) if hasattr(self, '_ultimate_source') else path[0] # Create the full path display: source -> path nodes -> target full_path = [src] + path if src not in path else path summary.append(f" Path {i}: {' -> '.join(full_path)}") # Calculate total nodes removed correctly # For each path, count only the nodes that are actually removed total_removed_paths = 0 for path in paths: if path and path[-1] in important_nodes: # Path ends at important node - only intermediate nodes removed total_removed_paths += len(path) - 1 else: # All nodes in path removed total_removed_paths += len(path) total_removed = len(non_obs) + len(non_ctrl) + total_removed_paths summary.append(f"\nTotal nodes removed/collapsed: {total_removed}") summary.append(f"Final network size: {self.network.N} nodes") return "\n".join(summary)
[docs] def visualize_compression(self, original_network, output_html="compression_visualization.html", interactive=False): """ Visualize the compression results showing removed nodes and edges. Parameters: ----------- original_network : BooleanNetwork The original network before compression output_html : str Output HTML file name interactive : bool If True, create interactive HTML visualization; if False, return matplotlib figure Returns: -------- Network object if interactive=True, matplotlib figure if interactive=False, None on error """ try: from ..BNMPy.vis import vis_compression compression_info = { 'removed_non_observable': set(), 'removed_non_controllable': set(), 'collapsed_paths': self.collapsed_paths, 'removed_nodes': self.removed_nodes, 'removed_edges': self.removed_edges, 'measured_nodes': self.measured_nodes, 'perturbed_nodes': self.perturbed_nodes } vis_compression( original_network, self.network, compression_info, output_html, interactive ) except ImportError: print("Visualization module not available. Please ensure BNMPy.vis is properly installed.")
def _detect_output_nodes(self) -> Set[str]: """ Automatically detect output nodes (nodes with no outgoing edges). Returns: -------- Set[str] Set of output node names """ output_nodes = set() for node_name in self.graph.nodes(): # Check if node has no outgoing edges if self.graph.out_degree(node_name) == 0: output_nodes.add(node_name) return output_nodes def _get_rule(self, node_name: str) -> str: """Get the rule (right-hand side) for a given node.""" if not hasattr(self.network, 'equations'): return "" for equation in self.network.equations: # Handle equations with extra whitespace around the equals sign eq_parts = equation.split('=', 1) if len(eq_parts) == 2 and eq_parts[0].strip() == node_name: return eq_parts[1].strip() return "" def _is_used_in_complex_rule(self, node: str) -> bool: """Return True if node appears in any rule as part of a complex expression (not a pure alias), or if the node's own rule is complex.""" # Check if the node's own rule is complex own_rule = self._get_rule(node).strip() if own_rule and not self._alias_pat.fullmatch(own_rule): return True # Check if the node appears in other nodes' complex rules for eq in self.network.equations: lhs, rhs = eq.split('=', 1) rule = rhs.strip() # Node must appear in the rule, and the rule must not be a pure alias if node in rule and not self._alias_pat.fullmatch(rule): return True return False def _ultimate_source(self, var: str) -> str: """ Follow alias chains (A = B = C ...) until a node is used in a complex rule or not a pure alias. """ seen = set() while self._alias_pat.fullmatch(self._get_rule(var).strip()) and var not in seen: # Stop if this node is used in a complex rule if self._is_used_in_complex_rule(var): return var seen.add(var) var = self._get_rule(var).strip() return var def _collapse_alias_nodes(self) -> None: """ Collapse alias nodes (nodes whose rule is a single identifier) by rewiring their children directly to their parent and removing the alias node. Important nodes (measured/perturbed) and nodes used in complex rules are never collapsed. """ nodes_to_remove = set() # Important nodes that should never be collapsed important_nodes = self.measured_nodes | self.perturbed_nodes # Also preserve nodes used in complex rules complex_nodes = set() for node in self.graph.nodes(): if self._is_used_in_complex_rule(node): complex_nodes.add(node) # Combined set of nodes that should never be collapsed protected_nodes = important_nodes | complex_nodes for node in list(self.graph.nodes()): if node in nodes_to_remove or node in protected_nodes: continue rule = self._get_rule(node).strip() # Check if this is an alias node (rule is a single identifier) if not self._alias_pat.fullmatch(rule): continue # Check if node has exactly one input and one output if self.graph.in_degree(node) != 1 or self.graph.out_degree(node) != 1: continue # Get parent and children parent = rule # The node this alias points to children = list(self.graph.successors(node)) if not children or parent not in self.graph.nodes(): continue child = children[0] # Don't collapse if the child is also a protected node if child in protected_nodes: continue # Update child's rule to point directly to parent child_rule = self._get_rule(child) if child_rule: # Replace occurrences of alias node with parent pattern = r'\b' + re.escape(node) + r'\b' new_rule = re.sub(pattern, parent, child_rule) self._update_node_rule(child, new_rule) # Mark alias node for removal nodes_to_remove.add(node) # Remove all alias nodes if nodes_to_remove: self.remove_nodes(nodes_to_remove) # Remove equations for alias nodes if hasattr(self.network, 'equations'): self.network.equations = [eq for eq in self.network.equations if eq.split('=')[0].strip() not in nodes_to_remove] # Remove nodes from nodeDict for node in nodes_to_remove: if node in self.network.nodeDict: del self.network.nodeDict[node] # Rebuild graph after removal self.graph = self._build_graph()
[docs] def compress_model(network, measured_nodes: Set[str] = None, perturbed_nodes: Set[str] = None, remove_non_observable: bool = True, remove_non_controllable: bool = True, collapse_linear_paths: bool = True): """ Convenience function to compress a model in one step. Parameters: ----------- network : BooleanNetwork The Boolean network to compress measured_nodes : Set[str], optional Set of node names that are measured/observed perturbed_nodes : Set[str], optional Set of node names that are perturbed/controlled remove_non_observable : bool, default=True Whether to remove non-observable nodes remove_non_controllable : bool, default=True Whether to remove non-controllable nodes collapse_linear_paths : bool, default=True Whether to collapse linear paths Returns: -------- Tuple[network, Dict] - Compressed network - Compression information dictionary """ # Check if it's a PBN if hasattr(network, 'cij'): raise ValueError("PBN compression is not currently supported. Please use Boolean Networks only.") # Make a deep copy to preserve the original network compressed_network = copy.deepcopy(network) compressor = ModelCompressor(compressed_network, measured_nodes, perturbed_nodes) compression_info = compressor.compress( remove_non_observable=remove_non_observable, remove_non_controllable=remove_non_controllable, collapse_linear_paths=collapse_linear_paths ) summary = compressor.get_compression_summary(compression_info) print(summary) return compressed_network, compression_info