Source code for aiidalab_widgets_base.nodes

"""Widgets to work with AiiDA nodes."""

import functools

import ipytree
import ipywidgets as ipw
import traitlets as tl
from aiida import common, engine, orm
from aiida.cmdline.utils.ascii_vis import calc_info
from IPython.display import clear_output, display


[docs] class AiidaNodeTreeNode(ipytree.Node): def __init__(self, pk, name, **kwargs): self.pk = pk self.nodes_registry = {} super().__init__(name=name, **kwargs) @tl.default("opened") def _default_opened(self): return True
[docs] class AiidaProcessNodeTreeNode(AiidaNodeTreeNode): def __init__(self, pk, **kwargs): self.outputs_node = AiidaOutputsTreeNode(name="outputs", parent_pk=pk) super().__init__(pk=pk, **kwargs)
[docs] class WorkChainProcessTreeNode(AiidaProcessNodeTreeNode): icon = tl.Unicode("chain").tag(sync=True)
[docs] class CalcJobTreeNode(AiidaProcessNodeTreeNode): icon = tl.Unicode("gears").tag(sync=True)
[docs] class CalcFunctionTreeNode(AiidaProcessNodeTreeNode): icon = tl.Unicode("gear").tag(sync=True)
[docs] class AiidaOutputsTreeNode(ipytree.Node): icon = tl.Unicode("folder").tag(sync=True) disabled = tl.Bool(False).tag(sync=True) def __init__(self, name, parent_pk, namespaces: tuple[str, ...] = (), **kwargs): self.parent_pk = parent_pk self.nodes_registry = {} self.namespaces = namespaces super().__init__(name=name, **kwargs)
[docs] class UnknownTypeTreeNode(AiidaNodeTreeNode): icon = tl.Unicode("file").tag(sync=True)
[docs] class NodesTreeWidget(ipw.Output): """A tree widget for the structured representation of a nodes graph.""" nodes = tl.Tuple().tag(trait=tl.Instance(orm.Node)) selected_nodes = tl.Tuple(read_only=True).tag(trait=tl.Instance(orm.Node)) PROCESS_STATE_STYLE = { engine.ProcessState.EXCEPTED: "danger", engine.ProcessState.FINISHED: "success", engine.ProcessState.KILLED: "warning", engine.ProcessState.RUNNING: "info", engine.ProcessState.WAITING: "info", } PROCESS_STATE_STYLE_DEFAULT = "default" NODE_TYPE = { orm.WorkChainNode: WorkChainProcessTreeNode, orm.CalcFunctionNode: CalcFunctionTreeNode, orm.CalcJobNode: CalcJobTreeNode, } def __init__(self, **kwargs): self._tree = ipytree.Tree() self._tree.observe(self._observe_tree_selected_nodes, ["selected_nodes"]) super().__init__(**kwargs) def _refresh_output(self): # There appears to be a bug in the ipytree implementation that sometimes # causes the output to not be properly cleared. We therefore refresh the # displayed tree upon change of the process trait. with self: clear_output() display(self._tree) def _observe_tree_selected_nodes(self, change): for node in change["new"]: # find the selected node and build the tree from it, so that users can expand and explore the tree node_pk = ( node.parent_pk if isinstance(node, AiidaOutputsTreeNode) else getattr(node, "pk", None) ) self._build_tree(self.find_node(node_pk, getattr(node, "namespaces", None))) return self.set_trait( "selected_nodes", tuple( orm.load_node(pk=node.pk) for node in change["new"] if hasattr(node, "pk") ), ) def _convert_to_tree_nodes(self, old_nodes, new_nodes): "Convert nodes into tree nodes while re-using already converted nodes." old_nodes_ = {node.pk: node for node in old_nodes} assert len(old_nodes_) == len(old_nodes) # no duplicated nodes for node in new_nodes: if node.pk in old_nodes_: yield old_nodes_[node.pk] else: yield self._to_tree_node(node, opened=True) @tl.observe("nodes") def _observe_nodes(self, change): self._tree.nodes = sorted( self._convert_to_tree_nodes( old_nodes=self._tree.nodes, new_nodes=change["new"] ), key=lambda node: node.pk, ) self.update() self._refresh_output() @classmethod def _to_tree_node(cls, node, name=None, **kwargs): """Convert an AiiDA node to a tree node.""" if name is None: if isinstance(node, orm.ProcessNode): name = calc_info(node) else: name = str(node) tree_node = cls.NODE_TYPE.get(type(node), UnknownTypeTreeNode)( pk=node.pk, name=name, **kwargs ) # Set the style based on the process state of the node if isinstance(node, orm.ProcessNode): process_state = ( engine.ProcessState.EXCEPTED if node.is_failed else node.process_state ) tree_node.icon_style = cls.PROCESS_STATE_STYLE.get( process_state, cls.PROCESS_STATE_STYLE_DEFAULT ) return tree_node @classmethod def _find_called(cls, root): assert isinstance(root, AiidaProcessNodeTreeNode) process_node = orm.load_node(root.pk) called = process_node.called called.sort(key=lambda p: p.ctime) for node in called: if node.pk not in root.nodes_registry: try: name = calc_info(node) except AttributeError: name = str(node) root.nodes_registry[node.pk] = cls._to_tree_node(node, name=name) yield root.nodes_registry[node.pk] @classmethod def _find_outputs(cls, root): """ A generator for all (including nested) output nodes. Generates an AiidaOutputsTreeNode when encountering a namespace, keeping track of the full namespace path to make it accessible via the root node in form of a breadth-first search. """ process_node = orm.load_node(root.parent_pk) # Gather outputs from node and its namespaces: outputs = functools.reduce( lambda attr_dict, namespace: attr_dict[namespace], root.namespaces or [], process_node.outputs, ) # Convert aiida.orm.LinkManager or AttributDict (if namespace presented) to dict output_nodes = {key: outputs[key] for key in outputs} for key in sorted( output_nodes.keys(), key=lambda k: getattr(outputs[k], "pk", -1) ): node = output_nodes[key] if isinstance(node, common.AttributeDict): # for namespace tree node attach label and continue recursively yield AiidaOutputsTreeNode( name=key, parent_pk=root.parent_pk, namespaces=(*root.namespaces, key), # attach nested namespace name ) else: if node.pk not in root.nodes_registry: root.nodes_registry[node.pk] = cls._to_tree_node( node, name=f"{key}<{node.pk}>" ) yield root.nodes_registry[node.pk] @classmethod def _find_children(cls, root): """Find all children of the provided AiiDA node.""" if isinstance(root, AiidaProcessNodeTreeNode): yield root.outputs_node yield from cls._find_called(root) elif isinstance(root, AiidaOutputsTreeNode): yield from cls._find_outputs(root) @classmethod def _build_tree(cls, root): """Build a tree nodes graph for a given tree node.""" root.nodes = list(cls._find_children(root)) return root @classmethod def _walk_tree(cls, root): """Breadth-first search of the node tree.""" yield root for node in root.nodes: yield from cls._walk_tree(node) def _update_tree_node(self, tree_node): if isinstance(tree_node, AiidaProcessNodeTreeNode): process_node = orm.load_node(tree_node.pk) if process_node.process_state is None: return tree_node.name = calc_info(process_node) # Override the process state in case that the process node has failed: # (This could be refactored with structural pattern matching with py>=3.10.) process_state = ( engine.ProcessState.EXCEPTED if process_node.is_failed else process_node.process_state ) tree_node.icon_style = self.PROCESS_STATE_STYLE.get( process_state, self.PROCESS_STATE_STYLE_DEFAULT )
[docs] def update(self, _=None): """Refresh nodes based on the latest state of the root process and its children.""" for root_node in self._tree.nodes: self._build_tree(root_node) for tree_node in self._walk_tree(root_node): self._update_tree_node(tree_node)
[docs] def find_node(self, pk, namespaces=None): """Find a node by its pk and namespaces. If node is an output node, it is identified by the parent pk and namespaces, otherwise by the pk.""" for node in self._walk_tree(self._tree): node_pk = ( node.parent_pk if isinstance(node, AiidaOutputsTreeNode) else getattr(node, "pk", None) ) if node_pk == pk and getattr(node, "namespaces", None) == namespaces: return node raise KeyError(pk)