Source code for torch_sim.io

"""Input/output utilities for atomistic systems.

This module provides functions for converting between different structural
representations. It includes utilities for converting ASE Atoms objects,
Pymatgen Structures, and PhonopyAtoms objects to SimState objects and vice versa.

The module handles:

* Converting between ASE Atoms and SimState
* Converting between Pymatgen Structure and SimState
* Converting between PhonopyAtoms and SimState
* Batched conversions for multiple structures
"""

from typing import TYPE_CHECKING

import numpy as np
import torch


if TYPE_CHECKING:
    from ase import Atoms
    from phonopy.structure.atoms import PhonopyAtoms
    from pymatgen.core import Structure

    from torch_sim.state import SimState


[docs] def state_to_atoms(state: "SimState") -> list["Atoms"]: """Convert a SimState to a list of ASE Atoms objects. Args: state (SimState): Batched state containing positions, cell, and atomic numbers Returns: list[Atoms]: ASE Atoms objects, one per batch Raises: ImportError: If ASE is not installed Notes: - Output positions and cell will be in Å - Output masses will be in amu """ try: from ase import Atoms from ase.data import chemical_symbols except ImportError: raise ImportError("ASE is required for state_to_atoms conversion") from None # Convert tensors to numpy arrays on CPU positions = state.positions.detach().cpu().numpy() cell = state.cell.detach().cpu().numpy() # Shape: (n_batches, 3, 3) atomic_numbers = state.atomic_numbers.detach().cpu().numpy() batch = state.batch.detach().cpu().numpy() atoms_list = [] for batch_idx in np.unique(batch): mask = batch == batch_idx batch_positions = positions[mask] batch_numbers = atomic_numbers[mask] batch_cell = cell[batch_idx].T # Transpose for ASE convention # Convert atomic numbers to chemical symbols symbols = [chemical_symbols[z] for z in batch_numbers] atoms = Atoms( symbols=symbols, positions=batch_positions, cell=batch_cell, pbc=state.pbc ) atoms_list.append(atoms) return atoms_list
[docs] def state_to_structures(state: "SimState") -> list["Structure"]: """Convert a SimState to a list of Pymatgen Structure objects. Args: state (SimState): Batched state containing positions, cell, and atomic numbers Returns: list[Structure]: Pymatgen Structure objects, one per batch Raises: ImportError: If Pymatgen is not installed Notes: - Output positions and cell will be in Å - Assumes periodic boundary conditions """ try: from pymatgen.core import Lattice, Structure from pymatgen.core.periodic_table import Element except ImportError: raise ImportError( "Pymatgen is required for state_to_structure conversion" ) from None # Convert tensors to numpy arrays on CPU positions = state.positions.detach().cpu().numpy() cell = state.cell.detach().cpu().numpy() # Shape: (n_batches, 3, 3) atomic_numbers = state.atomic_numbers.detach().cpu().numpy() batch = state.batch.detach().cpu().numpy() # Get unique batch indices and counts unique_batches = np.unique(batch) structures = [] for batch_idx in unique_batches: # Get mask for current batch mask = batch == batch_idx batch_positions = positions[mask] batch_numbers = atomic_numbers[mask] batch_cell = cell[batch_idx].T # Transpose for conventional form # Create species list from atomic numbers species = [Element.from_Z(z) for z in batch_numbers] # Create structure for this batch struct = Structure( lattice=Lattice(batch_cell), species=species, coords=batch_positions, coords_are_cartesian=True, ) structures.append(struct) return structures
[docs] def state_to_phonopy(state: "SimState") -> list["PhonopyAtoms"]: """Convert a SimState to a list of PhonopyAtoms objects. Args: state (SimState): Batched state containing positions, cell, and atomic numbers Returns: list[PhonopyAtoms]: PhonopyAtoms objects, one per batch Raises: ImportError: If Phonopy is not installed Notes: - Output positions and cell will be in Å - Output masses will be in amu """ try: from ase.data import chemical_symbols from phonopy.structure.atoms import PhonopyAtoms except ImportError: raise ImportError( "Phonopy is required for state_to_phonopy_atoms conversion" ) from None # Convert tensors to numpy arrays on CPU positions = state.positions.detach().cpu().numpy() cell = state.cell.detach().cpu().numpy() # Shape: (n_batches, 3, 3) atomic_numbers = state.atomic_numbers.detach().cpu().numpy() batch = state.batch.detach().cpu().numpy() phonopy_atoms_list = [] for batch_idx in np.unique(batch): mask = batch == batch_idx batch_positions = positions[mask] batch_numbers = atomic_numbers[mask] batch_cell = cell[batch_idx].T # Transpose for Phonopy convention # Convert atomic numbers to chemical symbols symbols = [chemical_symbols[z] for z in batch_numbers] phonopy_atoms_list.append( PhonopyAtoms( symbols=symbols, positions=batch_positions, cell=batch_cell, pbc=state.pbc, ) ) return phonopy_atoms_list
[docs] def atoms_to_state( atoms: "Atoms | list[Atoms]", device: torch.device, dtype: torch.dtype, ) -> "SimState": """Convert an ASE Atoms object or list of Atoms objects to a SimState. Args: atoms (Atoms | list[Atoms]): Single ASE Atoms object or list of Atoms objects device (torch.device): Device to create tensors on dtype (torch.dtype): Data type for tensors (typically torch.float32 or torch.float64) Returns: SimState: TorchSim SimState object. Raises: ImportError: If ASE is not installed ValueError: If systems have inconsistent periodic boundary conditions Notes: - Input positions and cell should be in Å - Input masses should be in amu - All systems must have consistent periodic boundary conditions """ from torch_sim.state import SimState try: from ase import Atoms except ImportError: raise ImportError("ASE is required for state_to_atoms conversion") from None atoms_list = [atoms] if isinstance(atoms, Atoms) else atoms # Stack all properties in one go positions = torch.tensor( np.concatenate([a.positions for a in atoms_list]), dtype=dtype, device=device ) masses = torch.tensor( np.concatenate([a.get_masses() for a in atoms_list]), dtype=dtype, device=device ) atomic_numbers = torch.tensor( np.concatenate([a.get_atomic_numbers() for a in atoms_list]), dtype=torch.int, device=device, ) cell = torch.tensor( # Transpose cell from ASE convention to torchsim convention np.stack([a.cell.array.T for a in atoms_list]), dtype=dtype, device=device ) # Create batch indices using repeat_interleave atoms_per_batch = torch.tensor([len(a) for a in atoms_list], device=device) batch = torch.repeat_interleave( torch.arange(len(atoms_list), device=device), atoms_per_batch ) # Verify consistent pbc if not all(all(a.pbc) == all(atoms_list[0].pbc) for a in atoms_list): raise ValueError("All systems must have the same periodic boundary conditions") return SimState( positions=positions, masses=masses, cell=cell, pbc=all(atoms_list[0].pbc), atomic_numbers=atomic_numbers, batch=batch, )
[docs] def structures_to_state( structure: "Structure | list[Structure]", device: torch.device, dtype: torch.dtype, ) -> "SimState": """Create a SimState from pymatgen Structure(s). Args: structure (Structure | list[Structure]): Single Structure or list of Structure objects device (torch.device): Device to create tensors on dtype (torch.dtype): Data type for tensors (typically torch.float32 or torch.float64) Returns: SimState: TorchSim SimState object. Raises: ImportError: If Pymatgen is not installed Notes: - Input positions and cell should be in Å - Cell matrix follows ASE convention: [[ax,ay,az],[bx,by,bz],[cx,cy,cz]] - Assumes periodic boundary conditions from Structure """ from torch_sim.state import SimState try: from pymatgen.core import Structure except ImportError: raise ImportError( "Pymatgen is required for state_to_structure conversion" ) from None struct_list = [structure] if isinstance(structure, Structure) else structure # Stack all properties cell = torch.tensor( np.stack([s.lattice.matrix.T for s in struct_list]), dtype=dtype, device=device ) positions = torch.tensor( np.concatenate([s.cart_coords for s in struct_list]), dtype=dtype, device=device ) masses = torch.tensor( np.concatenate([[site.specie.atomic_mass for site in s] for s in struct_list]), dtype=dtype, device=device, ) atomic_numbers = torch.tensor( np.concatenate([[site.specie.number for site in s] for s in struct_list]), dtype=torch.int, device=device, ) # Create batch indices atoms_per_batch = torch.tensor([len(s) for s in struct_list], device=device) batch = torch.repeat_interleave( torch.arange(len(struct_list), device=device), atoms_per_batch ) return SimState( positions=positions, masses=masses, cell=cell, pbc=True, # Structures are always periodic atomic_numbers=atomic_numbers, batch=batch, )
[docs] def phonopy_to_state( phonopy_atoms: "PhonopyAtoms | list[PhonopyAtoms]", device: torch.device, dtype: torch.dtype, ) -> "SimState": """Create state tensors from a PhonopyAtoms object or list of PhonopyAtoms objects. Args: phonopy_atoms (PhonopyAtoms | list[PhonopyAtoms]): Single PhonopyAtoms object or list of PhonopyAtoms objects device (torch.device): Device to create tensors on dtype (torch.dtype): Data type for tensors (typically torch.float32 or torch.float64) Returns: SimState: TorchSim SimState object. Raises: ImportError: If Phonopy is not installed Notes: - Input positions and cell should be in Å - Input masses should be in amu - PhonopyAtoms does not have pbc attribute for Supercells, assumes True - Cell matrix follows ASE convention: [[ax,ay,az],[bx,by,bz],[cx,cy,cz]] """ from torch_sim.state import SimState try: from phonopy.structure.atoms import PhonopyAtoms except ImportError: raise ImportError("Phonopy is required for phonopy_to_state conversion") from None phonopy_atoms_list = ( [phonopy_atoms] if isinstance(phonopy_atoms, PhonopyAtoms) else phonopy_atoms ) # Stack all properties in one go positions = torch.tensor( np.concatenate([a.positions for a in phonopy_atoms_list]), dtype=dtype, device=device, ) masses = torch.tensor( np.concatenate([a.masses for a in phonopy_atoms_list]), dtype=dtype, device=device, ) atomic_numbers = torch.tensor( np.concatenate([a.numbers for a in phonopy_atoms_list]), dtype=torch.int, device=device, ) cell = torch.tensor( np.stack([a.cell.T for a in phonopy_atoms_list]), dtype=dtype, device=device ) # Create batch indices using repeat_interleave atoms_per_batch = torch.tensor([len(a) for a in phonopy_atoms_list], device=device) batch = torch.repeat_interleave( torch.arange(len(phonopy_atoms_list), device=device), atoms_per_batch ) """ NOTE: PhonopyAtoms does not have pbc attribute for Supercells assume True Verify consistent pbc if not all(all(a.pbc) == all(phonopy_atoms_list[0].pbc) for a in phonopy_atoms_list): raise ValueError("All systems must have the same periodic boundary conditions") """ return SimState( positions=positions, masses=masses, cell=cell, pbc=True, atomic_numbers=atomic_numbers, batch=batch, )