Source code for aiidalab_widgets_base.utils

"""Some utility functions used acrross the repository."""

import threading
from enum import Enum
from typing import Any

import ipywidgets as ipw
import more_itertools as mit
import numpy as np
import traitlets
from aiida.plugins import DataFactory
from ase import Atoms
from import read

CifData = DataFactory("core.cif")  # pylint: disable=invalid-name
StructureData = DataFactory("core.structure")  # pylint: disable=invalid-name
TrajectoryData = DataFactory("core.array.trajectory")  # pylint: disable=invalid-name

[docs] def valid_arguments(arguments, valid_args): """Check whether provided arguments are valid.""" result = {} for key, value in arguments.items(): if key in valid_args: if isinstance(value, (tuple, list)): result[key] = "\n".join(value) else: result[key] = value return result
[docs] def predefine_settings(obj, **kwargs): """Specify some pre-defined settings.""" for key, value in kwargs.items(): if hasattr(obj, key): setattr(obj, key, value) else: raise AttributeError(f"{obj!r} object has no attribute {key!r}")
[docs] def get_ase_from_file(fname, file_format=None): # pylint: disable=redefined-builtin """Get ASE structure object.""" # store_tags parameter is useful for CIF files # if file_format == "cif": traj = read(fname, format=file_format, index=":", store_tags=True) else: traj = read(fname, format=file_format, index=":") if not traj: raise ValueError(f"Could not read any information from the file {fname}") return traj
[docs] def find_ranges(iterable): """Yield range of consecutive numbers.""" for grp in mit.consecutive_groups(iterable): group = list(grp) if len(group) == 1: yield group[0] else: yield group[0], group[-1]
[docs] def list_to_string_range(lst, shift=1): """Converts a list like [0, 2, 3, 4] into a string like '1 3..5'. Shift used when e.g. for a user interface numbering starts from 1 not from 0""" return " ".join( [ f"{t[0] + shift}..{t[1] + shift}" if isinstance(t, tuple) else str(t + shift) for t in find_ranges(sorted(lst)) ] )
[docs] def string_range_to_list(strng, shift=-1): """Converts a string like '1 3..5' into a list like [0, 2, 3, 4]. Shift used when e.g. for a user interface numbering starts from 1 not from 0""" singles = [int(s) + shift for s in strng.split() if s.isdigit()] ranges = [r for r in strng.split() if ".." in r] if len(singles) + len(ranges) != len(strng.split()): return [], False for rng in ranges: try: start, end = rng.split("..") singles += [i + shift for i in range(int(start), int(end) + 1)] except ValueError: return [], False return singles, True
[docs] def get_formula(data_node): """A wrapper for getting a molecular formula out of the AiiDA Data node""" if isinstance(data_node, TrajectoryData): # TrajectoryData can only hold structures with the same chemical formula, # so this approach is sound. stepid = data_node.get_stepids()[0] return data_node.get_step_structure(stepid).get_formula() elif isinstance(data_node, StructureData): return data_node.get_formula() elif isinstance(data_node, CifData): return data_node.get_ase().get_chemical_formula() else: raise TypeError(f"Cannot get formula from node {type(data_node)}")
[docs] class PinholeCamera:
[docs] def __init__(self, matrix): self.matrix = np.reshape(matrix, (4, 4)).transpose()
[docs] def screen_to_vector(self, move_vector): """Converts vector from the screen coordinates to the normalized vector in 3D.""" move_vector[0] = -move_vector[0] # the x axis seem to be reverted in nglview. res = np.append(np.array(move_vector), [0]) res = res /= np.linalg.norm(res) return res[0:3]
@property def inverse_matrix(self): return np.linalg.inv(self.matrix)
[docs] class _StatusWidgetMixin(traitlets.HasTraits): """Show temporary messages for example for status updates. This is a mixin class that is meant to be part of an inheritance tree of an actual widget with a 'value' traitlet that is used to convey a status message. See the non-private classes below for examples. """ message = traitlets.Unicode(default_value="", allow_none=True) new_line = "\n"
[docs] def __init__(self, clear_after=3, *args, **kwargs): self._clear_timer = None self._clear_after = clear_after self._message_stack = [] super().__init__(*args, **kwargs)
[docs] def _clear_value(self): """Set widget .value to be an empty string.""" if self._message_stack: self._message_stack.pop(0) self.value = self.new_line.join(self._message_stack) else: self.value = ""
[docs] def show_temporary_message(self, value, clear_after=None): """Show a temporary message and clear it after the given interval.""" clear_after = clear_after or self._clear_after if value: self._message_stack.append(value) self.value = self.new_line.join(self._message_stack) # Start new timer that will clear the value after the specified interval. self._clear_timer = threading.Timer(self._clear_after, self._clear_value) self._clear_timer.start() self.message = None
[docs] class StatusHTML(_StatusWidgetMixin, ipw.HTML): """Show temporary HTML messages for example for status updates.""" new_line = "<br>" # This method should be part of _StatusWidgetMixin, but that does not work # for an unknown reason. @traitlets.observe("message") def _observe_message(self, change): self.show_temporary_message(change["new"])
# Define the message levels as Enum
[docs] class MessageLevel(Enum): INFO = "info" WARNING = "warning" ERROR = "danger" SUCCESS = "success"
[docs] def wrap_message(message, level=MessageLevel.INFO): """Wrap message into HTML code with the given level.""" # mapping level to fa icon # mapping = { MessageLevel.INFO: "info-circle", MessageLevel.WARNING: "exclamation-triangle", MessageLevel.ERROR: "exclamation-circle", MessageLevel.SUCCESS: "check-circle", } # The message is wrapped into a div with the class "alert" and the icon of the given level return f""" <div class="alert alert-{level.value}" role="alert" style="margin-bottom: 0px; padding: 6px 12px;"> <i class="fa fa-{mapping[level]}"></i>{message} </div> """
[docs] def ase2spglib(ase_structure: Atoms) -> tuple[Any, Any, Any]: """ Convert ase Atoms instance to spglib cell in the format defined at """ lattice = ase_structure.get_cell() positions = ase_structure.get_scaled_positions() numbers = ase_structure.get_atomic_numbers() return (lattice, positions, numbers)