TorchSimTrajectory¶
- class torch_sim.trajectory.TorchSimTrajectory(filename, *, mode='r', compress_data=True, coerce_to_float32=True, coerce_to_int32=False, metadata=None)[source]¶
Bases:
object
Trajectory storage and retrieval for molecular dynamics simulations.
This class provides a low-level interface for reading and writing trajectory data to/from HDF5 files. It supports storing SimState objects, raw arrays, and conversion to common molecular modeling formats (ASE, pymatgen).
- Variables:
- Parameters:
Examples
>>> # Writing a trajectory >>> with TorchSimTrajectory('output.hdf5', mode='w') as traj: >>> for step, state in enumerate(simulation): >>> traj.write_state(state, step) >>> >>> # Reading a trajectory >>> with TorchSimTrajectory('output.hdf5', mode='r') as traj: >>> state = traj.get_state(frame=10) >>> structure = traj.get_structure(frame=-1) # Last frame
- write_arrays(data, steps)[source]¶
Write arrays to the trajectory file.
This function is used to write arrays to the trajectory file. If steps is an integer, we assume that the arrays in data are for a single frame. If steps is a list, we assume that the arrays in data are for multiple frames. This determines whether we pad arrays with a first dimension of size 1.
We also validate that the arrays are compatible with the existing arrays in the file and that the steps are monotonically increasing.
- Parameters:
- Raises:
ValueError – If array shapes or dtypes don’t match existing arrays, or if steps are not monotonically increasing
- Return type:
None
- get_array(name, start=None, stop=None, step=1)[source]¶
Get an array from the file.
Retrieves a subset of frames from the specified array.
- Parameters:
- Returns:
Array data as numpy array with shape [n_selected_frames, …]
- Return type:
np.ndarray
- Raises:
ValueError – If array name not found in registry
- get_steps(name, start=None, stop=None, step=1)[source]¶
Get the steps for an array.
Retrieves the simulation step numbers associated with frames in an array.
- Parameters:
- Returns:
Array of step numbers with shape [n_selected_frames]
- Return type:
np.ndarray
- write_state(state, steps, batch_index=None, *, save_velocities=False, save_forces=False, variable_cell=True, variable_masses=False, variable_atomic_numbers=False)[source]¶
Write a SimState or list of SimStates to the file.
Extracts and stores position, velocity, force, and other data from SimState objects. Static data (like cell parameters) is stored only once unless flagged as variable.
If a list, the states are assumed to be different configurations of the same system, representing a trajectory.
- Parameters:
state (SimState | list[SimState]) – SimState or list of SimStates to write
batch_index (int, optional) – Batch index to save.
save_velocities (bool, optional) – Whether to save velocities.
save_forces (bool, optional) – Whether to save forces.
variable_cell (bool, optional) – Whether the cell varies between frames.
variable_masses (bool, optional) – Whether masses vary between frames.
variable_atomic_numbers (bool, optional) – Whether atomic numbers vary between frames.
- Raises:
ValueError – If number of states doesn’t match number of steps or if required attributes are missing
- Return type:
None
- get_structure(frame=-1)[source]¶
Get a pymatgen Structure object for a given frame.
Converts the state at the specified frame to a pymatgen Structure object for analysis and visualization.
- Parameters:
frame (int, optional) – Frame index to retrieve. Defaults to -1 for last frame.
- Returns:
Pymatgen Structure object for the specified frame
- Return type:
Structure
- Raises:
ImportError – If pymatgen is not installed
- get_atoms(frame=-1)[source]¶
Get an ASE Atoms object for a given frame.
Converts the state at the specified frame to an ASE Atoms object for analysis and visualization.
- Parameters:
frame (int) – Frame index to retrieve (-1 for last frame)
- Returns:
ASE Atoms object for the specified frame
- Return type:
Atoms
- Raises:
ImportError – If ASE is not installed
- get_state(frame=-1, device=None, dtype=None)[source]¶
Get a SimState object for a given frame.
Reconstructs a SimState object from the data stored for a specific frame.
- Parameters:
- Returns:
- State object containing all available data for the frame with
shapes matching the original stored state
- Return type:
- close()[source]¶
Close the HDF5 file handle.
Ensures all data is written to disk and releases the file handle.
- Return type:
None
- flush()[source]¶
Write all buffered data to the disk file.
Forces any pending data to be written to the physical storage.
- Return type:
None
- write_ase_trajectory(filename)[source]¶
Convert trajectory to ASE Trajectory format.
Writes the entire trajectory to a new file in ASE format for compatibility with ASE analysis tools.
- Parameters:
filename (str | Path) – Path to the output ASE trajectory file
- Returns:
ASE trajectory object
- Return type:
ase.io.trajectory.Trajectory
- Raises:
ImportError – If ASE is not installed