Source code for torch_sim.state

"""The core state representation.

The main SimState class represents atomistic systems with support for batched
operations and conversion to/from various atomistic formats.
"""

import copy
import importlib
import typing
import warnings
from collections import defaultdict
from collections.abc import Generator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self, cast

import torch

import torch_sim as ts
from torch_sim.typing import SimStateVar, StateLike


if TYPE_CHECKING:
    from ase import Atoms
    from phonopy.structure.atoms import PhonopyAtoms
    from pymatgen.core import Structure


[docs] @dataclass(init=False) class SimState: """State representation for atomistic systems with batched operations support. Contains the fundamental properties needed to describe an atomistic system: positions, masses, unit cell, periodic boundary conditions, and atomic numbers. Supports batched operations where multiple atomistic systems can be processed simultaneously, managed through system indices. States support slicing, cloning, splitting, popping, and movement to other data structures or devices. Slicing is supported through fancy indexing, e.g. `state[[0, 1, 2]]` will return a new state containing only the first three systems. The other operations are available through the `pop`, `split`, `clone`, and `to` methods. Attributes: positions (torch.Tensor): Atomic positions with shape (n_atoms, 3) masses (torch.Tensor): Atomic masses with shape (n_atoms,) cell (torch.Tensor): Unit cell vectors with shape (n_systems, 3, 3). Note that we use a column vector convention, i.e. the cell vectors are stored as `[[a1, b1, c1], [a2, b2, c2], [a3, b3, c3]]` as opposed to the row vector convention `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]` used by ASE. pbc (bool): Boolean indicating whether to use periodic boundary conditions atomic_numbers (torch.Tensor): Atomic numbers with shape (n_atoms,) system_idx (torch.Tensor): Maps each atom index to its system index. Has shape (n_atoms,), must be unique consecutive integers starting from 0. Properties: wrap_positions (torch.Tensor): Positions wrapped according to periodic boundary conditions device (torch.device): Device of the positions tensor dtype (torch.dtype): Data type of the positions tensor n_atoms (int): Total number of atoms across all systems n_systems (int): Number of unique systems in the system Notes: - positions, masses, and atomic_numbers must have shape (n_atoms, 3). - cell must be in the conventional matrix form. - system indices must be unique consecutive integers starting from 0. Examples: >>> state = initialize_state( ... [ase_atoms_1, ase_atoms_2, ase_atoms_3], device, dtype ... ) >>> state.n_systems 3 >>> new_state = state[[0, 1]] >>> new_state.n_systems 2 >>> cloned_state = state.clone() """ positions: torch.Tensor masses: torch.Tensor cell: torch.Tensor pbc: bool # TODO: do all calculators support mixed pbc? atomic_numbers: torch.Tensor system_idx: torch.Tensor _atom_attributes: ClassVar[set[str]] = { "positions", "masses", "atomic_numbers", "system_idx", } _system_attributes: ClassVar[set[str]] = {"cell"} _global_attributes: ClassVar[set[str]] = {"pbc"} def __init__( self, positions: torch.Tensor, masses: torch.Tensor, cell: torch.Tensor, pbc: bool, # noqa: FBT001 # TODO(curtis): maybe make the constructor be keyword-only (it can be easy to confuse positions vs masses, etc.) atomic_numbers: torch.Tensor, system_idx: torch.Tensor | None = None, ) -> None: """Initialize the SimState and validate the arguments. Args: positions (torch.Tensor): Atomic positions with shape (n_atoms, 3) masses (torch.Tensor): Atomic masses with shape (n_atoms,) cell (torch.Tensor): Unit cell vectors with shape (n_systems, 3, 3). pbc (bool): Boolean indicating whether to use periodic boundary conditions atomic_numbers (torch.Tensor): Atomic numbers with shape (n_atoms,) system_idx (torch.Tensor | None): Maps each atom index to its system index. Has shape (n_atoms,), must be unique consecutive integers starting from 0. If not provided, it is initialized to zeros. """ self.positions = positions self.masses = masses self.cell = cell self.pbc = pbc self.atomic_numbers = atomic_numbers # Validate and process the state after initialization. # data validation and fill system_idx # should make pbc a tensor here # if devices aren't all the same, raise an error, in a clean way devices = { attr: getattr(self, attr).device for attr in ("positions", "masses", "cell", "atomic_numbers") } if len(set(devices.values())) > 1: raise ValueError("All tensors must be on the same device") # Check that positions, masses and atomic numbers have compatible shapes shapes = [ getattr(self, attr).shape[0] for attr in ("positions", "masses", "atomic_numbers") ] if len(set(shapes)) > 1: raise ValueError( f"Incompatible shapes: positions {shapes[0]}, " f"masses {shapes[1]}, atomic_numbers {shapes[2]}" ) if system_idx is None: self.system_idx = torch.zeros( self.n_atoms, device=self.device, dtype=torch.int64 ) else: # assert that system indices are unique consecutive integers # TODO(curtis): I feel like this logic is not reliable. # I'll come up with something better later. _, counts = torch.unique_consecutive(system_idx, return_counts=True) if not torch.all(counts == torch.bincount(system_idx)): raise ValueError("System indices must be unique consecutive integers") self.system_idx = system_idx if self.cell.ndim != 3 and system_idx is None: self.cell = self.cell.unsqueeze(0) if self.cell.shape[-2:] != (3, 3): raise ValueError("Cell must have shape (n_systems, 3, 3)") if self.cell.shape[0] != self.n_systems: raise ValueError( f"Cell must have shape (n_systems, 3, 3), got {self.cell.shape}" ) @property def wrap_positions(self) -> torch.Tensor: """Atomic positions wrapped according to periodic boundary conditions if pbc=True, otherwise returns unwrapped positions with shape (n_atoms, 3). """ # TODO: implement a wrapping method return self.positions @property def device(self) -> torch.device: """The device where the tensor data is located.""" return self.positions.device @property def dtype(self) -> torch.dtype: """The data type of the positions tensor.""" return self.positions.dtype @property def n_atoms(self) -> int: """Total number of atoms in the system across all systems.""" return self.positions.shape[0] @property def n_atoms_per_system(self) -> torch.Tensor: """Number of atoms per system.""" return ( self.system_idx.bincount() if self.system_idx is not None else torch.tensor([self.n_atoms], device=self.device) ) @property def n_atoms_per_batch(self) -> torch.Tensor: """Number of atoms per batch. deprecated:: Use :attr:`n_atoms_per_system` instead. """ warnings.warn( "n_atoms_per_batch is deprecated, use n_atoms_per_system instead", DeprecationWarning, stacklevel=2, ) return self.n_atoms_per_system @property def batch(self) -> torch.Tensor: """System indices. deprecated:: Use :attr:`system_idx` instead. """ warnings.warn( "batch is deprecated, use system_idx instead", DeprecationWarning, stacklevel=2, ) return self.system_idx @batch.setter def batch(self, system_idx: torch.Tensor) -> None: """Set the system indices from a batch index. deprecated:: Use :attr:`system_idx` instead. """ warnings.warn( "Setting batch is deprecated, use system_idx instead", DeprecationWarning, stacklevel=2, ) self.system_idx = system_idx @property def n_batches(self) -> int: """Number of batches in the system. deprecated:: Use :attr:`n_systems` instead. """ warnings.warn( "n_batches is deprecated, use n_systems instead", DeprecationWarning, stacklevel=2, ) return self.n_systems @property def n_systems(self) -> int: """Number of systems in the system.""" return torch.unique(self.system_idx).shape[0] @property def volume(self) -> torch.Tensor: """Volume of the system.""" if not self.pbc: raise ValueError("Volume is only defined for periodic systems") return torch.det(self.cell) @property def column_vector_cell(self) -> torch.Tensor: """Unit cell following the column vector convention.""" return self.cell @column_vector_cell.setter def column_vector_cell(self, value: torch.Tensor) -> None: """Set the unit cell from value following the column vector convention. Args: value: The unit cell as a column vector """ self.cell = value @property def row_vector_cell(self) -> torch.Tensor: """Unit cell following the row vector convention.""" return self.cell.mT @row_vector_cell.setter def row_vector_cell(self, value: torch.Tensor) -> None: """Set the unit cell from value following the row vector convention. Args: value: The unit cell as a row vector """ self.cell = value.mT
[docs] def clone(self) -> Self: """Create a deep copy of the SimState. Creates a new SimState object with identical but independent tensors, allowing modification without affecting the original. Returns: SimState: A new SimState object with the same properties as the original """ attrs = {} for attr_name, attr_value in vars(self).items(): if isinstance(attr_value, torch.Tensor): attrs[attr_name] = attr_value.clone() else: attrs[attr_name] = copy.deepcopy(attr_value) return self.__class__(**attrs)
[docs] def to_atoms(self) -> list["Atoms"]: """Convert the SimState to a list of ASE Atoms objects. Returns: list[Atoms]: A list of ASE Atoms objects, one per system """ return ts.io.state_to_atoms(self)
[docs] def to_structures(self) -> list["Structure"]: """Convert the SimState to a list of pymatgen Structure objects. Returns: list[Structure]: A list of pymatgen Structure objects, one per system """ return ts.io.state_to_structures(self)
[docs] def to_phonopy(self) -> list["PhonopyAtoms"]: """Convert the SimState to a list of PhonopyAtoms objects. Returns: list[PhonopyAtoms]: A list of PhonopyAtoms objects, one per system """ return ts.io.state_to_phonopy(self)
[docs] def split(self) -> list[Self]: """Split the SimState into a list of single-system SimStates. Divides the current state into separate states, each containing a single system, preserving all properties appropriately for each system. Returns: list[SimState]: A list of SimState objects, one per system """ return _split_state(self)
[docs] def pop(self, system_indices: int | list[int] | slice | torch.Tensor) -> list[Self]: """Pop off states with the specified system indices. This method modifies the original state object by removing the specified systems and returns the removed systems as separate SimState objects. Args: system_indices (int | list[int] | slice | torch.Tensor): The system indices to pop Returns: list[SimState]: Popped SimState objects, one per system index Notes: This method modifies the original SimState in-place. """ system_indices = _normalize_system_indices( system_indices, self.n_systems, self.device ) # Get the modified state and popped states modified_state, popped_states = _pop_states(self, system_indices) # Update all attributes of self with the modified state's attributes for attr_name, attr_value in vars(modified_state).items(): setattr(self, attr_name, attr_value) return cast("list[Self]", popped_states)
[docs] def to( self, device: torch.device | None = None, dtype: torch.dtype | None = None ) -> Self: """Convert the SimState to a new device and/or data type. Args: device (torch.device, optional): The target device. Defaults to current device. dtype (torch.dtype, optional): The target data type. Defaults to current dtype. Returns: SimState: A new SimState with tensors on the specified device and dtype """ return state_to_device(self, device, dtype)
def __getitem__(self, system_indices: int | list[int] | slice | torch.Tensor) -> Self: """Enable standard Python indexing syntax for slicing batches. Args: system_indices (int | list[int] | slice | torch.Tensor): The system indices to include Returns: SimState: A new SimState containing only the specified systems """ # TODO: need to document that slicing is supported # Reuse the existing slice method system_indices = _normalize_system_indices( system_indices, self.n_systems, self.device ) return _slice_state(self, system_indices) def __init_subclass__(cls, **kwargs) -> None: """Enforce that all derived states cannot have tensor attributes that can also be None. This is because torch.concatenate cannot concat between a tensor and a None. See https://github.com/Radical-AI/torch-sim/pull/219 for more details. Also enforce all of child classes's attributes are specified in _atom_attributes, _system_attributes, or _global_attributes. """ cls._assert_no_tensor_attributes_can_be_none() cls._assert_all_attributes_have_defined_scope() super().__init_subclass__(**kwargs) @classmethod def _assert_no_tensor_attributes_can_be_none(cls) -> None: # We need to use get_type_hints to correctly inspect the types type_hints = typing.get_type_hints(cls) for attr_name, attr_typehint in type_hints.items(): origin = typing.get_origin(attr_typehint) is_union = origin is typing.Union if not is_union and origin is not None: # For Python 3.10+ `|` syntax, origin is types.UnionType # We check by name to be robust against module reloading/patching issues is_union = origin.__module__ == "types" and origin.__name__ == "UnionType" if is_union: args = typing.get_args(attr_typehint) if torch.Tensor in args and type(None) in args: raise TypeError( f"Attribute '{attr_name}' in class '{cls.__name__}' is not " "allowed to be of type 'torch.Tensor | None' because torch.cat " "cannot concatenate between a tensor and a None. Please default " "the tensor with dummy values and track the 'None' case." ) @classmethod def _assert_all_attributes_have_defined_scope(cls) -> None: all_defined_attributes = ( cls._atom_attributes | cls._system_attributes | cls._global_attributes ) # 1) assert that no attribute is defined twice in all_defined_attributes duplicates = ( (cls._atom_attributes & cls._system_attributes) | (cls._atom_attributes & cls._global_attributes) | (cls._system_attributes & cls._global_attributes) ) if duplicates: raise TypeError( f"Attributes {duplicates} are declared multiple times in {cls.__name__} " "in _atom_attributes, _system_attributes, or _global_attributes" ) # 2) assert that all attributes are defined in all_defined_attributes all_annotations = {} for c in cls.mro(): if hasattr(c, "__annotations__"): all_annotations.update(c.__annotations__) attributes_to_check = set(vars(cls).keys()) | set(all_annotations.keys()) for attr_name in attributes_to_check: is_special_attribute = attr_name.startswith("__") is_property = attr_name in vars(cls) and isinstance( vars(cls).get(attr_name), property ) is_method = hasattr(cls, attr_name) and callable(getattr(cls, attr_name)) is_class_variable = ( # Note: _atom_attributes, _system_attributes, and _global_attributes # are all class variables typing.get_origin(all_annotations.get(attr_name)) is typing.ClassVar ) if is_special_attribute or is_property or is_method or is_class_variable: continue if attr_name not in all_defined_attributes: raise TypeError( f"Attribute '{attr_name}' is not defined in {cls.__name__} in any " "of _atom_attributes, _system_attributes, or _global_attributes" )
[docs] @dataclass(kw_only=True) class DeformGradMixin: """Mixin for states that support deformation gradients.""" reference_cell: torch.Tensor _system_attributes: ClassVar[set[str]] = {"reference_cell"} if TYPE_CHECKING: # define this under a TYPE_CHECKING block to avoid it being included in the # dataclass __init__ during runtime row_vector_cell: torch.Tensor @property def reference_row_vector_cell(self) -> torch.Tensor: """Get the original unit cell in terms of row vectors.""" return self.reference_cell.mT @reference_row_vector_cell.setter def reference_row_vector_cell(self, value: torch.Tensor) -> None: """Set the original unit cell in terms of row vectors.""" self.reference_cell = value.mT @staticmethod def _deform_grad( reference_row_vector_cell: torch.Tensor, row_vector_cell: torch.Tensor ) -> torch.Tensor: """Calculate the deformation gradient from original cell to current cell. Returns: The deformation gradient """ return torch.linalg.solve(reference_row_vector_cell, row_vector_cell).transpose( -2, -1 )
[docs] def deform_grad(self) -> torch.Tensor: """Calculate the deformation gradient from original cell to current cell. Returns: The deformation gradient """ return self._deform_grad(self.reference_row_vector_cell, self.row_vector_cell)
def _normalize_system_indices( system_indices: int | list[int] | slice | torch.Tensor, n_systems: int, device: torch.device, ) -> torch.Tensor: """Normalize system indices to handle negative indices and different input types. Converts various system index representations to a consistent tensor format, handling negative indices in the Python style (counting from the end). Args: system_indices (int | list[int] | slice | torch.Tensor): The system indices to normalize n_systems (int): Total number of systems in the system device (torch.device): Device to place the output tensor on Returns: torch.Tensor: Normalized system indices as a tensor Raises: TypeError: If system_indices is of an unsupported type """ if isinstance(system_indices, int): # Handle negative integer indexing if system_indices < 0: system_indices = n_systems + system_indices return torch.tensor([system_indices], device=device) if isinstance(system_indices, list): # Handle negative indices in lists normalized = [idx if idx >= 0 else n_systems + idx for idx in system_indices] return torch.tensor(normalized, device=device) if isinstance(system_indices, slice): # Let PyTorch handle the slice conversion with negative indices return torch.arange(n_systems, device=device)[system_indices] if isinstance(system_indices, torch.Tensor): # Handle negative indices in tensors return torch.where(system_indices < 0, n_systems + system_indices, system_indices) raise TypeError(f"Unsupported index type: {type(system_indices)}")
[docs] def state_to_device( state: SimStateVar, device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> SimStateVar: """Convert the SimState to a new device and dtype. Creates a new SimState with all tensors moved to the specified device and with the specified data type. Args: state (SimState): The state to convert device (torch.device, optional): The target device. Defaults to current device. dtype (torch.dtype, optional): The target data type. Defaults to current dtype. Returns: SimState: A new SimState with tensors on the specified device and dtype """ if device is None: device = state.device if dtype is None: dtype = state.dtype attrs = vars(state) for attr_name, attr_value in attrs.items(): if isinstance(attr_value, torch.Tensor): attrs[attr_name] = attr_value.to(device=device) if dtype is not None: attrs["positions"] = attrs["positions"].to(dtype=dtype) attrs["masses"] = attrs["masses"].to(dtype=dtype) attrs["cell"] = attrs["cell"].to(dtype=dtype) attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int) return type(state)(**attrs)
[docs] def get_attrs_for_scope( state: SimState, scope: Literal["per-atom", "per-system", "global"] ) -> Generator[tuple[str, Any], None, None]: """Get attributes for a given scope. Args: state (SimState): The state to get attributes for scope (Literal["per-atom", "per-system", "global"]): The scope to get attributes for Returns: Generator[tuple[str, Any], None, None]: A generator of attribute names and values """ match scope: case "per-atom": attr_names = state._atom_attributes # noqa: SLF001 case "per-system": attr_names = state._system_attributes # noqa: SLF001 case "global": attr_names = state._global_attributes # noqa: SLF001 case _: raise ValueError(f"Unknown scope: {scope!r}") for attr_name in attr_names: yield attr_name, getattr(state, attr_name)
def _filter_attrs_by_mask( state: SimState, atom_mask: torch.Tensor, system_mask: torch.Tensor, ) -> dict: """Filter attributes by atom and system masks. Selects subsets of attributes based on boolean masks for atoms and systems. Args: state (SimState): The state to filter atom_mask (torch.Tensor): Boolean mask for atoms to include with shape (n_atoms,) system_mask (torch.Tensor): Boolean mask for systems to include with shape (n_systems,) Returns: dict: Filtered attributes with appropriate handling for each scope """ # Copy global attributes directly filtered_attrs = dict(get_attrs_for_scope(state, "global")) # Filter per-atom attributes for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): if attr_name == "system_idx": # Get the old system indices for the selected atoms old_system_idxs = attr_value[atom_mask] # Get the system indices that are kept kept_indices = torch.arange(attr_value.max() + 1, device=attr_value.device)[ system_mask ] # Create a mapping from old system indices to new consecutive indices system_idx_map = {idx.item(): i for i, idx in enumerate(kept_indices)} # Create new system tensor with remapped indices new_system_idxs = torch.tensor( [system_idx_map[b.item()] for b in old_system_idxs], device=attr_value.device, dtype=attr_value.dtype, ) filtered_attrs[attr_name] = new_system_idxs else: filtered_attrs[attr_name] = attr_value[atom_mask] # Filter per-system attributes for attr_name, attr_value in get_attrs_for_scope(state, "per-system"): filtered_attrs[attr_name] = attr_value[system_mask] return filtered_attrs def _split_state( state: SimStateVar, ) -> list[SimStateVar]: """Split a SimState into a list of states, each containing a single system. Divides a multi-system state into individual single-system states, preserving appropriate properties for each system. Args: state (SimState): The SimState to split Returns: list[SimState]: A list of SimState objects, each containing a single system """ system_sizes = torch.bincount(state.system_idx).tolist() split_per_atom = {} for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): if attr_name != "system_idx": split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0) split_per_system = {} for attr_name, attr_value in get_attrs_for_scope(state, "per-system"): split_per_system[attr_name] = torch.split(attr_value, 1, dim=0) global_attrs = dict(get_attrs_for_scope(state, "global")) # Create a state for each system states = [] n_systems = len(system_sizes) for i in range(n_systems): system_attrs = { # Create a system tensor with all zeros for this system "system_idx": torch.zeros( system_sizes[i], device=state.device, dtype=torch.int64 ), # Add the split per-atom attributes **{attr_name: split_per_atom[attr_name][i] for attr_name in split_per_atom}, # Add the split per-system attributes **{ attr_name: split_per_system[attr_name][i] for attr_name in split_per_system }, # Add the global attributes **global_attrs, } states.append(type(state)(**system_attrs)) return states def _pop_states( state: SimState, pop_indices: list[int] | torch.Tensor, ) -> tuple[SimState, list[SimState]]: """Pop off the states with the specified indices. Extracts and removes the specified system indices from the state. Args: state (SimState): The SimState to modify pop_indices (list[int] | torch.Tensor): The system indices to extract and remove Returns: tuple[SimState, list[SimState]]: A tuple containing: - The modified original state with specified systems removed - A list of the extracted SimStates, one per popped system Notes: Unlike the pop method, this function does not modify the input state. """ if len(pop_indices) == 0: return state, [] if isinstance(pop_indices, list): pop_indices = torch.tensor(pop_indices, device=state.device, dtype=torch.int64) # Create masks for the atoms and systems to keep and pop system_range = torch.arange(state.n_systems, device=state.device) pop_system_mask = torch.isin(system_range, pop_indices) keep_system_mask = ~pop_system_mask pop_atom_mask = torch.isin(state.system_idx, pop_indices) keep_atom_mask = ~pop_atom_mask # Filter attributes for keep and pop states keep_attrs = _filter_attrs_by_mask(state, keep_atom_mask, keep_system_mask) pop_attrs = _filter_attrs_by_mask(state, pop_atom_mask, pop_system_mask) # Create the keep state keep_state = type(state)(**keep_attrs) # Create and split the pop state pop_state = type(state)(**pop_attrs) pop_states = _split_state(pop_state) return keep_state, pop_states def _slice_state( state: SimStateVar, system_indices: list[int] | torch.Tensor, ) -> SimStateVar: """Slice a substate from the SimState containing only the specified system indices. Creates a new SimState containing only the specified systems, preserving all relevant properties. Args: state (SimState): The state to slice system_indices (list[int] | torch.Tensor): System indices to include in the sliced state Returns: SimState: A new SimState object containing only the specified systems Raises: ValueError: If system_indices is empty """ if isinstance(system_indices, list): system_indices = torch.tensor( system_indices, device=state.device, dtype=torch.int64 ) if len(system_indices) == 0: raise ValueError("system_indices cannot be empty") # Create masks for the atoms and systems to include system_range = torch.arange(state.n_systems, device=state.device) system_mask = torch.isin(system_range, system_indices) atom_mask = torch.isin(state.system_idx, system_indices) # Filter attributes filtered_attrs = _filter_attrs_by_mask(state, atom_mask, system_mask) # Create the sliced state return type(state)(**filtered_attrs)
[docs] def concatenate_states( states: list[SimState], device: torch.device | None = None ) -> SimState: """Concatenate a list of SimStates into a single SimState. Combines multiple states into a single state with multiple systems. Global properties are taken from the first state, and per-atom and per-system properties are concatenated. Args: states (list[SimState]): A list of SimState objects to concatenate device (torch.device, optional): The device to place the concatenated state on. Defaults to the device of the first state. Returns: SimState: A new SimState containing all input states as separate systems Raises: ValueError: If states is empty TypeError: If not all states are of the same type """ if not states: raise ValueError("Cannot concatenate an empty list of states") # Get the first state to determine properties first_state = states[0] # Ensure all states are of the same class state_class = type(first_state) if not all(isinstance(state, state_class) for state in states): raise TypeError("All states must be of the same type") # Use the target device or default to the first state's device target_device = device or first_state.device # Initialize result with global properties from first state concatenated = dict(get_attrs_for_scope(first_state, "global")) # Pre-allocate lists for tensors to concatenate per_atom_tensors = defaultdict(list) per_system_tensors = defaultdict(list) new_system_indices = [] system_offset = 0 # Process all states in a single pass for state in states: # Move state to target device if needed if state.device != target_device: state = state_to_device(state, target_device) # Collect per-atom properties for prop, val in get_attrs_for_scope(state, "per-atom"): if prop == "system_idx": # skip system_idx, it will be handled below continue per_atom_tensors[prop].append(val) # Collect per-system properties for prop, val in get_attrs_for_scope(state, "per-system"): per_system_tensors[prop].append(val) # Update system indices num_systems = state.n_systems new_indices = state.system_idx + system_offset new_system_indices.append(new_indices) system_offset += num_systems # Concatenate collected tensors for prop, tensors in per_atom_tensors.items(): # if tensors: concatenated[prop] = torch.cat(tensors, dim=0) for prop, tensors in per_system_tensors.items(): # if tensors: concatenated[prop] = torch.cat(tensors, dim=0) # Concatenate system indices concatenated["system_idx"] = torch.cat(new_system_indices) # Create a new instance of the same class return state_class(**concatenated)
[docs] def initialize_state( system: StateLike, device: torch.device, dtype: torch.dtype, ) -> SimState: """Initialize state tensors from a atomistic system representation. Converts various atomistic system representations (ASE Atoms, pymatgen Structure, PhonopyAtoms, or existing SimState) to a SimState object. Args: system (StateLike): Input system to convert to state tensors device (torch.device): Device to create tensors on dtype (torch.dtype): Data type for tensor values Returns: SimState: State representation initialized from input system Raises: ValueError: If system type is not supported or if list items have inconsistent types """ # TODO: create a way to pass velocities from pmg and ase if isinstance(system, SimState): return state_to_device(system, device, dtype) if isinstance(system, list) and all(isinstance(s, SimState) for s in system): if not all(cast("SimState", state).n_systems == 1 for state in system): raise ValueError( "When providing a list of states, to the initialize_state function, " "all states must have n_systems == 1. To fix this, you can split the " "states into individual states with the split_state function." ) return concatenate_states(system) converters = [ ("pymatgen.core", "Structure", ts.io.structures_to_state), ("ase", "Atoms", ts.io.atoms_to_state), ("phonopy.structure.atoms", "PhonopyAtoms", ts.io.phonopy_to_state), ] # Try each converter for module_path, class_name, converter_func in converters: try: module = importlib.import_module(module_path) cls = getattr(module, class_name) if isinstance(system, cls) or ( isinstance(system, list) and all(isinstance(s, cls) for s in system) ): return converter_func(system, device, dtype) except ImportError: continue # remaining code just for informative error is_list = isinstance(system, list) all_same_type = ( is_list and all(isinstance(s, type(system[0])) for s in system) and system ) if is_list and not all_same_type: raise ValueError( f"All items in list must be of the same type, " f"found {type(system[0])} and {type(system[1])}" ) system_type = f"list[{type(system[0])}]" if is_list else type(system) raise ValueError(f"Unsupported system type, {system_type}")