state_to_atom_graphs

torch_sim.models.orb.state_to_atom_graphs(state, *, wrap=True, edge_method=None, system_config=None, max_num_neighbors=None, system_id=None, half_supercell=False, device=None, output_dtype=None, graph_construction_dtype=None)[source]

Convert a SimState object into AtomGraphs format, ready for use in an ORB model.

Parameters:
  • state (ts.SimState) – SimState object containing atomic positions, cell, and atomic numbers

  • wrap (bool) – Whether to wrap atomic positions into the central unit cell (if there is

  • one).

  • edge_method (EdgeCreationMethod, optional) – The method to use for graph edge construction. If None, the edge method is chosen automatically based on device and system size.

  • system_config (SystemConfig | None) – The system configuration to use for graph construction.

  • max_num_neighbors (int | None) – Maximum number of neighbors each node can send messages to. If None, will use system_config.max_num_neighbors.

  • system_id (int | None) – Optional index that is relative to a particular dataset.

  • half_supercell (bool) – Whether to use half the supercell for graph construction. This can improve performance for large systems.

  • device (device | None) – The device to put the tensors on.

  • output_dtype (dtype | None) – The dtype to use for all floating point tensors stored on the AtomGraphs object.

  • graph_construction_dtype (dtype | None) – The dtype to use for floating point tensors in the graph construction.

Returns:

AtomGraphs object containing the graph representation of the atomic system

Return type:

AtomGraphs