UnitCellGDState¶
- class torch_sim.optimizers.UnitCellGDState(positions, masses, cell, pbc, atomic_numbers, forces, energy, reference_cell, cell_factor, hydrostatic_strain, constant_volume, pressure, stress, cell_positions, cell_forces, cell_masses, *, batch=None)[source]¶
Bases:
GDState
,DeformGradMixin
State class for batched gradient descent optimization with unit cell.
Extends GDState to include unit cell optimization parameters and stress information. This class maintains the state variables needed for simultaneously optimizing atomic positions and unit cell parameters.
- Variables:
GDState (# Inherited from)
positions (Tensor) – Atomic positions with shape [n_atoms, 3]
masses (Tensor) – Atomic masses with shape [n_atoms]
cell (Tensor) – Unit cell vectors with shape [n_batches, 3, 3]
pbc (bool) – Whether to use periodic boundary conditions
atomic_numbers (Tensor) – Atomic numbers with shape [n_atoms]
batch (Tensor) – Batch indices with shape [n_atoms]
forces (Tensor) – Forces acting on atoms with shape [n_atoms, 3]
energy (Tensor) – Potential energy with shape [n_batches]
optimization (# Additional attributes for cell)
stress (Tensor) – Stress tensor with shape [n_batches, 3, 3]
reference_cell (Tensor) – Reference unit cells with shape [n_batches, 3, 3]
cell_factor (Tensor) – Scaling factor for cell optimization with shape [n_batches, 1, 1]
hydrostatic_strain (bool) – Whether to only allow hydrostatic deformation
constant_volume (bool) – Whether to maintain constant volume
pressure (Tensor) – Applied pressure tensor with shape [n_batches, 3, 3]
cell_positions (Tensor) – Cell positions with shape [n_batches, 3, 3]
cell_forces (Tensor) – Cell forces with shape [n_batches, 3, 3]
cell_masses (Tensor) – Cell masses with shape [n_batches, 3]
- Parameters:
positions (Tensor)
masses (Tensor)
cell (Tensor)
pbc (bool)
atomic_numbers (Tensor)
forces (Tensor)
energy (Tensor)
reference_cell (Tensor)
cell_factor (Tensor)
hydrostatic_strain (bool)
constant_volume (bool)
pressure (Tensor)
stress (Tensor)
cell_positions (Tensor)
cell_forces (Tensor)
cell_masses (Tensor)
batch (Tensor | None)