Dependencies
/// script dependencies = [ “mace-torch>=0.3.12”, “pymatgen>=2024.11.3”, “ase>=3.24”, “phonopy>=2.37.0”,] ///
Understanding State¶
This tutorial will guide you through the SimState class in TorchSim, which is the core data structure for representing atomistic systems. You’ll learn how to create, manipulate, and work with states.
Introduction¶
The SimState class is the core data structure in TorchSim that represents atomistic systems. A SimState contains all the fundamental properties needed to describe an atomistic system:
Atomic positions
Atomic masses
Unit cell parameters
Periodic boundary conditions
Atomic numbers (elements)
System indices (for processing multiple systems simultaneously)
Understanding the SimState Object¶
A Simple State¶
New SimStates can be either created manually or from existing atomistic objects. Here we’ll start by creating an ase atoms object and converting it to a SimState. The initialize_state
function can take in pymatgen Structure, PhonopyAtoms, or other SimStates and convert them into a single SimState.
[1]:
import torch
import torch_sim as ts
from ase.build import bulk
# Create a crystal structure using ASE
si_atoms = bulk("Si", "diamond", a=5.43, cubic=True)
# Convert to SimState
si_state = ts.initialize_state(si_atoms, device=torch.device("cpu"), dtype=torch.float64)
print(f"State has {si_state.n_atoms} atoms and {si_state.n_systems} systems")
# here we print all the attributes of the SimState
print(f"Positions shape: {si_state.positions.shape}")
print(f"Cell shape: {si_state.cell.shape}")
print(f"Atomic numbers shape: {si_state.atomic_numbers.shape}")
print(f"Masses shape: {si_state.masses.shape}")
print(f"PBC: {si_state.pbc}")
print(f"System indices shape: {si_state.system_idx.shape}")
State has 8 atoms and 1 systems
Positions shape: torch.Size([8, 3])
Cell shape: torch.Size([1, 3, 3])
Atomic numbers shape: torch.Size([8])
Masses shape: torch.Size([8])
PBC: True
System indices shape: torch.Size([8])
SimState attributes fall into three categories: atomwise, systemwise, and global.
Atomwise attributes are tensors with shape (n_atoms, …), these are
positions
,masses
,atomic_numbers
, andsystem_idx
. Names are plural.Systemwise attributes are tensors with shape (n_systems, …), this is just
cell
for the base SimState. Names are singular.Global attributes have any other shape or type, just
pbc
here. Names are singular.
For TorchSim to know which attributes are atomwise, systemwise, and global, each attribute’s name is explicitly defined in the _atom_attributes
, _system_attributes
, and _global_attributes
:
_atom_attributes = {“positions”, “masses”, “atomic_numbers”, “system_idx”} _system_attributes = {“cell”} _global_attributes = {“pbc”}
You can use the get_attrs_for_scope
generator function to analyze a state’s properties. This is mostly used internally but can be useful for debugging.
[2]:
from torch_sim.state import get_attrs_for_scope
# loop through each attribute:
for attr_name, attr_value in get_attrs_for_scope(si_state, "per-atom"):
print(f"per-atom attribute: {attr_name}")
print(f"value: {attr_value}")
# or access the attributes via a dict:
print("Per-system attributes:", dict(get_attrs_for_scope(si_state, "per-system"))) # noqa: E501
print("Global attributes:", dict(get_attrs_for_scope(si_state, "global")))
per-atom attribute: positions
value: tensor([[0.0000, 0.0000, 0.0000],
[1.3575, 1.3575, 1.3575],
[0.0000, 2.7150, 2.7150],
[1.3575, 4.0725, 4.0725],
[2.7150, 0.0000, 2.7150],
[4.0725, 1.3575, 4.0725],
[2.7150, 2.7150, 0.0000],
[4.0725, 4.0725, 1.3575]], dtype=torch.float64)
per-atom attribute: system_idx
value: tensor([0, 0, 0, 0, 0, 0, 0, 0])
per-atom attribute: masses
value: tensor([28.0850, 28.0850, 28.0850, 28.0850, 28.0850, 28.0850, 28.0850, 28.0850],
dtype=torch.float64)
per-atom attribute: atomic_numbers
value: tensor([14, 14, 14, 14, 14, 14, 14, 14], dtype=torch.int32)
Per-system attributes: {'cell': tensor([[[5.4300, 0.0000, 0.0000],
[0.0000, 5.4300, 0.0000],
[0.0000, 0.0000, 5.4300]]], dtype=torch.float64)}
Global attributes: {'pbc': True}
A Batched State¶
A key advantage of TorchSim is its ability to simulate multiple systems simultaneously via batching. To create a batch of multiple systems, you can simply pass a list of atomistic objects to the initialize_state
function.
[3]:
cu_atoms = bulk("Cu", "fcc", a=3.61, cubic=True)
al_atoms = bulk("Al", "fcc", a=4.05, cubic=True)
ag_atoms = bulk("Ag", "fcc", a=4.09, cubic=True)
# Initialize both as a single batched state
multi_state = ts.initialize_state(
[cu_atoms, al_atoms, ag_atoms], device=torch.device("cpu"), dtype=torch.float64
)
print(
f"Multi-state has {multi_state.n_atoms} total atoms across {multi_state.n_systems} systems"
)
# we can see how the shapes of atomwise, systemwise, and global properties change
print(f"Positions shape: {multi_state.positions.shape}")
print(f"Cell shape: {multi_state.cell.shape}")
print(f"PBC: {multi_state.pbc}")
print(f"System indices shape: {multi_state.system_idx.shape}")
Multi-state has 12 total atoms across 3 systems
Positions shape: torch.Size([12, 3])
Cell shape: torch.Size([3, 3, 3])
PBC: True
System indices shape: torch.Size([12])
Manipulating SimState¶
Datatype and Device¶
All SimState tensors are stored on the same device. Further, all floating point tensors are stored as the same datatype, torch.float64
by default.
We can change both the datatype and the state by calling the to
method.
[4]:
if torch.cuda.is_available():
multi_state = multi_state.to(device=torch.device("cuda"), dtype=torch.float32)
Slicing, Splitting, Popping, and more¶
SimState supports many convenience operations for manipulating batched states. Slicing is supported through fancy indexing, e.g. state[[0, 1, 2]]
will return a new state containing only the first three systems. The other operations are available through the pop
, split
, clone
, and to
methods.
[5]:
multi_state_copy = multi_state.clone()
print(f"This state has {multi_state_copy.n_systems} systems")
# we can pop states off while modifying the original state
popped_states = multi_state_copy.pop([0, 2])
print(
f"We popped {len(popped_states)} states, leaving us with "
f"{multi_state_copy.n_systems} systems in the original state"
)
# we can put them back together with concatenate
multi_state_full = ts.concatenate_states([*popped_states, multi_state_copy])
print(f"Again we have {multi_state_full.n_systems} systems in the full state")
# or if we don't want to modify the original state, we can instead index into it
# negative indexing
last_state = multi_state[-1]
# slicing
first_two_states = multi_state[:2]
# fancy indexing
sliced_state = multi_state[[0, 2]]
print(f"Unlike pop, slicing returns a {type(sliced_state)} instead of a list")
# but we could also get a list of states with the split method
list_of_sliced_states = sliced_state.split()
print(f"Which now is a list of {len(list_of_sliced_states)} states")
This state has 3 systems
We popped 2 states, leaving us with 1 systems in the original state
Again we have 3 systems in the full state
Unlike pop, slicing returns a <class 'torch_sim.state.SimState'> instead of a list
Which now is a list of 2 states
You can extract specific systems from a batched state using Python’s slicing syntax. This is extremely useful for analyzing specific systems or for implementing complex workflows where different systems need separate processing:
The slicing interface follows Python’s standard indexing conventions, making it intuitive to use. Behind the scenes, TorchSim is creating a new SimState with only the selected systems, maintaining all the necessary properties and relationships.
Note the difference between these operations:
split()
returns all systems as separate states but doesn’t modify the originalpop()
removes specified systems from the original state and returns them as separate states__getitem__
(slicing) creates a new state with specified systems without modifying the original
This flexibility allows you to structure your simulation workflows in the most efficient way for your specific needs.
Splitting and Popping Batches¶
SimState provides methods to split a batched state into separate states or to remove specific systems:
Converting States to Other Formats¶
SimState objects can be converted back to other atomistic representations. This is useful when you need to use external libraries for analysis or visualization:
[6]:
atoms_list = multi_state.to_atoms()
print(f"Converted to {len(atoms_list)} ASE Atoms objects")
print(f"First atoms object has chemical formula: {atoms_list[0].get_chemical_formula()}")
# Convert to pymatgen Structure
structures = multi_state.to_structures()
print(f"Converted to {len(structures)} pymatgen Structure objects")
print(f"First structure has formula: {structures[0].formula}")
# Convert to PhonopyAtoms (for phonon calculations)
phonopy_atoms = multi_state.to_phonopy()
print(f"Converted to {len(phonopy_atoms)} PhonopyAtoms objects")
print(f"First PhonopyAtoms object has chemical symbols: {phonopy_atoms[0].symbols}")
Converted to 3 ASE Atoms objects
First atoms object has chemical formula: Cu4
Converted to 3 pymatgen Structure objects
First structure has formula: Cu4
Converted to 3 PhonopyAtoms objects
First PhonopyAtoms object has chemical symbols: ['Cu', 'Cu', 'Cu', 'Cu']
Extending SimState: The MDState¶
MDState is defined in the torch_sim.integrators
module. It is a subclass of SimState for molecular dynamics simulations. It includes additional properties like momenta, forces, and energy. Here, we instantiate an MDState from a SimState by zeroing out the additional properties.
Since it inherits from SimState, it supports all the same operations. In general, all state objects in TorchSim support the operations described above.
[7]:
from torch_sim.integrators import MDState
from dataclasses import asdict
# Create an MDState from a SimState
md_state = MDState(
**asdict(si_state), # Copy all SimState properties
momenta=torch.zeros_like(si_state.positions), # Initial 0 momenta
forces=torch.zeros_like(si_state.positions), # Initial 0 forces
energy=torch.zeros((si_state.n_systems,), device=si_state.device), # Initial 0 energy
)
print("MDState properties:")
print("Per-atom attributes:", dict(get_attrs_for_scope(si_state, "per-atom")))
print("Per-system attributes:", dict(get_attrs_for_scope(si_state, "per-system")))
print("Global attributes:", dict(get_attrs_for_scope(si_state, "global")))
MDState properties:
Per-atom attributes: {'positions': tensor([[0.0000, 0.0000, 0.0000],
[1.3575, 1.3575, 1.3575],
[0.0000, 2.7150, 2.7150],
[1.3575, 4.0725, 4.0725],
[2.7150, 0.0000, 2.7150],
[4.0725, 1.3575, 4.0725],
[2.7150, 2.7150, 0.0000],
[4.0725, 4.0725, 1.3575]], dtype=torch.float64), 'system_idx': tensor([0, 0, 0, 0, 0, 0, 0, 0]), 'masses': tensor([28.0850, 28.0850, 28.0850, 28.0850, 28.0850, 28.0850, 28.0850, 28.0850],
dtype=torch.float64), 'atomic_numbers': tensor([14, 14, 14, 14, 14, 14, 14, 14], dtype=torch.int32)}
Per-system attributes: {'cell': tensor([[[5.4300, 0.0000, 0.0000],
[0.0000, 5.4300, 0.0000],
[0.0000, 0.0000, 5.4300]]], dtype=torch.float64)}
Global attributes: {'pbc': True}
Conclusion¶
The SimState class is the foundation of atomistic simulations in TorchSim. It provides:
A flexible, GPU-compatible representation for atomistic systems
Support for batched operations to efficiently process multiple systems
Seamless conversion to and from common atomistic formats
Properties and methods for slicing, combining, and manipulating atomistic data
With this understanding of the SimState, you’re now ready to build complex simulation workflows using TorchSim’s integrators, optimizers, and other modules.