Source code for torch_sim.optimizers

"""Optimizers for geometry relaxations.

This module provides optimization algorithms for atomic structures in a batched format,
enabling efficient relaxation of multiple atomic structures simultaneously. It includes
several gradient-based methods with support for both atomic position and unit cell
optimization.

The module offers:

* Standard gradient descent for atomic positions
* Gradient descent with unit cell optimization
* FIRE (Fast Inertial Relaxation Engine) optimization with unit cell parameters
* FIRE optimization with Frechet cell parameterization for improved cell relaxation

ASE-style FIRE: https://gitlab.com/ase/ase/-/blob/master/ase/optimize/fire.py?ref_type=heads
Velocity Verlet-style FIRE: https://doi.org/10.1103/PhysRevLett.97.170201

"""

import functools
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Literal, get_args

import torch

import torch_sim.math as tsm
from torch_sim.state import DeformGradMixin, SimState
from torch_sim.typing import StateDict


MdFlavor = Literal["vv_fire", "ase_fire"]
vv_fire_key, ase_fire_key = get_args(MdFlavor)


[docs] @dataclass class GDState(SimState): """State class for batched gradient descent optimization. This class extends SimState to store and track the evolution of system state during gradient descent optimization. It maintains the energies and forces needed to perform gradient-based structure relaxation in a batched manner. Attributes: positions (torch.Tensor): Atomic positions with shape [n_atoms, 3] masses (torch.Tensor): Atomic masses with shape [n_atoms] cell (torch.Tensor): Unit cell vectors with shape [n_batches, 3, 3] pbc (bool): Whether to use periodic boundary conditions atomic_numbers (torch.Tensor): Atomic numbers with shape [n_atoms] batch (torch.Tensor): Batch indices with shape [n_atoms] forces (torch.Tensor): Forces acting on atoms with shape [n_atoms, 3] energy (torch.Tensor): Potential energy with shape [n_batches] """ forces: torch.Tensor energy: torch.Tensor
[docs] def gradient_descent( model: torch.nn.Module, *, lr: torch.Tensor | float = 0.01 ) -> tuple[Callable[[StateDict | SimState], GDState], Callable[[GDState], GDState]]: """Initialize a batched gradient descent optimization. Creates an optimizer that performs standard gradient descent on atomic positions for multiple systems in parallel. The optimizer updates atomic positions based on forces computed by the provided model. The cell is not optimized with this optimizer. Args: model (torch.nn.Module): Model that computes energies and forces lr (torch.Tensor | float): Learning rate(s) for optimization. Can be a single float applied to all batches or a tensor with shape [n_batches] for batch-specific rates Returns: tuple: A pair of functions: - Initialization function that creates the initial BatchedGDState - Update function that performs one gradient descent step Notes: The learning rate controls the step size during optimization. Larger values can speed up convergence but may cause instability in the optimization process. """ device, dtype = model.device, model.dtype def gd_init( state: SimState | StateDict, **kwargs: Any, ) -> GDState: """Initialize the batched gradient descent optimization state. Args: state: SimState containing positions, masses, cell, etc. kwargs: Additional keyword arguments to override state attributes Returns: Initialized BatchedGDState with forces and energy """ if not isinstance(state, SimState): state = SimState(**state) atomic_numbers = kwargs.get("atomic_numbers", state.atomic_numbers) # Get initial forces and energy from model model_output = model(state) energy = model_output["energy"] forces = model_output["forces"] return GDState( positions=state.positions, forces=forces, energy=energy, masses=state.masses, cell=state.cell, pbc=state.pbc, atomic_numbers=atomic_numbers, batch=state.batch, ) def gd_step(state: GDState, lr: torch.Tensor = lr) -> GDState: """Perform one gradient descent optimization step to update the atomic positions. The cell is not optimized. Args: state: Current optimization state lr: Learning rate(s) to use for this step, overriding the default Returns: Updated GDState after one optimization step """ # Get per-atom learning rates by mapping batch learning rates to atoms if isinstance(lr, float): lr = torch.full((state.n_batches,), lr, device=device, dtype=dtype) atom_lr = lr[state.batch].unsqueeze(-1) # shape: (total_atoms, 1) # Update positions using forces and per-atom learning rates state.positions = state.positions + atom_lr * state.forces # Get updated forces and energy from model model_output = model(state) # Update state with new forces and energy state.forces = model_output["forces"] state.energy = model_output["energy"] return state return gd_init, gd_step
[docs] @dataclass class UnitCellGDState(GDState, DeformGradMixin): """State class for batched gradient descent optimization with unit cell. Extends GDState to include unit cell optimization parameters and stress information. This class maintains the state variables needed for simultaneously optimizing atomic positions and unit cell parameters. Attributes: # Inherited from GDState positions (torch.Tensor): Atomic positions with shape [n_atoms, 3] masses (torch.Tensor): Atomic masses with shape [n_atoms] cell (torch.Tensor): Unit cell vectors with shape [n_batches, 3, 3] pbc (bool): Whether to use periodic boundary conditions atomic_numbers (torch.Tensor): Atomic numbers with shape [n_atoms] batch (torch.Tensor): Batch indices with shape [n_atoms] forces (torch.Tensor): Forces acting on atoms with shape [n_atoms, 3] energy (torch.Tensor): Potential energy with shape [n_batches] # Additional attributes for cell optimization stress (torch.Tensor): Stress tensor with shape [n_batches, 3, 3] reference_cell (torch.Tensor): Reference unit cells with shape [n_batches, 3, 3] cell_factor (torch.Tensor): Scaling factor for cell optimization with shape [n_batches, 1, 1] hydrostatic_strain (bool): Whether to only allow hydrostatic deformation constant_volume (bool): Whether to maintain constant volume pressure (torch.Tensor): Applied pressure tensor with shape [n_batches, 3, 3] cell_positions (torch.Tensor): Cell positions with shape [n_batches, 3, 3] cell_forces (torch.Tensor): Cell forces with shape [n_batches, 3, 3] cell_masses (torch.Tensor): Cell masses with shape [n_batches, 3] """ # Required attributes not in BatchedGDState reference_cell: torch.Tensor cell_factor: torch.Tensor hydrostatic_strain: bool constant_volume: bool pressure: torch.Tensor stress: torch.Tensor # Cell attributes cell_positions: torch.Tensor cell_forces: torch.Tensor cell_masses: torch.Tensor
[docs] def unit_cell_gradient_descent( # noqa: PLR0915, C901 model: torch.nn.Module, *, positions_lr: float = 0.01, cell_lr: float = 0.1, cell_factor: float | torch.Tensor | None = None, hydrostatic_strain: bool = False, constant_volume: bool = False, scalar_pressure: float = 0.0, ) -> tuple[ Callable[[SimState | StateDict], UnitCellGDState], Callable[[UnitCellGDState], UnitCellGDState], ]: """Initialize a batched gradient descent optimization with unit cell parameters. Creates an optimizer that performs gradient descent on both atomic positions and unit cell parameters for multiple systems in parallel. Supports constraints on cell deformation and applied external pressure. This optimizer extends standard gradient descent to simultaneously optimize both atomic coordinates and unit cell parameters based on forces and stress computed by the provided model. Args: model (torch.nn.Module): Model that computes energies, forces, and stress positions_lr (float): Learning rate for atomic positions optimization. Default is 0.01. cell_lr (float): Learning rate for unit cell optimization. Default is 0.1. cell_factor (float | torch.Tensor | None): Scaling factor for cell optimization. If None, defaults to number of atoms per batch hydrostatic_strain (bool): Whether to only allow hydrostatic deformation (isotropic scaling). Default is False. constant_volume (bool): Whether to maintain constant volume during optimization Default is False. scalar_pressure (float): Applied external pressure in GPa. Default is 0.0. Returns: tuple: A pair of functions: - Initialization function that creates a BatchedUnitCellGDState - Update function that performs one gradient descent step with cell optimization Notes: - To fix the cell and only optimize atomic positions, set both constant_volume=True and hydrostatic_strain=True - The cell_factor parameter controls the relative scale of atomic vs cell optimization - Larger values for positions_lr and cell_lr can speed up convergence but may cause instability in the optimization process """ device, dtype = model.device, model.dtype def gd_init( state: SimState, cell_factor: float | torch.Tensor | None = cell_factor, hydrostatic_strain: bool = hydrostatic_strain, # noqa: FBT001 constant_volume: bool = constant_volume, # noqa: FBT001 scalar_pressure: float = scalar_pressure, ) -> UnitCellGDState: """Initialize the batched gradient descent optimization state with unit cell. Args: state: Initial system state containing positions, masses, cell, etc. cell_factor: Scaling factor for cell optimization (default: number of atoms) hydrostatic_strain: Whether to only allow hydrostatic deformation constant_volume: Whether to maintain constant volume scalar_pressure: Applied pressure in GPa **kwargs: Additional keyword arguments for state initialization Returns: Initial UnitCellGDState with system configuration and forces """ if not isinstance(state, SimState): state = SimState(**state) n_batches = state.n_batches # Setup cell_factor if cell_factor is None: # Count atoms per batch _, counts = torch.unique(state.batch, return_counts=True) cell_factor = counts.to(dtype=dtype) if isinstance(cell_factor, int | float): # Use same factor for all batches cell_factor = torch.full( (state.n_batches,), cell_factor, device=device, dtype=dtype ) # Reshape to (n_batches, 1, 1) for broadcasting cell_factor = cell_factor.view(n_batches, 1, 1) scalar_pressure = torch.full( (state.n_batches, 1, 1), scalar_pressure, device=device, dtype=dtype ) # Setup pressure tensor pressure = scalar_pressure * torch.eye(3, device=device) # Get initial forces and energy from model model_output = model(state) energy = model_output["energy"] forces = model_output["forces"] stress = model_output["stress"] # Already shape: (n_batches, 3, 3) # Create cell masses cell_masses = torch.ones( (state.n_batches, 3), device=device, dtype=dtype ) # One mass per cell DOF # Get current deformation gradient cur_deform_grad = DeformGradMixin._deform_grad( # noqa: SLF001 state.row_vector_cell, state.row_vector_cell ) # Calculate cell positions cell_factor_expanded = cell_factor.expand( state.n_batches, 3, 1 ) # shape: (n_batches, 3, 1) cell_positions = ( cur_deform_grad.reshape(state.n_batches, 3, 3) * cell_factor_expanded ) # shape: (n_batches, 3, 3) # Calculate virial volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) virial = -volumes * (stress + pressure) if hydrostatic_strain: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze( 0 ).expand(state.n_batches, -1, -1) if constant_volume: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = virial - diag_mean.unsqueeze(-1) * torch.eye( 3, device=device ).unsqueeze(0).expand(state.n_batches, -1, -1) return UnitCellGDState( positions=state.positions, forces=forces, energy=energy, stress=stress, masses=state.masses, cell=state.cell, pbc=state.pbc, reference_cell=state.cell.clone(), cell_factor=cell_factor, hydrostatic_strain=hydrostatic_strain, constant_volume=constant_volume, pressure=pressure, atomic_numbers=state.atomic_numbers, batch=state.batch, cell_positions=cell_positions, cell_forces=virial / cell_factor, cell_masses=cell_masses, ) def gd_step( state: UnitCellGDState, positions_lr: torch.Tensor = positions_lr, cell_lr: torch.Tensor = cell_lr, ) -> UnitCellGDState: """Perform one gradient descent optimization step with unit cell. Updates both atomic positions and cell parameters based on forces and stress. Args: state: Current optimization state positions_lr: Learning rate for atomic positions optimization cell_lr: Learning rate for unit cell optimization Returns: Updated UnitCellGDState after one optimization step """ # Get dimensions n_batches = state.n_batches # Get per-atom learning rates by mapping batch learning rates to atoms if isinstance(positions_lr, float): positions_lr = torch.full( (state.n_batches,), positions_lr, device=device, dtype=dtype ) if isinstance(cell_lr, float): cell_lr = torch.full((state.n_batches,), cell_lr, device=device, dtype=dtype) # Get current deformation gradient cur_deform_grad = state.deform_grad() # Calculate cell positions from deformation gradient cell_factor_expanded = state.cell_factor.expand(n_batches, 3, 1) cell_positions = ( cur_deform_grad.reshape(n_batches, 3, 3) * cell_factor_expanded ) # shape: (n_batches, 3, 3) # Get per-atom and per-cell learning rates atom_wise_lr = positions_lr[state.batch].unsqueeze(-1) cell_wise_lr = cell_lr.view(n_batches, 1, 1) # shape: (n_batches, 1, 1) # Update atomic and cell positions atomic_positions_new = state.positions + atom_wise_lr * state.forces cell_positions_new = cell_positions + cell_wise_lr * state.cell_forces # Update cell with deformation gradient cell_update = cell_positions_new / cell_factor_expanded new_row_vector_cell = torch.bmm(state.reference_row_vector_cell, cell_update.mT) # Update state state.positions = atomic_positions_new state.row_vector_cell = new_row_vector_cell # Get new forces and energy model_output = model(state) state.energy = model_output["energy"] state.forces = model_output["forces"] state.stress = model_output["stress"] # Calculate virial for cell forces volumes = torch.linalg.det(new_row_vector_cell).view(n_batches, 1, 1) virial = -volumes * (state.stress + state.pressure) if state.hydrostatic_strain: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze( 0 ).expand(n_batches, -1, -1) if state.constant_volume: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = virial - diag_mean.unsqueeze(-1) * torch.eye( 3, device=device ).unsqueeze(0).expand(n_batches, -1, -1) # Update cell forces state.cell_positions = cell_positions_new state.cell_forces = virial / state.cell_factor return state return gd_init, gd_step
[docs] @dataclass class FireState(SimState): """State information for batched FIRE optimization. This class extends SimState to store and track the system state during FIRE (Fast Inertial Relaxation Engine) optimization. It maintains the atomic parameters along with their velocities and forces for structure relaxation using the FIRE algorithm. Attributes: # Inherited from SimState positions (torch.Tensor): Atomic positions with shape [n_atoms, 3] masses (torch.Tensor): Atomic masses with shape [n_atoms] cell (torch.Tensor): Unit cell vectors with shape [n_batches, 3, 3] pbc (bool): Whether to use periodic boundary conditions atomic_numbers (torch.Tensor): Atomic numbers with shape [n_atoms] batch (torch.Tensor): Batch indices with shape [n_atoms] # Atomic quantities forces (torch.Tensor): Forces on atoms with shape [n_atoms, 3] velocities (torch.Tensor): Atomic velocities with shape [n_atoms, 3] energy (torch.Tensor): Energy per batch with shape [n_batches] # FIRE optimization parameters dt (torch.Tensor): Current timestep per batch with shape [n_batches] alpha (torch.Tensor): Current mixing parameter per batch with shape [n_batches] n_pos (torch.Tensor): Number of positive power steps per batch with shape [n_batches] Properties: momenta (torch.Tensor): Atomwise momenta of the system with shape [n_atoms, 3], calculated as velocities * masses """ # Required attributes not in SimState forces: torch.Tensor energy: torch.Tensor velocities: torch.Tensor | None # FIRE algorithm parameters dt: torch.Tensor alpha: torch.Tensor n_pos: torch.Tensor
[docs] def fire( model: torch.nn.Module, *, dt_max: float = 1.0, dt_start: float = 0.1, n_min: int = 5, f_inc: float = 1.1, f_dec: float = 0.5, alpha_start: float = 0.1, f_alpha: float = 0.99, max_step: float = 0.2, md_flavor: MdFlavor = ase_fire_key, ) -> tuple[ Callable[[SimState | StateDict], FireState], Callable[[FireState], FireState], ]: """Initialize a batched FIRE optimization. Creates an optimizer that performs FIRE (Fast Inertial Relaxation Engine) optimization on atomic positions. Args: model (torch.nn.Module): Model that computes energies, forces, and stress dt_max (float): Maximum allowed timestep dt_start (float): Initial timestep n_min (int): Minimum steps before timestep increase f_inc (float): Factor for timestep increase when power is positive f_dec (float): Factor for timestep decrease when power is negative alpha_start (float): Initial velocity mixing parameter f_alpha (float): Factor for mixing parameter decrease max_step (float): Maximum distance an atom can move per iteration (default value is 0.2). Only used when md_flavor='ase_fire'. md_flavor ("vv_fire" | "ase_fire"): Optimization flavor. Default is "ase_fire". Returns: tuple[Callable, Callable]: - Initialization function that creates a FireState - Update function (either vv_fire_step or ase_fire_step) that performs one FIRE optimization step. Notes: - md_flavor="vv_fire" follows the original paper closely, including integration with Velocity Verlet steps. See https://doi.org/10.1103/PhysRevLett.97.170201 and https://github.com/Radical-AI/torch-sim/issues/90#issuecomment-2826179997 for details. - md_flavor="ase_fire" mimics the implementation in ASE, which differs slightly in the update steps and does not explicitly use atomic masses in the velocity update step. See https://gitlab.com/ase/ase/-/blob/66963e6e38/ase/optimize/fire.py#L164-214 for details. - FIRE is generally more efficient than standard gradient descent for atomic structure optimization. - The algorithm adaptively adjusts step sizes and mixing parameters based on the dot product of forces and velocities (power). """ if md_flavor not in get_args(MdFlavor): raise ValueError(f"Unknown {md_flavor=}, must be one of {get_args(MdFlavor)}") device, dtype = model.device, model.dtype eps = 1e-8 if dtype == torch.float32 else 1e-16 # Setup parameters, added max_step for ASE style dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min, max_step = ( torch.as_tensor(p, device=device, dtype=dtype) for p in (dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min, max_step) ) def fire_init( state: SimState | StateDict, dt_start: float = dt_start, alpha_start: float = alpha_start, ) -> FireState: """Initialize a batched FIRE optimization state. Args: state: Input state as SimState object or state parameter dict dt_start: Initial timestep per batch alpha_start: Initial mixing parameter per batch Returns: FireState with initialized optimization tensors """ if not isinstance(state, SimState): state = SimState(**state) # Get dimensions n_batches = state.n_batches # Get initial forces and energy from model model_output = model(state) energy = model_output["energy"] # [n_batches] forces = model_output["forces"] # [n_total_atoms, 3] # Setup parameters dt_start = torch.full((n_batches,), dt_start, device=device, dtype=dtype) alpha_start = torch.full((n_batches,), alpha_start, device=device, dtype=dtype) n_pos = torch.zeros((n_batches,), device=device, dtype=torch.int32) return FireState( # Create initial state # Copy SimState attributes positions=state.positions.clone(), masses=state.masses.clone(), cell=state.cell.clone(), atomic_numbers=state.atomic_numbers.clone(), batch=state.batch.clone(), pbc=state.pbc, velocities=None, forces=forces, energy=energy, # Optimization attributes dt=dt_start, alpha=alpha_start, n_pos=n_pos, ) step_func_kwargs = dict( model=model, dt_max=dt_max, n_min=n_min, f_inc=f_inc, f_dec=f_dec, alpha_start=alpha_start, f_alpha=f_alpha, eps=eps, is_cell_optimization=False, is_frechet=False, ) if md_flavor == ase_fire_key: step_func_kwargs["max_step"] = max_step step_func = {vv_fire_key: _vv_fire_step, ase_fire_key: _ase_fire_step}[md_flavor] return fire_init, functools.partial(step_func, **step_func_kwargs)
[docs] @dataclass class UnitCellFireState(SimState, DeformGradMixin): """State information for batched FIRE optimization with unit cell degrees of freedom. This class extends SimState to store and track the system state during FIRE (Fast Inertial Relaxation Engine) optimization. It maintains both atomic and cell parameters along with their velocities and forces for structure relaxation using the FIRE algorithm. Attributes: # Inherited from SimState positions (torch.Tensor): Atomic positions with shape [n_atoms, 3] masses (torch.Tensor): Atomic masses with shape [n_atoms] cell (torch.Tensor): Unit cell vectors with shape [n_batches, 3, 3] pbc (bool): Whether to use periodic boundary conditions atomic_numbers (torch.Tensor): Atomic numbers with shape [n_atoms] batch (torch.Tensor): Batch indices with shape [n_atoms] # Atomic quantities forces (torch.Tensor): Forces on atoms with shape [n_atoms, 3] velocities (torch.Tensor): Atomic velocities with shape [n_atoms, 3] energy (torch.Tensor): Energy per batch with shape [n_batches] stress (torch.Tensor): Stress tensor with shape [n_batches, 3, 3] # Cell quantities cell_positions (torch.Tensor): Cell positions with shape [n_batches, 3, 3] cell_velocities (torch.Tensor): Cell velocities with shape [n_batches, 3, 3] cell_forces (torch.Tensor): Cell forces with shape [n_batches, 3, 3] cell_masses (torch.Tensor): Cell masses with shape [n_batches, 3] # Cell optimization parameters reference_cell (torch.Tensor): Original unit cells with shape [n_batches, 3, 3] cell_factor (torch.Tensor): Cell optimization scaling factor with shape [n_batches, 1, 1] pressure (torch.Tensor): Applied pressure tensor with shape [n_batches, 3, 3] hydrostatic_strain (bool): Whether to only allow hydrostatic deformation constant_volume (bool): Whether to maintain constant volume # FIRE optimization parameters dt (torch.Tensor): Current timestep per batch with shape [n_batches] alpha (torch.Tensor): Current mixing parameter per batch with shape [n_batches] n_pos (torch.Tensor): Number of positive power steps per batch with shape [n_batches] Properties: momenta (torch.Tensor): Atomwise momenta of the system with shape [n_atoms, 3], calculated as velocities * masses """ # Required attributes not in SimState forces: torch.Tensor energy: torch.Tensor stress: torch.Tensor velocities: torch.Tensor # Cell attributes cell_positions: torch.Tensor cell_velocities: torch.Tensor cell_forces: torch.Tensor cell_masses: torch.Tensor # Optimization-specific attributes reference_cell: torch.Tensor cell_factor: torch.Tensor pressure: torch.Tensor hydrostatic_strain: bool constant_volume: bool # FIRE algorithm parameters dt: torch.Tensor alpha: torch.Tensor n_pos: torch.Tensor
[docs] def unit_cell_fire( model: torch.nn.Module, *, dt_max: float = 1.0, dt_start: float = 0.1, n_min: int = 5, f_inc: float = 1.1, f_dec: float = 0.5, alpha_start: float = 0.1, f_alpha: float = 0.99, cell_factor: float | None = None, hydrostatic_strain: bool = False, constant_volume: bool = False, scalar_pressure: float = 0.0, max_step: float = 0.2, md_flavor: MdFlavor = ase_fire_key, ) -> tuple[ UnitCellFireState, Callable[[UnitCellFireState], UnitCellFireState], ]: """Initialize a batched FIRE optimization with unit cell degrees of freedom. Creates an optimizer that performs FIRE (Fast Inertial Relaxation Engine) optimization on both atomic positions and unit cell parameters for multiple systems in parallel. FIRE combines molecular dynamics with velocity damping and adjustment of time steps to efficiently find local minima. Args: model (torch.nn.Module): Model that computes energies, forces, and stress dt_max (float): Maximum allowed timestep dt_start (float): Initial timestep n_min (int): Minimum steps before timestep increase f_inc (float): Factor for timestep increase when power is positive f_dec (float): Factor for timestep decrease when power is negative alpha_start (float): Initial velocity mixing parameter f_alpha (float): Factor for mixing parameter decrease cell_factor (float | None): Scaling factor for cell optimization. If None, defaults to number of atoms per batch hydrostatic_strain (bool): Whether to only allow hydrostatic deformation (isotropic scaling) constant_volume (bool): Whether to maintain constant volume during optimization scalar_pressure (float): Applied external pressure in GPa max_step (float): Maximum allowed step size for ase_fire md_flavor ("vv_fire" | "ase_fire"): Optimization flavor. Default is "ase_fire". Returns: tuple: A pair of functions: - Initialization function that creates a BatchedUnitCellFireState - Update function that performs one FIRE optimization step Notes: - md_flavor="vv_fire" follows the original paper closely, including integration with Velocity Verlet steps. See https://doi.org/10.1103/PhysRevLett.97.170201 and https://github.com/Radical-AI/torch-sim/issues/90#issuecomment-2826179997 for details. - md_flavor="ase_fire" mimics the implementation in ASE, which differs slightly in the update steps and does not explicitly use atomic masses in the velocity update step. See https://gitlab.com/ase/ase/-/blob/66963e6e38/ase/optimize/fire.py#L164-214 for details. - FIRE is generally more efficient than standard gradient descent for atomic structure optimization - The algorithm adaptively adjusts step sizes and mixing parameters based on the dot product of forces and velocities - To fix the cell and only optimize atomic positions, set both constant_volume=True and hydrostatic_strain=True - The cell_factor parameter controls the relative scale of atomic vs cell optimization """ if md_flavor not in get_args(MdFlavor): raise ValueError(f"Unknown {md_flavor=}, must be one of {get_args(MdFlavor)}") device, dtype = model.device, model.dtype eps = 1e-8 if dtype == torch.float32 else 1e-16 # Setup parameters dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min, max_step = ( torch.as_tensor(p, device=device, dtype=dtype) for p in (dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min, max_step) ) def fire_init( state: SimState | StateDict, cell_factor: torch.Tensor | None = cell_factor, scalar_pressure: float = scalar_pressure, dt_start: float = dt_start, alpha_start: float = alpha_start, ) -> UnitCellFireState: """Initialize a batched FIRE optimization state with unit cell. Args: state: Input state as SimState object or state parameter dict cell_factor: Cell optimization scaling factor. If None, uses atoms per batch. Single value or tensor of shape [n_batches]. scalar_pressure: Applied pressure in energy units dt_start: Initial timestep per batch alpha_start: Initial mixing parameter per batch Returns: UnitCellFireState with initialized optimization tensors """ if not isinstance(state, SimState): state = SimState(**state) # Get dimensions n_batches = state.n_batches # Setup cell_factor if cell_factor is None: # Count atoms per batch _, counts = torch.unique(state.batch, return_counts=True) cell_factor = counts.to(dtype=dtype) if isinstance(cell_factor, int | float): # Use same factor for all batches cell_factor = torch.full( (state.n_batches,), cell_factor, device=device, dtype=dtype ) # Reshape to (n_batches, 1, 1) for broadcasting cell_factor = cell_factor.view(n_batches, 1, 1) # Setup pressure tensor pressure = scalar_pressure * torch.eye(3, device=device, dtype=dtype) pressure = pressure.unsqueeze(0).expand(n_batches, -1, -1) # Get initial forces and energy from model model_output = model(state) energy = model_output["energy"] # [n_batches] forces = model_output["forces"] # [n_total_atoms, 3] stress = model_output["stress"] # [n_batches, 3, 3] volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) virial = -volumes * (stress + pressure) # P is P_ext * I if hydrostatic_strain: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze( 0 ).expand(n_batches, -1, -1) if constant_volume: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = virial - diag_mean.unsqueeze(-1) * torch.eye( 3, device=device ).unsqueeze(0).expand(n_batches, -1, -1) cell_forces = virial / cell_factor # Sum masses per batch using segment_reduce # TODO (AG): check this batch_counts = torch.bincount(state.batch) cell_masses = torch.segment_reduce( state.masses, reduce="sum", lengths=batch_counts ) # shape: (n_batches,) cell_masses = cell_masses.unsqueeze(-1).expand(-1, 3) # shape: (n_batches, 3) # Setup parameters dt_start = torch.full((n_batches,), dt_start, device=device, dtype=dtype) alpha_start = torch.full((n_batches,), alpha_start, device=device, dtype=dtype) n_pos = torch.zeros((n_batches,), device=device, dtype=torch.int32) return UnitCellFireState( # Create initial state # Copy SimState attributes positions=state.positions.clone(), masses=state.masses.clone(), cell=state.cell.clone(), atomic_numbers=state.atomic_numbers.clone(), batch=state.batch.clone(), pbc=state.pbc, velocities=None, forces=forces, energy=energy, stress=stress, # Cell attributes cell_positions=torch.zeros(n_batches, 3, 3, device=device, dtype=dtype), cell_velocities=None, cell_forces=cell_forces, cell_masses=cell_masses, # Optimization attributes reference_cell=state.cell.clone(), cell_factor=cell_factor, pressure=pressure, dt=dt_start, alpha=alpha_start, n_pos=n_pos, hydrostatic_strain=hydrostatic_strain, constant_volume=constant_volume, ) step_func_kwargs = dict( model=model, dt_max=dt_max, n_min=n_min, f_inc=f_inc, f_dec=f_dec, alpha_start=alpha_start, f_alpha=f_alpha, eps=eps, is_cell_optimization=True, is_frechet=False, ) if md_flavor == ase_fire_key: step_func_kwargs["max_step"] = max_step step_func = {vv_fire_key: _vv_fire_step, ase_fire_key: _ase_fire_step}[md_flavor] return fire_init, functools.partial(step_func, **step_func_kwargs)
[docs] @dataclass class FrechetCellFIREState(SimState, DeformGradMixin): """State class for batched FIRE optimization with Frechet cell derivatives. This class extends SimState to store and track the system state during FIRE optimization with matrix logarithm parameterization for cell degrees of freedom. This parameterization provides improved handling of cell deformations during optimization. Attributes: # Inherited from SimState positions (torch.Tensor): Atomic positions with shape [n_atoms, 3] masses (torch.Tensor): Atomic masses with shape [n_atoms] cell (torch.Tensor): Unit cell vectors with shape [n_batches, 3, 3] pbc (bool): Whether to use periodic boundary conditions atomic_numbers (torch.Tensor): Atomic numbers with shape [n_atoms] batch (torch.Tensor): Batch indices with shape [n_atoms] # Additional atomic quantities forces (torch.Tensor): Forces on atoms with shape [n_atoms, 3] energy (torch.Tensor): Energy per batch with shape [n_batches] velocities (torch.Tensor): Atomic velocities with shape [n_atoms, 3] stress (torch.Tensor): Stress tensor with shape [n_batches, 3, 3] # Optimization-specific attributes reference_cell (torch.Tensor): Original unit cell with shape [n_batches, 3, 3] cell_factor (torch.Tensor): Scaling factor for cell optimization with shape [n_batches, 1, 1] pressure (torch.Tensor): Applied pressure tensor with shape [n_batches, 3, 3] hydrostatic_strain (bool): Whether to only allow hydrostatic deformation constant_volume (bool): Whether to maintain constant volume # Cell attributes using log parameterization cell_positions (torch.Tensor): Cell positions using log parameterization with shape [n_batches, 3, 3] cell_velocities (torch.Tensor): Cell velocities with shape [n_batches, 3, 3] cell_forces (torch.Tensor): Cell forces with shape [n_batches, 3, 3] cell_masses (torch.Tensor): Cell masses with shape [n_batches, 3] # FIRE algorithm parameters dt (torch.Tensor): Current timestep per batch with shape [n_batches] alpha (torch.Tensor): Current mixing parameter per batch with shape [n_batches] n_pos (torch.Tensor): Number of positive power steps per batch with shape [n_batches] Properties: momenta (torch.Tensor): Atomwise momenta of the system with shape [n_atoms, 3], calculated as velocities * masses """ # Required attributes not in SimState forces: torch.Tensor energy: torch.Tensor velocities: torch.Tensor stress: torch.Tensor # Optimization-specific attributes reference_cell: torch.Tensor cell_factor: torch.Tensor pressure: torch.Tensor hydrostatic_strain: bool constant_volume: bool # Cell attributes cell_positions: torch.Tensor cell_velocities: torch.Tensor | None cell_forces: torch.Tensor cell_masses: torch.Tensor # FIRE algorithm parameters dt: torch.Tensor alpha: torch.Tensor n_pos: torch.Tensor
[docs] def frechet_cell_fire( model: torch.nn.Module, *, dt_max: float = 1.0, dt_start: float = 0.1, n_min: int = 5, f_inc: float = 1.1, f_dec: float = 0.5, alpha_start: float = 0.1, f_alpha: float = 0.99, cell_factor: float | None = None, hydrostatic_strain: bool = False, constant_volume: bool = False, scalar_pressure: float = 0.0, max_step: float = 0.2, md_flavor: MdFlavor = ase_fire_key, ) -> tuple[ FrechetCellFIREState, Callable[[FrechetCellFIREState], FrechetCellFIREState], ]: """Initialize a batched FIRE optimization with Frechet cell parameterization. Creates an optimizer that performs FIRE optimization on both atomic positions and unit cell parameters using matrix logarithm parameterization for cell degrees of freedom. This parameterization provides forces consistent with numerical derivatives of the potential energy with respect to cell variables, resulting in more robust cell optimization. Args: model (torch.nn.Module): Model that computes energies, forces, and stress. dt_max (float): Maximum allowed timestep dt_start (float): Initial timestep n_min (int): Minimum steps before timestep increase f_inc (float): Factor for timestep increase when power is positive f_dec (float): Factor for timestep decrease when power is negative alpha_start (float): Initial velocity mixing parameter f_alpha (float): Factor for mixing parameter decrease cell_factor (float | None): Scaling factor for cell optimization. If None, defaults to number of atoms per batch hydrostatic_strain (bool): Whether to only allow hydrostatic deformation (isotropic scaling) constant_volume (bool): Whether to maintain constant volume during optimization scalar_pressure (float): Applied external pressure in GPa max_step (float): Maximum allowed step size for ase_fire md_flavor ("vv_fire" | "ase_fire"): Optimization flavor. Default is "ase_fire". Returns: tuple: A pair of functions: - Initialization function that creates a FrechetCellFIREState - Update function that performs one FIRE step with Frechet derivatives Notes: - md_flavor="vv_fire" follows the original paper closely, including integration with Velocity Verlet steps. See https://doi.org/10.1103/PhysRevLett.97.170201 and https://github.com/Radical-AI/torch-sim/issues/90#issuecomment-2826179997 for details. - md_flavor="ase_fire" mimics the implementation in ASE, which differs slightly in the update steps and does not explicitly use atomic masses in the velocity update step. See https://gitlab.com/ase/ase/-/blob/66963e6e38/ase/optimize/fire.py#L164-214 for details. - Frechet cell parameterization uses matrix logarithm to represent cell deformations, which provides improved numerical properties for cell optimization - This method generally performs better than standard unit cell optimization for cases with large cell deformations - To fix the cell and only optimize atomic positions, set both constant_volume=True and hydrostatic_strain=True """ if md_flavor not in get_args(MdFlavor): raise ValueError(f"Unknown {md_flavor=}, must be one of {get_args(MdFlavor)}") device, dtype = model.device, model.dtype eps = 1e-8 if dtype == torch.float32 else 1e-16 # Setup parameters dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min, max_step = ( torch.as_tensor(p, device=device, dtype=dtype) for p in (dt_max, dt_start, alpha_start, f_inc, f_dec, f_alpha, n_min, max_step) ) def fire_init( state: SimState | StateDict, cell_factor: torch.Tensor | None = cell_factor, scalar_pressure: float = scalar_pressure, dt_start: float = dt_start, alpha_start: float = alpha_start, ) -> FrechetCellFIREState: """Initialize a batched FIRE optimization state with Frechet cell parameterization. Args: state: Input state as SimState object or state parameter dict cell_factor: Cell optimization scaling factor. If None, uses atoms per batch. Single value or tensor of shape [n_batches]. scalar_pressure: Applied pressure in energy units dt_start: Initial timestep per batch alpha_start: Initial mixing parameter per batch Returns: FrechetCellFIREState with initialized optimization tensors """ if not isinstance(state, SimState): state = SimState(**state) # Get dimensions n_batches = state.n_batches # Setup cell_factor if cell_factor is None: # Count atoms per batch _, counts = torch.unique(state.batch, return_counts=True) cell_factor = counts.to(dtype=dtype) if isinstance(cell_factor, int | float): # Use same factor for all batches cell_factor = torch.full( (state.n_batches,), cell_factor, device=device, dtype=dtype ) # Reshape to (n_batches, 1, 1) for broadcasting cell_factor = cell_factor.view(n_batches, 1, 1) # Setup pressure tensor pressure = scalar_pressure * torch.eye(3, device=device, dtype=dtype) pressure = pressure.unsqueeze(0).expand(n_batches, -1, -1) # Get initial forces and energy from model model_output = model(state) energy = model_output["energy"] # [n_batches] forces = model_output["forces"] # [n_total_atoms, 3] stress = model_output["stress"] # [n_batches, 3, 3] # Calculate initial cell positions using matrix logarithm # Calculate current deformation gradient (identity matrix at start) cur_deform_grad = DeformGradMixin._deform_grad( # noqa: SLF001 state.row_vector_cell, state.row_vector_cell ) # shape: (n_batches, 3, 3) # For identity matrix, logm gives zero matrix # Initialize cell positions to zeros cell_positions = torch.zeros((n_batches, 3, 3), device=device, dtype=dtype) # Calculate virial for cell forces volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) virial = -volumes * (stress + pressure) # P is P_ext * I if hydrostatic_strain: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = diag_mean.unsqueeze(-1) * torch.eye(3, device=device).unsqueeze( 0 ).expand(n_batches, -1, -1) if constant_volume: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = virial - diag_mean.unsqueeze(-1) * torch.eye( 3, device=device ).unsqueeze(0).expand(n_batches, -1, -1) # Calculate UCF-style cell gradient ucf_cell_grad = torch.zeros_like(virial) for b in range(n_batches): ucf_cell_grad[b] = virial[b] @ torch.linalg.inv(cur_deform_grad[b].T) # Calculate cell forces using Frechet derivative approach (all zeros for identity) cell_forces = ucf_cell_grad / cell_factor # Sum masses per batch batch_counts = torch.bincount(state.batch) cell_masses = torch.segment_reduce( state.masses, reduce="sum", lengths=batch_counts ) # shape: (n_batches,) cell_masses = cell_masses.unsqueeze(-1).expand(-1, 3) # shape: (n_batches, 3) # Setup parameters dt_start = torch.full((n_batches,), dt_start, device=device, dtype=dtype) alpha_start = torch.full((n_batches,), alpha_start, device=device, dtype=dtype) n_pos = torch.zeros((n_batches,), device=device, dtype=torch.int32) return FrechetCellFIREState( # Create initial state # Copy SimState attributes positions=state.positions, masses=state.masses, cell=state.cell, atomic_numbers=state.atomic_numbers, batch=state.batch, pbc=state.pbc, velocities=None, forces=forces, energy=energy, stress=stress, # Cell attributes cell_positions=cell_positions, cell_velocities=None, cell_forces=cell_forces, cell_masses=cell_masses, # Optimization attributes reference_cell=state.cell.clone(), cell_factor=cell_factor, pressure=pressure, dt=dt_start, alpha=alpha_start, n_pos=n_pos, hydrostatic_strain=hydrostatic_strain, constant_volume=constant_volume, ) step_func_kwargs = dict( model=model, dt_max=dt_max, n_min=n_min, f_inc=f_inc, f_dec=f_dec, alpha_start=alpha_start, f_alpha=f_alpha, eps=eps, is_cell_optimization=True, is_frechet=True, ) if md_flavor == ase_fire_key: step_func_kwargs["max_step"] = max_step step_func = {vv_fire_key: _vv_fire_step, ase_fire_key: _ase_fire_step}[md_flavor] return fire_init, functools.partial(step_func, **step_func_kwargs)
AnyFireCellState = UnitCellFireState | FrechetCellFIREState def _vv_fire_step( # noqa: C901, PLR0915 state: FireState | AnyFireCellState, model: torch.nn.Module, *, dt_max: torch.Tensor, n_min: torch.Tensor, f_inc: torch.Tensor, f_dec: torch.Tensor, alpha_start: torch.Tensor, f_alpha: torch.Tensor, eps: float, is_cell_optimization: bool = False, is_frechet: bool = False, ) -> FireState | AnyFireCellState: """Perform one Velocity-Verlet based FIRE optimization step. Implements one step of the Fast Inertial Relaxation Engine (FIRE) algorithm for optimizing atomic positions and optionally unit cell parameters in a batched setting. Uses velocity Verlet integration with adaptive velocity mixing. Args: state: Current optimization state (FireState, UnitCellFireState, or FrechetCellFIREState). model: Model that computes energies, forces, and potentially stress. dt_max: Maximum allowed timestep. n_min: Minimum steps before timestep increase. f_inc: Factor for timestep increase when power is positive. f_dec: Factor for timestep decrease when power is negative. alpha_start: Initial mixing parameter for velocity update. f_alpha: Factor for mixing parameter decrease. eps: Small epsilon value for numerical stability. is_cell_optimization: Flag indicating if cell optimization is active. is_frechet: Flag indicating if Frechet cell parameterization is used. Returns: Updated state after performing one VV-FIRE step. """ n_batches = state.n_batches device = state.positions.device dtype = state.positions.dtype deform_grad_new: torch.Tensor | None = None if state.velocities is None: state.velocities = torch.zeros_like(state.positions) if is_cell_optimization: if not isinstance(state, AnyFireCellState): raise ValueError( f"Cell optimization requires one of {get_args(AnyFireCellState)}." ) state.cell_velocities = torch.zeros( (n_batches, 3, 3), device=device, dtype=dtype ) alpha_start_batch = torch.full( (n_batches,), alpha_start.item(), device=device, dtype=dtype ) atom_wise_dt = state.dt[state.batch].unsqueeze(-1) state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) if is_cell_optimization: cell_wise_dt = state.dt.unsqueeze(-1).unsqueeze(-1) state.cell_velocities += ( 0.5 * cell_wise_dt * state.cell_forces / state.cell_masses.unsqueeze(-1) ) state.positions = state.positions + atom_wise_dt * state.velocities if is_cell_optimization: cell_factor_reshaped = state.cell_factor.view(n_batches, 1, 1) if is_frechet: if not isinstance(state, expected_cls := FrechetCellFIREState): raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") cur_deform_grad = state.deform_grad() deform_grad_log = torch.zeros_like(cur_deform_grad) for b in range(n_batches): deform_grad_log[b] = tsm.matrix_log_33(cur_deform_grad[b]) cell_positions_log_scaled = deform_grad_log * cell_factor_reshaped cell_positions_log_scaled_new = ( cell_positions_log_scaled + cell_wise_dt * state.cell_velocities ) deform_grad_log_new = cell_positions_log_scaled_new / cell_factor_reshaped deform_grad_new = torch.matrix_exp(deform_grad_log_new) new_row_vector_cell = torch.bmm( state.reference_row_vector_cell, deform_grad_new.transpose(1, 2) ) state.row_vector_cell = new_row_vector_cell state.cell_positions = cell_positions_log_scaled_new else: if not isinstance(state, expected_cls := UnitCellFireState): raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") cur_deform_grad = state.deform_grad() cell_factor_expanded = state.cell_factor.expand(n_batches, 3, 1) current_cell_positions_scaled = ( cur_deform_grad.view(n_batches, 3, 3) * cell_factor_expanded ) cell_positions_scaled_new = ( current_cell_positions_scaled + cell_wise_dt * state.cell_velocities ) cell_update = cell_positions_scaled_new / cell_factor_expanded new_cell = torch.bmm( state.reference_row_vector_cell, cell_update.transpose(1, 2) ) state.row_vector_cell = new_cell state.cell_positions = cell_positions_scaled_new results = model(state) state.forces = results["forces"] state.energy = results["energy"] if is_cell_optimization: state.stress = results["stress"] volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) virial = -volumes * (state.stress + state.pressure) if state.hydrostatic_strain: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = diag_mean.unsqueeze(-1) * torch.eye( 3, device=device, dtype=dtype ).unsqueeze(0).expand(n_batches, -1, -1) if state.constant_volume: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = virial - diag_mean.unsqueeze(-1) * torch.eye( 3, device=device, dtype=dtype ).unsqueeze(0).expand(n_batches, -1, -1) if is_frechet: if not isinstance(state, expected_cls := FrechetCellFIREState): raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") ucf_cell_grad = torch.bmm( virial, torch.linalg.inv(torch.transpose(deform_grad_new, 1, 2)) ) directions = torch.zeros((9, 3, 3), device=device, dtype=dtype) for idx, (mu, nu) in enumerate([(i, j) for i in range(3) for j in range(3)]): directions[idx, mu, nu] = 1.0 new_cell_forces = torch.zeros_like(ucf_cell_grad) for b in range(n_batches): expm_derivs = torch.stack( [ tsm.expm_frechet( deform_grad_log_new[b], direction, compute_expm=False ) for direction in directions ] ) forces_flat = torch.sum( expm_derivs * ucf_cell_grad[b].unsqueeze(0), dim=(1, 2) ) new_cell_forces[b] = forces_flat.reshape(3, 3) state.cell_forces = new_cell_forces / cell_factor_reshaped else: if not isinstance(state, expected_cls := UnitCellFireState): raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") state.cell_forces = virial / cell_factor_reshaped state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) if is_cell_optimization: state.cell_velocities += ( 0.5 * cell_wise_dt * state.cell_forces / state.cell_masses.unsqueeze(-1) ) batch_power = tsm.batched_vdot(state.forces, state.velocities, state.batch) if is_cell_optimization: batch_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) # 2. Update dt, alpha, n_pos pos_mask_batch = batch_power > 0.0 neg_mask_batch = ~pos_mask_batch state.n_pos[pos_mask_batch] += 1 inc_mask = (state.n_pos > n_min) & pos_mask_batch state.dt[inc_mask] = torch.minimum(state.dt[inc_mask] * f_inc, dt_max) state.alpha[inc_mask] *= f_alpha state.dt[neg_mask_batch] *= f_dec state.alpha[neg_mask_batch] = alpha_start_batch[neg_mask_batch] state.n_pos[neg_mask_batch] = 0 v_scaling_batch = tsm.batched_vdot(state.velocities, state.velocities, state.batch) f_scaling_batch = tsm.batched_vdot(state.forces, state.forces, state.batch) if is_cell_optimization: v_scaling_batch += state.cell_velocities.pow(2).sum(dim=(1, 2)) f_scaling_batch += state.cell_forces.pow(2).sum(dim=(1, 2)) v_scaling_cell = torch.sqrt(v_scaling_batch.view(n_batches, 1, 1)) f_scaling_cell = torch.sqrt(f_scaling_batch.view(n_batches, 1, 1)) v_mixing_cell = state.cell_forces / (f_scaling_cell + eps) * v_scaling_cell alpha_cell_bc = state.alpha.view(n_batches, 1, 1) state.cell_velocities = torch.where( pos_mask_batch.view(n_batches, 1, 1), (1.0 - alpha_cell_bc) * state.cell_velocities + alpha_cell_bc * v_mixing_cell, torch.zeros_like(state.cell_velocities), ) v_scaling_atom = torch.sqrt(v_scaling_batch[state.batch].unsqueeze(-1)) f_scaling_atom = torch.sqrt(f_scaling_batch[state.batch].unsqueeze(-1)) v_mixing_atom = state.forces * (v_scaling_atom / (f_scaling_atom + eps)) alpha_atom = state.alpha[state.batch].unsqueeze(-1) # per-atom alpha state.velocities = torch.where( pos_mask_batch[state.batch].unsqueeze(-1), (1.0 - alpha_atom) * state.velocities + alpha_atom * v_mixing_atom, torch.zeros_like(state.velocities), ) return state def _ase_fire_step( # noqa: C901, PLR0915 state: FireState | AnyFireCellState, model: torch.nn.Module, *, dt_max: torch.Tensor, n_min: torch.Tensor, f_inc: torch.Tensor, f_dec: torch.Tensor, alpha_start: torch.Tensor, f_alpha: torch.Tensor, max_step: torch.Tensor, eps: float, is_cell_optimization: bool = False, is_frechet: bool = False, ) -> FireState | AnyFireCellState: """Perform one ASE-style FIRE optimization step. Implements one step of the Fast Inertial Relaxation Engine (FIRE) algorithm mimicking the ASE implementation. It can handle atomic position optimization only, or combined position and cell optimization (standard or Frechet). Args: state: Current optimization state. model: Model that computes energies, forces, and potentially stress. dt_max: Maximum allowed timestep. n_min: Minimum steps before timestep increase. f_inc: Factor for timestep increase when power is positive. f_dec: Factor for timestep decrease when power is negative. alpha_start: Initial mixing parameter for velocity update. f_alpha: Factor for mixing parameter decrease. max_step: Maximum allowed step size. eps: Small epsilon value for numerical stability. is_cell_optimization: Flag indicating if cell optimization is active. is_frechet: Flag indicating if Frechet cell parameterization is used. Returns: Updated state after performing one ASE-FIRE step. """ device, dtype = state.positions.device, state.positions.dtype n_batches = state.n_batches cur_deform_grad = None # Initialize cur_deform_grad to prevent UnboundLocalError if state.velocities is None: state.velocities = torch.zeros_like(state.positions) forces = state.forces if is_cell_optimization: if not isinstance(state, AnyFireCellState): raise ValueError( f"Cell optimization requires one of {get_args(AnyFireCellState)}." ) state.cell_velocities = torch.zeros( (n_batches, 3, 3), device=device, dtype=dtype ) cur_deform_grad = state.deform_grad() else: alpha_start_batch = torch.full( (n_batches,), alpha_start.item(), device=device, dtype=dtype ) if is_cell_optimization: cur_deform_grad = state.deform_grad() forces = torch.bmm( state.forces.unsqueeze(1), cur_deform_grad[state.batch] ).squeeze(1) else: forces = state.forces # 1. Current power (F·v) per batch (atoms + cell) batch_power = tsm.batched_vdot(forces, state.velocities, state.batch) if is_cell_optimization: batch_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) # 2. Update dt, alpha, n_pos pos_mask_batch = batch_power > 0.0 neg_mask_batch = ~pos_mask_batch inc_mask = (state.n_pos > n_min) & pos_mask_batch state.dt[inc_mask] = torch.minimum(state.dt[inc_mask] * f_inc, dt_max) state.alpha[inc_mask] *= f_alpha state.n_pos[pos_mask_batch] += 1 state.dt[neg_mask_batch] *= f_dec state.alpha[neg_mask_batch] = alpha_start_batch[neg_mask_batch] state.n_pos[neg_mask_batch] = 0 # 3. Velocity mixing BEFORE acceleration (ASE ordering) v_scaling_batch = tsm.batched_vdot( state.velocities, state.velocities, state.batch ) f_scaling_batch = tsm.batched_vdot(forces, forces, state.batch) if is_cell_optimization: v_scaling_batch += state.cell_velocities.pow(2).sum(dim=(1, 2)) f_scaling_batch += state.cell_forces.pow(2).sum(dim=(1, 2)) v_scaling_cell = torch.sqrt(v_scaling_batch.view(n_batches, 1, 1)) f_scaling_cell = torch.sqrt(f_scaling_batch.view(n_batches, 1, 1)) v_mixing_cell = state.cell_forces / (f_scaling_cell + eps) * v_scaling_cell alpha_cell_bc = state.alpha.view(n_batches, 1, 1) state.cell_velocities = torch.where( pos_mask_batch.view(n_batches, 1, 1), (1.0 - alpha_cell_bc) * state.cell_velocities + alpha_cell_bc * v_mixing_cell, torch.zeros_like(state.cell_velocities), ) v_scaling_atom = torch.sqrt(v_scaling_batch[state.batch].unsqueeze(-1)) f_scaling_atom = torch.sqrt(f_scaling_batch[state.batch].unsqueeze(-1)) v_mixing_atom = forces * (v_scaling_atom / (f_scaling_atom + eps)) alpha_atom = state.alpha[state.batch].unsqueeze(-1) # per-atom alpha state.velocities = torch.where( pos_mask_batch[state.batch].unsqueeze(-1), (1.0 - alpha_atom) * state.velocities + alpha_atom * v_mixing_atom, torch.zeros_like(state.velocities), ) # 4. Acceleration (single forward-Euler, no mass for ASE FIRE) state.velocities += forces * state.dt[state.batch].unsqueeze(-1) dr_atom = state.velocities * state.dt[state.batch].unsqueeze(-1) dr_scaling_batch = tsm.batched_vdot(dr_atom, dr_atom, state.batch) if is_cell_optimization: state.cell_velocities += state.cell_forces * state.dt.view(n_batches, 1, 1) dr_cell = state.cell_velocities * state.dt.view(n_batches, 1, 1) dr_scaling_batch += dr_cell.pow(2).sum(dim=(1, 2)) dr_scaling_cell = torch.sqrt(dr_scaling_batch).view(n_batches, 1, 1) dr_cell = torch.where( dr_scaling_cell > max_step, max_step * dr_cell / (dr_scaling_cell + eps), dr_cell, ) dr_scaling_atom = torch.sqrt(dr_scaling_batch)[state.batch].unsqueeze(-1) dr_atom = torch.where( dr_scaling_atom > max_step, max_step * dr_atom / (dr_scaling_atom + eps), dr_atom ) if is_cell_optimization: state.positions = ( torch.linalg.solve( cur_deform_grad[state.batch], state.positions.unsqueeze(-1) ).squeeze(-1) + dr_atom ) if is_frechet: if not isinstance(state, expected_cls := FrechetCellFIREState): raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") new_logm_F_scaled = state.cell_positions + dr_cell state.cell_positions = new_logm_F_scaled logm_F_new = new_logm_F_scaled / (state.cell_factor + eps) F_new = torch.matrix_exp(logm_F_new) new_row_vector_cell = torch.bmm(state.reference_row_vector_cell, F_new.mT) state.row_vector_cell = new_row_vector_cell else: if not isinstance(state, expected_cls := UnitCellFireState): raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") F_current = state.deform_grad() cell_factor_exp_mult = state.cell_factor.expand(n_batches, 3, 1) current_F_scaled = F_current * cell_factor_exp_mult F_new_scaled = current_F_scaled + dr_cell state.cell_positions = F_new_scaled F_new = F_new_scaled / (cell_factor_exp_mult + eps) new_row_vector_cell = torch.bmm(state.reference_row_vector_cell, F_new.mT) state.row_vector_cell = new_row_vector_cell state.positions = torch.bmm( state.positions.unsqueeze(1), F_new[state.batch].mT ).squeeze(1) else: state.positions = state.positions + dr_atom # 7. Force / stress refresh & new cell forces results = model(state) state.forces = results["forces"] state.energy = results["energy"] if is_cell_optimization: state.stress = results["stress"] volumes = torch.linalg.det(state.cell).view(n_batches, 1, 1) if torch.any(volumes <= 0): bad_indices = torch.where(volumes <= 0)[0].tolist() print( # noqa: T201 f"WARNING: Non-positive volume(s) detected during _ase_fire_step: " f"{volumes[bad_indices].tolist()} at {bad_indices=} ({is_frechet=})" ) virial = -volumes * (state.stress + state.pressure) if state.hydrostatic_strain: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = diag_mean.unsqueeze(-1) * torch.eye( 3, device=device, dtype=dtype ).unsqueeze(0).expand(n_batches, -1, -1) if state.constant_volume: diag_mean = torch.diagonal(virial, dim1=1, dim2=2).mean(dim=1, keepdim=True) virial = virial - diag_mean.unsqueeze(-1) * torch.eye( 3, device=device, dtype=dtype ).unsqueeze(0).expand(n_batches, -1, -1) if is_frechet: if not isinstance(state, expected_cls := FrechetCellFIREState): raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") if F_new is None: raise ValueError( "F_new should be defined for Frechet cell force calculation" ) if logm_F_new is None: raise ValueError( "logm_F_new should be defined for Frechet cell force calculation" ) ucf_cell_grad = torch.bmm( virial, torch.linalg.inv(torch.transpose(F_new, 1, 2)) ) directions = torch.zeros((9, 3, 3), device=device, dtype=dtype) for idx, (mu, nu) in enumerate( [(i_idx, j_idx) for i_idx in range(3) for j_idx in range(3)] ): directions[idx, mu, nu] = 1.0 new_cell_forces_log_space = torch.zeros_like(state.cell_forces) for b_idx in range(n_batches): expm_derivs = torch.stack( [ tsm.expm_frechet(logm_F_new[b_idx], direction, compute_expm=False) for direction in directions ] ) forces_flat = torch.sum( expm_derivs * ucf_cell_grad[b_idx].unsqueeze(0), dim=(1, 2) ) new_cell_forces_log_space[b_idx] = forces_flat.reshape(3, 3) state.cell_forces = new_cell_forces_log_space / (state.cell_factor + eps) else: if not isinstance(state, expected_cls := UnitCellFireState): raise ValueError(f"{type(state)=} must be a {expected_cls.__name__}") state.cell_forces = virial / state.cell_factor return state