Source code for torch_sim.trajectory

"""Trajectory format and reporting.

This module provides classes for reading and writing trajectory data in HDF5 format.
The core classes (TorchSimTrajectory and TrajectoryReporter) allow efficient storage
and retrieval of atomic positions, forces, energies, and other properties from
molecular dynamics simulations.

The TorchSimTrajectory does not aim to be a new trajectory standard, but rather
a simple interface for storing and retrieving trajectory data from HDF5 files.
It aims to support arbitrary arrays from the user in a natural way, allowing
it to be seamlessly extended to whatever attributes are important to the user.

Example:
    Reading and writing a trajectory file::

        # Writing to multiple trajectory files with a reporter
        reporter = TrajectoryReporter(["traj1.hdf5", "traj2.hdf5"], state_frequency=100)
        reporter.report(state, step=0, model=model)

        # Reading the file with a TorchSimTrajectory
        with TorchSimTrajectory("simulation.hdf5", mode="r") as traj:
            state = traj.get_state(frame=0)

Notes:
    This module uses PyTables (HDF5) for efficient I/O operations and supports
    compression to reduce file sizes. It can interoperate with ASE and pymatgen
    for visualization and analysis.
"""

import copy
import inspect
import pathlib
from collections.abc import Callable
from functools import partial
from typing import Any, Literal, Self

import numpy as np
import tables
import torch

from torch_sim.state import SimState


_DATA_TYPE_MAP = {
    np.dtype("float32"): tables.Float32Atom(),
    np.dtype("float64"): tables.Float64Atom(),
    np.dtype("int32"): tables.Int32Atom(),
    np.dtype("int64"): tables.Int64Atom(),
    np.dtype("bool"): tables.BoolAtom(),
    torch.float32: tables.Float32Atom(),
    torch.float64: tables.Float64Atom(),
    torch.int32: tables.Int32Atom(),
    torch.int64: tables.Int64Atom(),
    torch.bool: tables.BoolAtom(),
    bool: tables.BoolAtom(),
}
# ruff: noqa: SLF001


[docs] class TrajectoryReporter: """Trajectory reporter for saving simulation data at specified intervals. This class manages writing multiple trajectory files simultaneously. It handles periodic saving of full system states and custom property calculations. Attributes: state_frequency (int): How often to save full states (in simulation steps) prop_calculators (dict): Map of frequencies to property calculators state_kwargs (dict): Additional arguments for state writing metadata (dict): Metadata to save in trajectory files trajectories (list): TorchSimTrajectory instances filenames (list): Trajectory file paths array_registry (dict): Map of array names to (shape, dtype) tuples shape_warned (bool): Whether a shape warning has been issued Examples: >>> reporter = TrajectoryReporter( ... ["system1.h5", "system2.h5"], ... state_frequency=100, ... prop_calculators={10: {"energy": calculate_energy}}, ... ) >>> for step in range(1000): ... # Run simulation step ... state = step_fn(state) ... reporter.report(state, step, model) >>> reporter.close() """ def __init__( self, filenames: str | pathlib.Path | list[str | pathlib.Path] | None, state_frequency: int = 100, *, prop_calculators: dict[int, dict[str, Callable]] | None = None, state_kwargs: dict | None = None, metadata: dict[str, str] | None = None, trajectory_kwargs: dict | None = None, ) -> None: """Initialize a TrajectoryReporter. Args: filenames (str | pathlib.Path | list[str | pathlib.Path]): Path(s) to save trajectory file(s). If None, the reporter will not save any trajectories but `TrajectoryReporter.report` can still be used to compute properties directly. state_frequency (int): How often to save state (in steps) prop_calculators (dict[int, dict[str, Callable]], optional): Dictionary mapping frequencies to property calculators where each calculator is a function that takes a state and optionally a model and returns a tensor. Defaults to None. state_kwargs (dict, optional): Additional arguments for state writing. Passed to the `TorchSimTrajectory.write_state` method. These can be set to save the velocities and forces or to allow variable masses, and atomic numbers across the trajectory. metadata (dict[str, str], optional): Metadata to save in trajectory file. trajectory_kwargs (dict, optional): Additional arguments for trajectory initialization. Passed to the `TorchSimTrajectory.__init__` method. Raises: ValueError: If filenames are not unique """ self.state_frequency = state_frequency self.trajectory_kwargs = trajectory_kwargs or {} # default is to force overwrite self.trajectory_kwargs["mode"] = self.trajectory_kwargs.get("mode", "w") self.prop_calculators = prop_calculators or {} self.state_kwargs = state_kwargs or {} self.shape_warned = False self.metadata = metadata self.trajectories = [] if filenames is None: self.filenames = None self.trajectories = [] else: self.load_new_trajectories(filenames) self._add_model_arg_to_prop_calculators()
[docs] def load_new_trajectories( self, filenames: str | pathlib.Path | list[str | pathlib.Path] ) -> None: """Load new trajectories into the reporter. Closes any existing trajectory files and initializes new ones. Args: filenames (str | pathlib.Path | list[str | pathlib.Path]): Path(s) to save trajectory file(s) Raises: ValueError: If filenames are not unique """ self.finish() filenames = [filenames] if not isinstance(filenames, list) else filenames self.filenames = [pathlib.Path(filename) for filename in filenames] if len(set(self.filenames)) != len(self.filenames): raise ValueError("All filenames must be unique.") self.trajectories = [] for filename in self.filenames: self.trajectories.append( TorchSimTrajectory( filename=filename, metadata=self.metadata, **self.trajectory_kwargs, ) )
@property def array_registry(self) -> dict[str, tuple[tuple[int, ...], np.dtype]]: """Registry of array shapes and dtypes.""" # Return the registry from the first trajectory if self.trajectories: return self.trajectories[0].array_registry return {} def _add_model_arg_to_prop_calculators(self) -> None: """Add model argument to property calculators that only accept state. Transforms single-argument (state) property calculators to accept the dual-argument (state, model) interface by creating partial functions with an optional second argument. """ for frequency in self.prop_calculators: for name, prop_fn in self.prop_calculators[frequency].items(): # Get function signature sig = inspect.signature(prop_fn) # If function only takes one parameter, wrap it to accept two if len(sig.parameters) == 1: # we partially evaluate the function to create a new function with # an optional second argument, this can be set to state later on new_fn = partial(lambda state, _=None, fn=None: fn(state), fn=prop_fn) self.prop_calculators[frequency][name] = new_fn
[docs] def report( self, state: SimState, step: int, model: torch.nn.Module | None = None, ) -> list[dict[str, torch.Tensor]]: """Report a state and step to the trajectory files. Writes states and calculated properties to all trajectory files at the specified frequencies. Splits multi-batch states across separate trajectory files. The number of batches must match the number of trajectory files. Args: state (SimState): Current system state with n_batches equal to len(filenames) step (int): Current simulation step, setting step to 0 will write the state and all properties. model (torch.nn.Module, optional): Model used for simulation. Defaults to None. Must be provided if any prop_calculators are provided. write_to_file (bool, optional): Whether to write the state to the trajectory files. Defaults to True. Should only be set to `False` if the props are being collected separately. Returns: list[dict[str, torch.Tensor]]: Map of property names to tensors for each batch Raises: ValueError: If number of batches doesn't match number of trajectory files """ # Get unique batch indices batch_indices = range(state.n_batches) # batch_indices = torch.unique(state.batch).cpu().tolist() # Ensure we have the right number of trajectories if self.filenames is not None and len(batch_indices) != len(self.trajectories): raise ValueError( f"Number of batches ({len(batch_indices)}) doesn't match " f"number of trajectory files ({len(self.trajectories)})" ) split_states = state.split() all_props: list[dict[str, torch.Tensor]] = [] # Process each batch separately for idx, substate in enumerate(split_states): # Slice the state once to get only the data for this batch self.shape_warned = True # Write state to trajectory if it's time if ( self.state_frequency and step % self.state_frequency == 0 and self.filenames is not None ): self.trajectories[idx].write_state(substate, step, **self.state_kwargs) all_state_props = {} # Process property calculators for this batch for report_frequency, calculators in self.prop_calculators.items(): if step % report_frequency != 0 or report_frequency == 0: continue # Calculate properties for this substate props = {} for prop_name, prop_fn in calculators.items(): prop = prop_fn(substate, model) if len(prop.shape) == 0: prop = prop.unsqueeze(0) props[prop_name] = prop # Write properties to this trajectory if props: all_state_props.update(props) if self.filenames is not None: self.trajectories[idx].write_arrays(props, step) all_props.append(all_state_props) return all_props
[docs] def finish(self) -> None: """Finish writing the trajectory files. Closes all open trajectory files. """ for trajectory in self.trajectories: trajectory.close()
[docs] def close(self) -> None: """Close all trajectory files. Ensures all data is written to disk and releases the file handles. """ for trajectory in self.trajectories: trajectory.close()
def __enter__(self) -> "TrajectoryReporter": """Support the context manager protocol. Returns: TrajectoryReporter: The reporter instance """ return self def __exit__(self, *exc_info) -> None: """Support the context manager protocol. Closes all trajectory files when exiting the context. Args: *exc_info: Exception information """ self.close()
[docs] class TorchSimTrajectory: """Trajectory storage and retrieval for molecular dynamics simulations. This class provides a low-level interface for reading and writing trajectory data to/from HDF5 files. It supports storing SimState objects, raw arrays, and conversion to common molecular modeling formats (ASE, pymatgen). Attributes: _file (tables.File): The HDF5 file handle array_registry (dict): Registry mapping array names to (shape, dtype) tuples type_map (dict): Mapping of numpy/torch dtypes to PyTables atom types Examples: >>> # Writing a trajectory >>> with TorchSimTrajectory('output.hdf5', mode='w') as traj: >>> for step, state in enumerate(simulation): >>> traj.write_state(state, step) >>> >>> # Reading a trajectory >>> with TorchSimTrajectory('output.hdf5', mode='r') as traj: >>> state = traj.get_state(frame=10) >>> structure = traj.get_structure(frame=-1) # Last frame """ def __init__( self, filename: str | pathlib.Path, *, mode: Literal["w", "a", "r"] = "r", compress_data: bool = True, coerce_to_float32: bool = True, coerce_to_int32: bool = False, metadata: dict[str, str] | None = None, ) -> None: """Initialize the trajectory file. Args: filename (str | pathlib.Path): Path to the HDF5 file mode ("w" | "a" | "r"): Mode to open the file in. "w" will create a new file and overwrite any existing file, "a" will append to the existing file and "r" will open the file for reading only. Defaults to "r". compress_data (bool): Whether to compress the data using zlib compression. Defaults to True. coerce_to_float32 (bool): Whether to coerce float64 data to float32. Defaults to True. coerce_to_int32 (bool): Whether to coerce int64 data to int32. Defaults to False. metadata (dict[str, str], optional): Additional metadata to save in trajectory. Raises: ValueError: If the file cannot be opened or initialized """ filename = pathlib.Path(filename) if compress_data: compression = tables.Filters(complib="zlib", shuffle=True, complevel=1) else: compression = None # TODO FIX THIS if handles := tables.file._open_files.get_handlers_by_name(str(filename)): list(handles)[-1].close() # create parent directory if it doesn't exist filename.parent.mkdir(parents=True, exist_ok=True) self._file = tables.open_file(str(filename), mode=mode, filters=compression) self.array_registry: dict[str, tuple[tuple[int, ...], np.dtype]] = {} # check if the header has already been written if "header" not in [node._v_name for node in self._file.list_nodes("/")]: self._initialize_header(metadata) self._initialize_registry() self.type_map = self._initialize_type_map( coerce_to_float32=coerce_to_float32, coerce_to_int32=coerce_to_int32 ) def _initialize_header(self, metadata: dict[str, str] | None = None) -> None: """Initialize the HDF5 file header with metadata. Creates the basic structure of the HDF5 file with header, metadata, data, and steps groups. Args: metadata (dict[str, str], optional): Metadata to store in the header. """ self._file.create_group("/", "header") self._file.root.header._v_attrs.program = "TorchSim" self._file.root.header._v_attrs.title = "TorchSim Trajectory" self._file.create_group("/", "metadata") if metadata: for key, value in metadata.items(): setattr(self._file.root.metadata._v_attrs, key, value) self._file.create_group("/", "data") self._file.create_group("/", "steps") def _initialize_registry(self) -> None: """Initialize the array registry from an existing file. Scans the HDF5 file to build a registry of array names, shapes, and data types for validation of subsequent write operations. """ for node in self._file.list_nodes("/data/"): name = node.name dtype = node.dtype shape = tuple(int(ix) for ix in node.shape)[1:] self.array_registry[name] = (shape, dtype) def _initialize_type_map( self, *, coerce_to_float32: bool, coerce_to_int32: bool ) -> dict: """Initialize the type map for data type coercion. Creates a mapping from numpy/torch data types to PyTables atom types, with optional type coercion for reduced file size. Args: coerce_to_float32 (bool): Whether to coerce float64 data to float32 coerce_to_int32 (bool): Whether to coerce int64 data to int32 Returns: dict: Dictionary mapping numpy/torch dtypes to PyTables atom types """ type_map = copy.copy(_DATA_TYPE_MAP) if coerce_to_int32: type_map[torch.int64] = tables.Int32Atom() type_map[np.dtype("int64")] = tables.Int32Atom() if coerce_to_float32: type_map[torch.float64] = tables.Float32Atom() type_map[np.dtype("float64")] = tables.Float32Atom() return type_map
[docs] def write_arrays( self, data: dict[str, np.ndarray | torch.Tensor], steps: int | list[int], ) -> None: """Write arrays to the trajectory file. This function is used to write arrays to the trajectory file. If steps is an integer, we assume that the arrays in data are for a single frame. If steps is a list, we assume that the arrays in data are for multiple frames. This determines whether we pad arrays with a first dimension of size 1. We also validate that the arrays are compatible with the existing arrays in the file and that the steps are monotonically increasing. Args: data (dict[str, np.ndarray | torch.Tensor]): Dictionary mapping array names to numpy arrays or torch tensors with shapes [n_frames, ...] steps (int | list[int]): Step number(s) for the frame(s) being written. If steps is an integer, arrays will be treated as single frame data. Raises: ValueError: If array shapes or dtypes don't match existing arrays, or if steps are not monotonically increasing """ if isinstance(steps, int): pad_first_dim = True steps = [steps] else: pad_first_dim = False for name, array in data.items(): # TODO: coerce dtypes to numpy if isinstance(array, torch.Tensor): array = array.cpu().detach().numpy() if pad_first_dim: # pad 1st dim of array with 1 array = array[np.newaxis, ...] if name not in self.array_registry: self._initialize_array(name, array) self._validate_array(name, array, steps) self._serialize_array(name, array, steps) self.flush()
def _initialize_array(self, name: str, array: np.ndarray) -> None: """Initialize a single array and add it to the registry. Creates a new array in the HDF5 file and registers its shape and dtype. Args: name (str): Name of the array array (np.ndarray): Array data to initialize with shape [n_frames, ...] Raises: ValueError: If the array dtype is not supported """ if array.dtype not in self.type_map: raise ValueError(f"Unsupported {array.dtype=}") self._file.create_earray( where="/data/", name=name, atom=self.type_map[array.dtype], shape=(0, *array.shape[1:]), ) self._file.create_earray( where="/steps/", name=name, atom=tables.Int32Atom(), shape=(0,) ) # in the registry we store the shape of the single-frame array # because the multi-frame array shape will change over time self.array_registry[name] = (array.shape[1:], array.dtype) def _validate_array(self, name: str, data: np.ndarray, steps: list[int]) -> None: """Validate that the data is compatible with the existing array. Checks that the array shape, dtype, and step numbers are compatible with the existing array in the file. Args: name (str): Name of the array data (np.ndarray): Array data to validate with shape [n_frames, ...] steps (list[int]): Step numbers to validate Raises: ValueError: If array shape or dtype doesn't match, or if steps aren't monotonically increasing """ # Get the registered shape and dtype registered_shape, registered_dtype = self.array_registry[name] # Validate shape if data.shape[1:] != registered_shape: # TODO: update this message raise ValueError( f"Array {name} shape mismatch. Expected {registered_shape}, " f"got {data.shape}" ) # Get the expected dtype from our type map expected_atom = self.type_map[data.dtype] stored_atom = self.type_map[registered_dtype] # Compare the PyTables atoms instead of numpy dtypes if type(expected_atom) is not type(stored_atom): raise ValueError( f"Array {name} dtype mismatch. Cannot convert {data.dtype} " f"to match stored dtype {registered_dtype}" ) # Validate step is monotonically increasing by checking HDF5 file directly steps_node = self._file.get_node("/steps/", name=name) if len(steps_node) > 0: last_step = steps_node[-1] # Get the last recorded step if steps[0] <= last_step: raise ValueError( f"{steps[0]=} must be greater than the last recorded " f"step {last_step} for array {name}" ) def _serialize_array(self, name: str, data: np.ndarray, steps: list[int]) -> None: """Add additional contents to an array already in the registry. Appends frames to an existing array and its associated step numbers. Args: name (str): Name of the array data (np.ndarray): Array data to serialize with shape [n_frames, ...] steps (list[int]): Step numbers for the frames Raises: ValueError: If number of steps doesn't match number of frames """ if len(steps) > 1 and data.shape[0] != len(steps): raise ValueError( f"Number of steps {len(steps)} must match the number of frames " f"{data.shape[0]} for array {name}" ) self._file.get_node(where="/data/", name=name).append(data) self._file.get_node(where="/steps/", name=name).append(steps)
[docs] def get_array( self, name: str, start: int | None = None, stop: int | None = None, step: int = 1, ) -> np.ndarray: """Get an array from the file. Retrieves a subset of frames from the specified array. Args: name (str): Name of the array to retrieve start (int, optional): Starting frame index. Defaults to None. stop (int, optional): Ending frame index (exclusive). Defaults to None. step (int, optional): Step size between frames. Defaults to 1. Returns: np.ndarray: Array data as numpy array with shape [n_selected_frames, ...] Raises: ValueError: If array name not found in registry """ if name not in self.array_registry: raise ValueError(f"Array {name} not found in registry") return self._file.root.data.__getitem__(name).read( start=start, stop=stop, step=step )
[docs] def get_steps( self, name: str, start: int | None = None, stop: int | None = None, step: int = 1, ) -> np.ndarray: """Get the steps for an array. Retrieves the simulation step numbers associated with frames in an array. Args: name (str): Name of the array start (int, optional): Starting frame index. Defaults to None. stop (int, optional): Ending frame index (exclusive). Defaults to None. step (int, optional): Step size between frames. Defaults to 1. Returns: np.ndarray: Array of step numbers with shape [n_selected_frames] """ return self._file.root.steps.__getitem__(name).read( start=start, stop=stop, step=step )
def __str__(self) -> str: """Get a string representation of the trajectory. Returns: str: Summary of arrays in the file including shapes and dtypes """ # summarize arrays and steps in the file summary = ["Arrays in file:"] for node in self._file.list_nodes("/data/"): shape_ints = tuple(int(ix) for ix in node.shape) steps = shape_ints[0] shape = shape_ints[1:] dtype = node.dtype summary.append(f" {node.name}: {steps=} with {shape=} and {dtype=}") return "\n".join(summary)
[docs] def write_state( # noqa: C901 self, state: SimState | list[SimState], steps: int | list[int], batch_index: int | None = None, *, save_velocities: bool = False, save_forces: bool = False, variable_cell: bool = True, variable_masses: bool = False, variable_atomic_numbers: bool = False, ) -> None: """Write a SimState or list of SimStates to the file. Extracts and stores position, velocity, force, and other data from SimState objects. Static data (like cell parameters) is stored only once unless flagged as variable. If a list, the states are assumed to be different configurations of the same system, representing a trajectory. Args: state (SimState | list[SimState]): SimState or list of SimStates to write steps (int | list[int]): Step number(s) for the frame(s) batch_index (int, optional): Batch index to save. save_velocities (bool, optional): Whether to save velocities. save_forces (bool, optional): Whether to save forces. variable_cell (bool, optional): Whether the cell varies between frames. variable_masses (bool, optional): Whether masses vary between frames. variable_atomic_numbers (bool, optional): Whether atomic numbers vary between frames. Raises: ValueError: If number of states doesn't match number of steps or if required attributes are missing """ # TODO: consider changing this reporting later # we wrap if isinstance(state, SimState): state = [state] if isinstance(steps, int): steps = [steps] if isinstance(batch_index, int): batch_index = [batch_index] sub_states = [state[batch_index] for state in state] elif batch_index is None and torch.unique(state[0].batch) == 0: batch_index = 0 sub_states = state else: raise ValueError( "Batch index must be specified if there are multiple batches" ) if len(sub_states) != len(steps): raise ValueError(f"{len(sub_states)=} must match the {len(steps)=}") # Initialize data dictionary with required arrays data = { "positions": torch.stack([s.positions for s in state]), } # Add optional arrays based on flags # Define optional arrays to save based on flags optional_arrays = { "velocities": save_velocities, "forces": save_forces, } # Loop through optional arrays and add them if requested for array_name, should_save in optional_arrays.items(): if should_save: if not hasattr(state[0], array_name): raise ValueError( f"{array_name.capitalize()} can only be saved " f"if included in the state being reported." ) data[array_name] = torch.stack([getattr(s, array_name) for s in state]) # Handle cell and masses based on variable flags if variable_cell: data["cell"] = torch.cat([s.cell for s in state]) elif "cell" not in self.array_registry: # Save cell only for first frame # we but cell in list because it doesn't need to be padded self.write_arrays({"cell": state[0].cell}, [0]) if variable_masses: data["masses"] = torch.stack([s.masses for s in state]) elif "masses" not in self.array_registry: # Save masses only for first frame self.write_arrays({"masses": state[0].masses}, 0) if variable_atomic_numbers: data["atomic_numbers"] = torch.stack([s.atomic_numbers for s in state]) elif "atomic_numbers" not in self.array_registry: # Save atomic numbers only for first frame self.write_arrays({"atomic_numbers": state[0].atomic_numbers}, 0) if "pbc" not in self.array_registry: self.write_arrays({"pbc": np.array(state[0].pbc)}, 0) # Write all arrays to file self.write_arrays(data, steps)
def _get_state_arrays(self, frame: int) -> dict[str, torch.Tensor]: """Get all available state tensors for a given frame. Retrieves all state-related arrays (positions, cell, masses, etc.) for a specific frame. Args: frame (int): Frame index to retrieve (-1 for last frame) Returns: dict[str, torch.Tensor]: Dictionary of tensor names to their values Raises: ValueError: If required arrays are missing from trajectory or frame is out of range """ arrays: dict[str, np.ndarray] = {} # Get required data if "positions" not in self.array_registry: keys = list(self.array_registry) raise ValueError( f"Positions not found in trajectory so cannot get structure. Have {keys=}" ) # check length of positions array n_frames = self._file.root.data.positions.shape[0] if frame < 0: frame = n_frames + frame if frame > n_frames: raise ValueError(f"{frame=} is out of range. Total frames: {n_frames:,}") arrays["positions"] = self.get_array("positions", start=frame, stop=frame + 1)[0] def return_prop(self: Self, prop: str, frame: int) -> np.ndarray: if getattr(self._file.root.data, prop).shape[0] > 1: # Variable prop start, stop = frame, frame + 1 else: # Static prop start, stop = 0, 1 return self.get_array(prop, start=start, stop=stop)[0] arrays["cell"] = np.expand_dims(return_prop(self, "cell", frame), axis=0) arrays["atomic_numbers"] = return_prop(self, "atomic_numbers", frame) arrays["masses"] = return_prop(self, "masses", frame) arrays["pbc"] = return_prop(self, "pbc", frame) return arrays
[docs] def get_structure(self, frame: int = -1) -> Any: """Get a pymatgen Structure object for a given frame. Converts the state at the specified frame to a pymatgen Structure object for analysis and visualization. Args: frame (int, optional): Frame index to retrieve. Defaults to -1 for last frame. Returns: Structure: Pymatgen Structure object for the specified frame Raises: ImportError: If pymatgen is not installed """ from pymatgen.core import Structure arrays = self._get_state_arrays(frame) # Create pymatgen Structure # TODO: check if this is correct lattice = arrays["cell"][0].T # pymatgen expects lattice matrix as rows species = [str(num) for num in arrays["atomic_numbers"]] return Structure( lattice=np.ascontiguousarray(lattice), species=species, coords=np.ascontiguousarray(arrays["positions"]), coords_are_cartesian=True, validate_proximity=False, )
[docs] def get_atoms(self, frame: int = -1) -> Any: """Get an ASE Atoms object for a given frame. Converts the state at the specified frame to an ASE Atoms object for analysis and visualization. Args: frame (int): Frame index to retrieve (-1 for last frame) Returns: Atoms: ASE Atoms object for the specified frame Raises: ImportError: If ASE is not installed """ from ase import Atoms arrays = self._get_state_arrays(frame) pbc = arrays.get("pbc", True) return Atoms( numbers=np.ascontiguousarray(arrays["atomic_numbers"]), positions=np.ascontiguousarray(arrays["positions"]), cell=np.ascontiguousarray(arrays["cell"])[0], pbc=pbc, )
[docs] def get_state( self, frame: int = -1, device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> SimState: """Get a SimState object for a given frame. Reconstructs a SimState object from the data stored for a specific frame. Args: frame (int): Frame index to retrieve (-1 for last frame) device (torch.device, optional): Device to place tensors on. Defaults to None. dtype (torch.dtype, optional): Data type for tensors. Defaults to None. Returns: SimState: State object containing all available data for the frame with shapes matching the original stored state """ device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = dtype or torch.float64 arrays = self._get_state_arrays(frame) # Create SimState with required attributes return SimState( positions=torch.tensor(arrays["positions"], device=device, dtype=dtype), masses=torch.tensor(arrays.get("masses", None), device=device, dtype=dtype), cell=torch.tensor(arrays["cell"], device=device, dtype=dtype), pbc=arrays.get("pbc", True), atomic_numbers=torch.tensor( arrays["atomic_numbers"], device=device, dtype=torch.int ), )
@property def metadata(self) -> dict: """Metadata for the trajectory.""" attrs = self._file.root.metadata._v_attrs return {name: getattr(attrs, name) for name in attrs._f_list()}
[docs] def close(self) -> None: """Close the HDF5 file handle. Ensures all data is written to disk and releases the file handle. """ if self._file.isopen: # TODO: ??? self._file.close()
def __enter__(self) -> "TorchSimTrajectory": """Support the context manager protocol. Returns: TorchSimTrajectory: The trajectory instance """ return self def __exit__(self, *exc_info) -> None: """Support the context manager protocol. Closes the file when exiting the context. Args: *exc_info: Exception information """ self.close()
[docs] def flush(self) -> None: """Write all buffered data to the disk file. Forces any pending data to be written to the physical storage. """ if self._file.isopen: self._file.flush()
def __len__(self) -> int: """Get the number of frames in the trajectory. Returns: int: Number of frames in the trajectory """ return self._file.root.data.positions.shape[0]
[docs] def write_ase_trajectory(self, filename: str | pathlib.Path) -> Any: """Convert trajectory to ASE Trajectory format. Writes the entire trajectory to a new file in ASE format for compatibility with ASE analysis tools. Args: filename (str | pathlib.Path): Path to the output ASE trajectory file Returns: ase.io.trajectory.Trajectory: ASE trajectory object Raises: ImportError: If ASE is not installed """ try: from ase.io.trajectory import Trajectory except ImportError: raise ImportError( "ASE is required to convert to ASE trajectory. Run `pip install ase`" ) from None # Create ASE trajectory traj = Trajectory(filename, mode="w") # Write each frame for frame in range(len(self)): atoms = self.get_atoms(frame) traj.write(atoms) traj.close() return Trajectory(filename) # Reopen in read mode