TorchSimTrajectory

class torch_sim.trajectory.TorchSimTrajectory(filename, *, mode='r', compress_data=True, coerce_to_float32=True, coerce_to_int32=False, metadata=None)[source]

Bases: object

Trajectory storage and retrieval for molecular dynamics simulations.

This class provides a low-level interface for reading and writing trajectory data to/from HDF5 files. It supports storing SimState objects, raw arrays, and conversion to common molecular modeling formats (ASE, pymatgen).

Variables:
  • _file (tables.File) – The HDF5 file handle

  • array_registry (dict) – Registry mapping array names to (shape, dtype) tuples

  • type_map (dict) – Mapping of numpy/torch dtypes to PyTables atom types

Parameters:

Examples

>>> # Writing a trajectory
>>> with TorchSimTrajectory('output.hdf5', mode='w') as traj:
>>>     for step, state in enumerate(simulation):
>>>         traj.write_state(state, step)
>>>
>>> # Reading a trajectory
>>> with TorchSimTrajectory('output.hdf5', mode='r') as traj:
>>>     state = traj.get_state(frame=10)
>>>     structure = traj.get_structure(frame=-1)  # Last frame
write_arrays(data, steps)[source]

Write arrays to the trajectory file.

This function is used to write arrays to the trajectory file. If steps is an integer, we assume that the arrays in data are for a single frame. If steps is a list, we assume that the arrays in data are for multiple frames. This determines whether we pad arrays with a first dimension of size 1.

We also validate that the arrays are compatible with the existing arrays in the file and that the steps are monotonically increasing.

Parameters:
  • data (dict[str, np.ndarray | Tensor]) – Dictionary mapping array names to numpy arrays or torch tensors with shapes [n_frames, …]

  • steps (int | list[int]) – Step number(s) for the frame(s) being written. If steps is an integer, arrays will be treated as single frame data.

Raises:

ValueError – If array shapes or dtypes don’t match existing arrays, or if steps are not monotonically increasing

Return type:

None

get_array(name, start=None, stop=None, step=1)[source]

Get an array from the file.

Retrieves a subset of frames from the specified array.

Parameters:
  • name (str) – Name of the array to retrieve

  • start (int, optional) – Starting frame index. Defaults to None.

  • stop (int, optional) – Ending frame index (exclusive). Defaults to None.

  • step (int, optional) – Step size between frames. Defaults to 1.

Returns:

Array data as numpy array with shape [n_selected_frames, …]

Return type:

np.ndarray

Raises:

ValueError – If array name not found in registry

get_steps(name, start=None, stop=None, step=1)[source]

Get the steps for an array.

Retrieves the simulation step numbers associated with frames in an array.

Parameters:
  • name (str) – Name of the array

  • start (int, optional) – Starting frame index. Defaults to None.

  • stop (int, optional) – Ending frame index (exclusive). Defaults to None.

  • step (int, optional) – Step size between frames. Defaults to 1.

Returns:

Array of step numbers with shape [n_selected_frames]

Return type:

np.ndarray

write_state(state, steps, batch_index=None, *, save_velocities=False, save_forces=False, variable_cell=True, variable_masses=False, variable_atomic_numbers=False)[source]

Write a SimState or list of SimStates to the file.

Extracts and stores position, velocity, force, and other data from SimState objects. Static data (like cell parameters) is stored only once unless flagged as variable.

If a list, the states are assumed to be different configurations of the same system, representing a trajectory.

Parameters:
  • state (SimState | list[SimState]) – SimState or list of SimStates to write

  • steps (int | list[int]) – Step number(s) for the frame(s)

  • batch_index (int, optional) – Batch index to save.

  • save_velocities (bool, optional) – Whether to save velocities.

  • save_forces (bool, optional) – Whether to save forces.

  • variable_cell (bool, optional) – Whether the cell varies between frames.

  • variable_masses (bool, optional) – Whether masses vary between frames.

  • variable_atomic_numbers (bool, optional) – Whether atomic numbers vary between frames.

Raises:

ValueError – If number of states doesn’t match number of steps or if required attributes are missing

Return type:

None

get_structure(frame=-1)[source]

Get a pymatgen Structure object for a given frame.

Converts the state at the specified frame to a pymatgen Structure object for analysis and visualization.

Parameters:

frame (int, optional) – Frame index to retrieve. Defaults to -1 for last frame.

Returns:

Pymatgen Structure object for the specified frame

Return type:

Structure

Raises:

ImportError – If pymatgen is not installed

get_atoms(frame=-1)[source]

Get an ASE Atoms object for a given frame.

Converts the state at the specified frame to an ASE Atoms object for analysis and visualization.

Parameters:

frame (int) – Frame index to retrieve (-1 for last frame)

Returns:

ASE Atoms object for the specified frame

Return type:

Atoms

Raises:

ImportError – If ASE is not installed

get_state(frame=-1, device=None, dtype=None)[source]

Get a SimState object for a given frame.

Reconstructs a SimState object from the data stored for a specific frame.

Parameters:
  • frame (int) – Frame index to retrieve (-1 for last frame)

  • device (device, optional) – Device to place tensors on. Defaults to None.

  • dtype (dtype, optional) – Data type for tensors. Defaults to None.

Returns:

State object containing all available data for the frame with

shapes matching the original stored state

Return type:

SimState

property metadata: dict

Metadata for the trajectory.

close()[source]

Close the HDF5 file handle.

Ensures all data is written to disk and releases the file handle.

Return type:

None

flush()[source]

Write all buffered data to the disk file.

Forces any pending data to be written to the physical storage.

Return type:

None

write_ase_trajectory(filename)[source]

Convert trajectory to ASE Trajectory format.

Writes the entire trajectory to a new file in ASE format for compatibility with ASE analysis tools.

Parameters:

filename (str | Path) – Path to the output ASE trajectory file

Returns:

ASE trajectory object

Return type:

ase.io.trajectory.Trajectory

Raises:

ImportError – If ASE is not installed