state_to_atom_graphs¶
- torch_sim.models.orb.state_to_atom_graphs(state, *, wrap=True, edge_method=None, system_config=None, max_num_neighbors=None, system_id=None, half_supercell=False, device=None, output_dtype=None, graph_construction_dtype=None)[source]¶
Convert a SimState object into AtomGraphs format, ready for use in an ORB model.
- Parameters:
state (ts.SimState) – SimState object containing atomic positions, cell, and atomic numbers
wrap (bool) – Whether to wrap atomic positions into the central unit cell (if there is
one).
edge_method (EdgeCreationMethod, optional) – The method to use for graph edge construction. If None, the edge method is chosen automatically based on device and system size.
system_config (SystemConfig | None) – The system configuration to use for graph construction.
max_num_neighbors (int | None) – Maximum number of neighbors each node can send messages to. If None, will use system_config.max_num_neighbors.
system_id (int | None) – Optional index that is relative to a particular dataset.
half_supercell (bool) – Whether to use half the supercell for graph construction. This can improve performance for large systems.
device (device | None) – The device to put the tensors on.
output_dtype (dtype | None) – The dtype to use for all floating point tensors stored on the AtomGraphs object.
graph_construction_dtype (dtype | None) – The dtype to use for floating point tensors in the graph construction.
- Returns:
AtomGraphs object containing the graph representation of the atomic system
- Return type:
AtomGraphs