"""Some utility functions used acrross the repository."""
import threading
from enum import Enum
from typing import Any
import ipywidgets as ipw
import more_itertools as mit
import numpy as np
import traitlets
from aiida.plugins import DataFactory
from ase import Atoms
from ase.io import read
CifData = DataFactory("core.cif") # pylint: disable=invalid-name
StructureData = DataFactory("core.structure") # pylint: disable=invalid-name
TrajectoryData = DataFactory("core.array.trajectory") # pylint: disable=invalid-name
[docs]
def valid_arguments(arguments, valid_args):
"""Check whether provided arguments are valid."""
result = {}
for key, value in arguments.items():
if key in valid_args:
if isinstance(value, (tuple, list)):
result[key] = "\n".join(value)
else:
result[key] = value
return result
[docs]
def predefine_settings(obj, **kwargs):
"""Specify some pre-defined settings."""
for key, value in kwargs.items():
if hasattr(obj, key):
setattr(obj, key, value)
else:
raise AttributeError(f"{obj!r} object has no attribute {key!r}")
[docs]
def get_ase_from_file(fname, file_format=None): # pylint: disable=redefined-builtin
"""Get ASE structure object."""
# store_tags parameter is useful for CIF files
# https://wiki.fysik.dtu.dk/ase/ase/io/formatoptions.html#cif
if file_format == "cif":
traj = read(fname, format=file_format, index=":", store_tags=True)
else:
traj = read(fname, format=file_format, index=":")
if not traj:
raise ValueError(f"Could not read any information from the file {fname}")
return traj
[docs]
def find_ranges(iterable):
"""Yield range of consecutive numbers."""
for grp in mit.consecutive_groups(iterable):
group = list(grp)
if len(group) == 1:
yield group[0]
else:
yield group[0], group[-1]
[docs]
def list_to_string_range(lst, shift=1):
"""Converts a list like [0, 2, 3, 4] into a string like '1 3..5'.
Shift used when e.g. for a user interface numbering starts from 1 not from 0"""
return " ".join(
[
f"{t[0] + shift}..{t[1] + shift}"
if isinstance(t, tuple)
else str(t + shift)
for t in find_ranges(sorted(lst))
]
)
[docs]
def string_range_to_list(strng, shift=-1):
"""Converts a string like '1 3..5' into a list like [0, 2, 3, 4].
Shift used when e.g. for a user interface numbering starts from 1 not from 0"""
singles = [int(s) + shift for s in strng.split() if s.isdigit()]
ranges = [r for r in strng.split() if ".." in r]
if len(singles) + len(ranges) != len(strng.split()):
return [], False
for rng in ranges:
try:
start, end = rng.split("..")
singles += [i + shift for i in range(int(start), int(end) + 1)]
except ValueError:
return [], False
return singles, True
[docs]
class PinholeCamera:
[docs]
def __init__(self, matrix):
self.matrix = np.reshape(matrix, (4, 4)).transpose()
[docs]
def screen_to_vector(self, move_vector):
"""Converts vector from the screen coordinates to the normalized vector in 3D."""
move_vector[0] = -move_vector[0] # the x axis seem to be reverted in nglview.
res = np.append(np.array(move_vector), [0])
res = self.inverse_matrix.dot(res)
res /= np.linalg.norm(res)
return res[0:3]
@property
def inverse_matrix(self):
return np.linalg.inv(self.matrix)
[docs]
class StatusHTML(_StatusWidgetMixin, ipw.HTML):
"""Show temporary HTML messages for example for status updates."""
new_line = "<br>"
# This method should be part of _StatusWidgetMixin, but that does not work
# for an unknown reason.
@traitlets.observe("message")
def _observe_message(self, change):
self.show_temporary_message(change["new"])
# Define the message levels as Enum
[docs]
class MessageLevel(Enum):
INFO = "info"
WARNING = "warning"
ERROR = "danger"
SUCCESS = "success"
[docs]
def wrap_message(message, level=MessageLevel.INFO):
"""Wrap message into HTML code with the given level."""
# mapping level to fa icon
# https://fontawesome.com/v4.7.0/icons/
mapping = {
MessageLevel.INFO: "info-circle",
MessageLevel.WARNING: "exclamation-triangle",
MessageLevel.ERROR: "exclamation-circle",
MessageLevel.SUCCESS: "check-circle",
}
# The message is wrapped into a div with the class "alert" and the icon of the given level
return f"""
<div class="alert alert-{level.value}" role="alert" style="margin-bottom: 0px; padding: 6px 12px;">
<i class="fa fa-{mapping[level]}"></i>{message}
</div>
"""
[docs]
def ase2spglib(ase_structure: Atoms) -> tuple[Any, Any, Any]:
"""
Convert ase Atoms instance to spglib cell in the format defined at
https://spglib.github.io/spglib/python-spglib.html#crystal-structure-cell
"""
lattice = ase_structure.get_cell()
positions = ase_structure.get_scaled_positions()
numbers = ase_structure.get_atomic_numbers()
return (lattice, positions, numbers)