Source code for torch_sim.neighbors

"""Utilities for neighbor list calculations."""

import torch
from vesin import NeighborList as VesinNeighborList
from vesin.torch import NeighborList as VesinNeighborList_ts

import torch_sim.math as tsm
from torch_sim.transforms import (
    build_linked_cell_neighborhood,
    build_naive_neighborhood,
    compute_cell_shifts,
)


@torch.jit.script
def primitive_neighbor_list(  # noqa: C901, PLR0915
    quantities: str,
    pbc: tuple[bool, bool, bool],
    cell: torch.Tensor,
    positions: torch.Tensor,
    cutoff: torch.Tensor,
    device: torch.device,
    dtype: torch.dtype,
    self_interaction: bool = False,  # noqa: FBT001, FBT002
    use_scaled_positions: bool = False,  # noqa: FBT001, FBT002
    max_n_bins: int = int(1e6),
) -> list[torch.Tensor]:
    """Compute a neighbor list for an atomic configuration.

    ASE periodic neighbor list implementation
    Atoms outside periodic boundaries are mapped into the unit cell. Atoms
    outside non-periodic boundaries are included in the neighbor list
    but complexity of neighbor list search for those can become n^2.
    The neighbor list is sorted by first atom index 'i', but not by second
    atom index 'j'.

    Args:
        quantities: Quantities to compute by the neighbor list algorithm. Each character
            in this string defines a quantity. They are returned in a tuple of
            the same order. Possible quantities are
                * 'i' : first atom index
                * 'j' : second atom index
                * 'd' : absolute distance
                * 'D' : distance vector
                * 'S' : shift vector (number of cell boundaries crossed by the bond
                  between atom i and j). With the shift vector S, the
                  distances D between atoms can be computed from:
                  D = positions[j]-positions[i]+S.dot(cell)
        pbc: 3-tuple indicating giving periodic boundaries in the three Cartesian
            directions.
        cell: Unit cell vectors according to the row vector convention, i.e.
            `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`.
        positions: Atomic positions. Anything that can be converted to an ndarray of
            shape (n, 3) will do: [(x1,y1,z1), (x2,y2,z2), ...]. If
            use_scaled_positions is set to true, this must be scaled positions.
        cutoff: Cutoff for neighbor search. It can be:
            * A single float: This is a global cutoff for all elements.
            * A dictionary: This specifies cutoff values for element
              pairs. Specification accepts element numbers of symbols.
              Example: {(1, 6): 1.1, (1, 1): 1.0, ('C', 'C'): 1.85}
            * A list/array with a per atom value: This specifies the radius of
              an atomic sphere for each atoms. If spheres overlap, atoms are
              within each others neighborhood.
              See :func:`~ase.neighborlist.natural_cutoffs`
              for an example on how to get such a list.
        device: PyTorch device to use for computations
        dtype: PyTorch data type to use
        self_interaction: Return the atom itself as its own neighbor if set to true.
            Default: False
        use_scaled_positions: If set to true, positions are expected to be
            scaled positions.
        max_n_bins: Maximum number of bins used in neighbor search. This is used to limit
            the maximum amount of memory required by the neighbor list.

    Returns:
        list[torch.Tensor]: One tensor for each item in `quantities`. Indices in `i`
            are returned in ascending order 0..len(a)-1, but the order of (i,j)
            pairs is not guaranteed.

    References:
        - This code is modified version of the github gist
        https://gist.github.com/Linux-cpp-lisp/692018c74b3906b63529e60619f5a207
    """
    # Naming conventions: Suffixes indicate the dimension of an array. The
    # following convention is used here:
    #     c: Cartesian index, can have values 0, 1, 2
    #     i: Global atom index, can have values 0..len(a)-1
    #     xyz: Bin index, three values identifying x-, y- and z-component of a
    #          spatial bin that is used to make neighbor search O(n)
    #     b: Linearized version of the 'xyz' bin index
    #     a: Bin-local atom index, i.e. index identifying an atom *within* a
    #        bin
    #     p: Pair index, can have value 0 or 1
    #     n: (Linear) neighbor index

    # Return empty neighbor list if no atoms are passed here
    if len(positions) == 0:
        raise AssertionError("No atoms provided")

    # Compute reciprocal lattice vectors.
    recip_cell = torch.linalg.pinv(cell).T
    b1_c, b2_c, b3_c = recip_cell[0], recip_cell[1], recip_cell[2]

    # Compute distances of cell faces.
    l1 = torch.linalg.norm(b1_c)
    l2 = torch.linalg.norm(b2_c)
    l3 = torch.linalg.norm(b3_c)
    pytorch_scalar_1 = torch.as_tensor(1.0, device=device, dtype=dtype)
    face_dist_c = torch.hstack(
        [
            1 / l1 if l1 > 0 else pytorch_scalar_1,
            1 / l2 if l2 > 0 else pytorch_scalar_1,
            1 / l3 if l3 > 0 else pytorch_scalar_1,
        ]
    )
    assert face_dist_c.shape == (3,)

    # we don't handle other fancier cutoffs
    max_cutoff: torch.Tensor = cutoff

    # We use a minimum bin size of 3 A
    bin_size = torch.maximum(max_cutoff, torch.tensor(3.0, device=device, dtype=dtype))
    # Compute number of bins such that a sphere of radius cutoff fits into
    # eight neighboring bins.
    n_bins_c = torch.maximum(
        (face_dist_c / bin_size).to(dtype=torch.long, device=device),
        torch.ones(3, dtype=torch.long, device=device),
    )
    n_bins = torch.prod(n_bins_c)
    # Make sure we limit the amount of memory used by the explicit bins.
    while n_bins > max_n_bins:
        n_bins_c = torch.maximum(
            n_bins_c // 2, torch.ones(3, dtype=torch.long, device=device)
        )
        n_bins = torch.prod(n_bins_c)

    # Compute over how many bins we need to loop in the neighbor list search.
    neigh_search = torch.ceil(bin_size * n_bins_c / face_dist_c).to(
        dtype=torch.long, device=device
    )
    neigh_search_x, neigh_search_y, neigh_search_z = (
        neigh_search[0],
        neigh_search[1],
        neigh_search[2],
    )

    # If we only have a single bin and the system is not periodic, then we
    # do not need to search neighboring bins
    pytorch_scalar_int_0 = torch.as_tensor(0, dtype=torch.long, device=device)
    neigh_search_x = (
        pytorch_scalar_int_0 if n_bins_c[0] == 1 and not pbc[0] else neigh_search_x
    )
    neigh_search_y = (
        pytorch_scalar_int_0 if n_bins_c[1] == 1 and not pbc[1] else neigh_search_y
    )
    neigh_search_z = (
        pytorch_scalar_int_0 if n_bins_c[2] == 1 and not pbc[2] else neigh_search_z
    )

    # Sort atoms into bins.
    if not any(pbc):
        scaled_positions_ic = positions
    elif use_scaled_positions:
        scaled_positions_ic = positions
        positions = torch.dot(scaled_positions_ic, cell)
    else:
        scaled_positions_ic = torch.linalg.solve(cell.T, positions.T).T

    bin_index_ic = torch.floor(scaled_positions_ic * n_bins_c).to(
        dtype=torch.long, device=device
    )
    cell_shift_ic = torch.zeros_like(bin_index_ic, device=device)

    for c in range(3):
        if pbc[c]:
            # (Note: torch.divmod does not exist in older numpy versions)
            cell_shift_ic[:, c], bin_index_ic[:, c] = tsm.torch_divmod(
                bin_index_ic[:, c], n_bins_c[c]
            )
        else:
            bin_index_ic[:, c] = torch.clip(bin_index_ic[:, c], 0, n_bins_c[c] - 1)

    # Convert Cartesian bin index to unique scalar bin index.
    bin_index_i = bin_index_ic[:, 0] + n_bins_c[0] * (
        bin_index_ic[:, 1] + n_bins_c[1] * bin_index_ic[:, 2]
    )

    # atom_i contains atom index in new sort order.
    atom_i = torch.argsort(bin_index_i)
    bin_index_i = bin_index_i[atom_i]

    # Find max number of atoms per bin
    max_n_atoms_per_bin = torch.bincount(bin_index_i).max()

    # Sort atoms into bins: atoms_in_bin_ba contains for each bin (identified
    # by its scalar bin index) a list of atoms inside that bin. This list is
    # homogeneous, i.e. has the same size *max_n_atoms_per_bin* for all bins.
    # The list is padded with -1 values.
    atoms_in_bin_ba = -torch.ones(
        n_bins, max_n_atoms_per_bin.item(), dtype=torch.long, device=device
    )
    for bin_cnt in range(int(max_n_atoms_per_bin.item())):
        # Create a mask array that identifies the first atom of each bin.
        mask = torch.cat(
            (
                torch.ones(1, dtype=torch.bool, device=device),
                bin_index_i[:-1] != bin_index_i[1:],
            ),
            dim=0,
        )
        # Assign all first atoms.
        atoms_in_bin_ba[bin_index_i[mask], bin_cnt] = atom_i[mask]

        # Remove atoms that we just sorted into atoms_in_bin_ba. The next
        # "first" atom will be the second and so on.
        mask = torch.logical_not(mask)
        atom_i = atom_i[mask]
        bin_index_i = bin_index_i[mask]

    # Make sure that all atoms have been sorted into bins.
    assert len(atom_i) == 0
    assert len(bin_index_i) == 0

    # Now we construct neighbor pairs by pairing up all atoms within a bin or
    # between bin and neighboring bin. atom_pairs_pn is a helper buffer that
    # contains all potential pairs of atoms between two bins, i.e. it is a list
    # of length max_n_atoms_per_bin**2.
    # atom_pairs_pn_np = np.indices(
    #     (max_n_atoms_per_bin, max_n_atoms_per_bin), dtype=int
    # ).reshape(2, -1)
    atom_pairs_pn = torch.cartesian_prod(
        torch.arange(max_n_atoms_per_bin, device=device),
        torch.arange(max_n_atoms_per_bin, device=device),
    )
    atom_pairs_pn = atom_pairs_pn.T.reshape(2, -1)

    # Initialized empty neighbor list buffers.
    first_at_neigh_tuple_nn = []
    second_at_neigh_tuple_nn = []
    cell_shift_vector_x_n = []
    cell_shift_vector_y_n = []
    cell_shift_vector_z_n = []

    # This is the main neighbor list search. We loop over neighboring bins and
    # then construct all possible pairs of atoms between two bins, assuming
    # that each bin contains exactly max_n_atoms_per_bin atoms. We then throw
    # out pairs involving pad atoms with atom index -1 below.
    binz_xyz, biny_xyz, binx_xyz = torch.meshgrid(
        torch.arange(n_bins_c[2], device=device),
        torch.arange(n_bins_c[1], device=device),
        torch.arange(n_bins_c[0], device=device),
        indexing="ij",
    )
    # The memory layout of binx_xyz, biny_xyz, binz_xyz is such that computing
    # the respective bin index leads to a linearly increasing consecutive list.
    # The following assert statement succeeds:
    #     b_b = (binx_xyz + n_bins_c[0] * (biny_xyz + n_bins_c[1] *
    #                                     binz_xyz)).ravel()
    #     assert (b_b == torch.arange(torch.prod(n_bins_c))).all()

    # First atoms in pair.
    _first_at_neigh_tuple_n = atoms_in_bin_ba[:, atom_pairs_pn[0]]
    for dz in range(-int(neigh_search_z.item()), int(neigh_search_z.item()) + 1):
        for dy in range(-int(neigh_search_y.item()), int(neigh_search_y.item()) + 1):
            for dx in range(-int(neigh_search_x.item()), int(neigh_search_x.item()) + 1):
                # Bin index of neighboring bin and shift vector.
                shiftx_xyz, neighbinx_xyz = tsm.torch_divmod(binx_xyz + dx, n_bins_c[0])
                shifty_xyz, neighbiny_xyz = tsm.torch_divmod(biny_xyz + dy, n_bins_c[1])
                shiftz_xyz, neighbinz_xyz = tsm.torch_divmod(binz_xyz + dz, n_bins_c[2])
                neighbin_b = (
                    neighbinx_xyz
                    + n_bins_c[0] * (neighbiny_xyz + n_bins_c[1] * neighbinz_xyz)
                ).ravel()

                # Second atom in pair.
                _second_at_neigh_tuple_n = atoms_in_bin_ba[neighbin_b][
                    :, atom_pairs_pn[1]
                ]

                # Shift vectors.
                # TODO: was np.resize:
                # _cell_shift_vector_x_n_np = np.resize(
                #     shiftx_xyz.reshape(-1, 1).numpy(),
                #     (int(max_n_atoms_per_bin.item() ** 2), shiftx_xyz.numel()),
                # ).T
                # _cell_shift_vector_y_n_np = np.resize(
                #     shifty_xyz.reshape(-1, 1).numpy(),
                #     (int(max_n_atoms_per_bin.item() ** 2), shifty_xyz.numel()),
                # ).T
                # _cell_shift_vector_z_n_np = np.resize(
                #     shiftz_xyz.reshape(-1, 1).numpy(),
                #     (int(max_n_atoms_per_bin.item() ** 2), shiftz_xyz.numel()),
                # ).T
                # this basically just tiles shiftx_xyz.reshape(-1, 1) n times
                _cell_shift_vector_x_n = shiftx_xyz.reshape(-1, 1).repeat(
                    (1, int(max_n_atoms_per_bin.item() ** 2))
                )
                # assert _cell_shift_vector_x_n.shape == _cell_shift_vector_x_n_np.shape
                # assert np.allclose(
                #     _cell_shift_vector_x_n.numpy(), _cell_shift_vector_x_n_np
                # )
                _cell_shift_vector_y_n = shifty_xyz.reshape(-1, 1).repeat(
                    (1, int(max_n_atoms_per_bin.item() ** 2))
                )
                # assert _cell_shift_vector_y_n.shape == _cell_shift_vector_y_n_np.shape
                # assert np.allclose(
                #     _cell_shift_vector_y_n.numpy(), _cell_shift_vector_y_n_np
                # )
                _cell_shift_vector_z_n = shiftz_xyz.reshape(-1, 1).repeat(
                    (1, int(max_n_atoms_per_bin.item() ** 2))
                )
                # assert _cell_shift_vector_z_n.shape == _cell_shift_vector_z_n_np.shape
                # assert np.allclose(
                #     _cell_shift_vector_z_n.numpy(), _cell_shift_vector_z_n_np
                # )

                # We have created too many pairs because we assumed each bin
                # has exactly max_n_atoms_per_bin atoms. Remove all superfluous
                # pairs. Those are pairs that involve an atom with index -1.
                mask = torch.logical_and(
                    _first_at_neigh_tuple_n != -1, _second_at_neigh_tuple_n != -1
                )
                if mask.sum() > 0:
                    first_at_neigh_tuple_nn += [_first_at_neigh_tuple_n[mask]]
                    second_at_neigh_tuple_nn += [_second_at_neigh_tuple_n[mask]]
                    cell_shift_vector_x_n += [_cell_shift_vector_x_n[mask]]
                    cell_shift_vector_y_n += [_cell_shift_vector_y_n[mask]]
                    cell_shift_vector_z_n += [_cell_shift_vector_z_n[mask]]

    # Flatten overall neighbor list.
    first_at_neigh_tuple_n = torch.cat(first_at_neigh_tuple_nn)
    second_at_neigh_tuple_n = torch.cat(second_at_neigh_tuple_nn)
    cell_shift_vector_n = torch.vstack(
        [
            torch.cat(cell_shift_vector_x_n),
            torch.cat(cell_shift_vector_y_n),
            torch.cat(cell_shift_vector_z_n),
        ]
    ).T

    # Add global cell shift to shift vectors
    cell_shift_vector_n += (
        cell_shift_ic[first_at_neigh_tuple_n] - cell_shift_ic[second_at_neigh_tuple_n]
    )

    # Remove all self-pairs that do not cross the cell boundary.
    if not self_interaction:
        m = torch.logical_not(
            torch.logical_and(
                first_at_neigh_tuple_n == second_at_neigh_tuple_n,
                (cell_shift_vector_n == 0).all(dim=1),
            )
        )
        first_at_neigh_tuple_n = first_at_neigh_tuple_n[m]
        second_at_neigh_tuple_n = second_at_neigh_tuple_n[m]
        cell_shift_vector_n = cell_shift_vector_n[m]

    # For non-periodic directions, remove any bonds that cross the domain
    # boundary.
    for c in range(3):
        if not pbc[c]:
            m = cell_shift_vector_n[:, c] == 0
            first_at_neigh_tuple_n = first_at_neigh_tuple_n[m]
            second_at_neigh_tuple_n = second_at_neigh_tuple_n[m]
            cell_shift_vector_n = cell_shift_vector_n[m]

    # Sort neighbor list.
    bin_cnt = torch.argsort(first_at_neigh_tuple_n)
    first_at_neigh_tuple_n = first_at_neigh_tuple_n[bin_cnt]
    second_at_neigh_tuple_n = second_at_neigh_tuple_n[bin_cnt]
    cell_shift_vector_n = cell_shift_vector_n[bin_cnt]

    # Compute distance vectors.
    # TODO: Use .T?
    distance_vector_nc = (
        positions[second_at_neigh_tuple_n]
        - positions[first_at_neigh_tuple_n]
        + cell_shift_vector_n.to(cell.dtype).matmul(cell)
    )
    abs_distance_vector_n = torch.sqrt(
        torch.sum(distance_vector_nc * distance_vector_nc, dim=1)
    )

    # We have still created too many pairs. Only keep those with distance
    # smaller than max_cutoff.
    mask = abs_distance_vector_n < max_cutoff
    first_at_neigh_tuple_n = first_at_neigh_tuple_n[mask]
    second_at_neigh_tuple_n = second_at_neigh_tuple_n[mask]
    cell_shift_vector_n = cell_shift_vector_n[mask]
    distance_vector_nc = distance_vector_nc[mask]
    abs_distance_vector_n = abs_distance_vector_n[mask]

    # Assemble return tuple.
    ret_vals = []
    for quant in quantities:
        if quant == "i":
            ret_vals += [first_at_neigh_tuple_n]
        elif quant == "j":
            ret_vals += [second_at_neigh_tuple_n]
        elif quant == "D":
            ret_vals += [distance_vector_nc]
        elif quant == "d":
            ret_vals += [abs_distance_vector_n]
        elif quant == "S":
            ret_vals += [cell_shift_vector_n]
        else:
            raise ValueError("Unsupported quantity specified.")

    return ret_vals


@torch.jit.script
def standard_nl(
    positions: torch.Tensor,
    cell: torch.Tensor,
    pbc: bool,  # noqa: FBT001
    cutoff: torch.Tensor,
    sort_id: bool = False,  # noqa: FBT001, FBT002
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute neighbor lists using primitive neighbor list algorithm.

    This function provides a standardized interface for computing neighbor lists
    in atomic systems, wrapping the more general primitive_neighbor_list implementation.
    It handles both periodic and non-periodic boundary conditions and returns
    neighbor pairs along with their periodic shifts.

    The function follows ASE's neighbor list conventions (see ASE:
    https://gitlab.com/ase/ase/-/blob/master/ase/neighborlist.py?ref_type=heads#L152
    but provides a simplified interface focused on the most common use case of
    getting neighbor pairs and shifts.

    Key Features:
    - Handles both periodic and non-periodic systems
    - Returns both neighbor indices and shift vectors for periodic systems
    - Optional sorting of neighbors by first index for better memory access patterns
    - Fully compatible with PyTorch's automatic differentiation

    Args:
        positions: Atomic positions tensor of shape (num_atoms, 3)
        cell: Unit cell vectors according to the row vector convention, i.e.
            `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`.
        pbc: Whether to use periodic boundary conditions (applied to all directions)
        cutoff: Maximum distance for considering atoms as neighbors
        sort_id: If True, sort neighbors by first atom index for better memory
            access patterns

    Returns:
        tuple containing:
            - mapping: Tensor of shape (2, num_neighbors) containing pairs of
              atom indices that are neighbors. Each column (i,j) represents a
              neighbor pair.
            - shifts: Tensor of shape (num_neighbors, 3) containing the periodic
              shift vectors needed to get the correct periodic image for each
              neighbor pair.

    Example:
        >>> # Get neighbors for a periodic system
        >>> positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]])
        >>> cell = torch.eye(3) * 10.0
        >>> mapping, shifts = standard_nl(positions, cell, True, 1.5)
        >>> print(mapping)  # Shows pairs of neighboring atoms
        >>> print(shifts)  # Shows corresponding periodic shifts

    Notes:
        - The function uses primitive_neighbor_list internally but provides a simpler
          interface
        - For non-periodic systems (pbc=False), shifts will be zero vectors
        - The neighbor list includes both (i,j) and (j,i) pairs for complete force
          computation
        - Memory usage scales with system size and number of neighbors per atom

    References:
        - https://gist.github.com/Linux-cpp-lisp/692018c74b3906b63529e60619f5a207
    """
    device = positions.device
    dtype = positions.dtype
    i, j, S = primitive_neighbor_list(
        quantities="ijS",
        positions=positions,
        cell=cell,
        pbc=(pbc, pbc, pbc),
        cutoff=cutoff,
        device=device,
        dtype=dtype,
        self_interaction=False,
        use_scaled_positions=False,
        max_n_bins=torch.tensor(1e6, dtype=torch.int64, device=device),
    )

    mapping = torch.stack((i, j), dim=0)
    mapping = mapping.to(dtype=torch.long)
    shifts = S.to(dtype=dtype)

    if sort_id:
        idx = torch.argsort(mapping[0])
        mapping = mapping[:, idx]
        shifts = shifts[idx, :]

    return mapping, shifts


@torch.jit.script
def vesin_nl_ts(
    positions: torch.Tensor,
    cell: torch.Tensor,
    pbc: bool,  # noqa: FBT001
    cutoff: torch.Tensor,
    sort_id: bool = False,  # noqa: FBT001, FBT002
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute neighbor lists using TorchScript-compatible Vesin implementation.

    This function provides a TorchScript-compatible interface to the Vesin neighbor
    list algorithm using VesinNeighborList_ts. It handles both periodic and non-periodic
    systems and returns neighbor pairs along with their periodic shifts.

    Args:
        positions: Atomic positions tensor of shape (num_atoms, 3)
        cell: Unit cell vectors according to the row vector convention, i.e.
            `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`.
        pbc: Whether to use periodic boundary conditions (applied to all directions)
        cutoff: Maximum distance (scalar tensor) for considering atoms as neighbors
        sort_id: If True, sort neighbors by first atom index for better memory
            access patterns

    Returns:
        tuple containing:
            - mapping: Tensor of shape (2, num_neighbors) containing pairs of
              atom indices that are neighbors. Each column (i,j) represents a
              neighbor pair.
            - shifts: Tensor of shape (num_neighbors, 3) containing the periodic
              shift vectors needed to get the correct periodic image for each
              neighbor pair.

    Notes:
        - Uses VesinNeighborList_ts for TorchScript compatibility
        - Requires CPU tensors in float64 precision internally
        - Returns tensors on the same device as input with original precision
        - For non-periodic systems (pbc=False), shifts will be zero vectors
        - The neighbor list includes both (i,j) and (j,i) pairs

    References:
          https://github.com/Luthaf/vesin
    """
    device = positions.device
    dtype = positions.dtype

    neighbor_list_fn = VesinNeighborList_ts(cutoff.item(), full_list=True)

    # Convert tensors to CPU and float64 properly
    positions_cpu = positions.cpu().to(dtype=torch.float64)
    cell_cpu = cell.cpu().to(dtype=torch.float64)

    # Only works on CPU and requires float64
    i, j, S = neighbor_list_fn.compute(
        points=positions_cpu,
        box=cell_cpu,
        periodic=pbc,
        quantities="ijS",
    )

    mapping = torch.stack((i, j), dim=0)
    mapping = mapping.to(dtype=torch.long, device=device)
    shifts = S.to(dtype=dtype, device=device)

    if sort_id:
        idx = torch.argsort(mapping[0])
        mapping = mapping[:, idx]
        shifts = shifts[idx, :]

    return mapping, shifts


[docs] def vesin_nl( *, positions: torch.Tensor, cell: torch.Tensor, pbc: bool, cutoff: torch.Tensor, sort_id: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """Compute neighbor lists using the standard Vesin implementation. This function provides an interface to the standard Vesin neighbor list algorithm using VesinNeighborList. It handles both periodic and non-periodic systems and returns neighbor pairs along with their periodic shifts. Args: positions: Atomic positions tensor of shape (num_atoms, 3) cell: Unit cell vectors according to the row vector convention, i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. pbc: Whether to use periodic boundary conditions (applied to all directions) cutoff: Maximum distance (scalar tensor) for considering atoms as neighbors sort_id: If True, sort neighbors by first atom index for better memory access patterns Returns: tuple containing: - mapping: Tensor of shape (2, num_neighbors) containing pairs of atom indices that are neighbors. Each column (i,j) represents a neighbor pair. - shifts: Tensor of shape (num_neighbors, 3) containing the periodic shift vectors needed to get the correct periodic image for each neighbor pair. Notes: - Uses standard VesinNeighborList implementation - Requires CPU tensors in float64 precision internally - Returns tensors on the same device as input with original precision - For non-periodic systems (pbc=False), shifts will be zero vectors - The neighbor list includes both (i,j) and (j,i) pairs - Supports pre-sorting through the VesinNeighborList constructor References: - https://github.com/Luthaf/vesin """ device = positions.device dtype = positions.dtype neighbor_list_fn = VesinNeighborList(cutoff, full_list=True, sorted=sort_id) # Convert tensors to CPU and float64 without gradients positions_cpu = positions.detach().cpu().to(dtype=torch.float64) cell_cpu = cell.detach().cpu().to(dtype=torch.float64) # Only works on CPU and returns numpy arrays i, j, S = neighbor_list_fn.compute( points=positions_cpu, box=cell_cpu, periodic=pbc, quantities="ijS", ) i, j = ( torch.tensor(i, dtype=torch.long, device=device), torch.tensor(j, dtype=torch.long, device=device), ) mapping = torch.stack((i, j), dim=0) shifts = torch.tensor(S, dtype=dtype, device=device) return mapping, shifts
[docs] def strict_nl( cutoff: float, positions: torch.Tensor, cell: torch.Tensor, mapping: torch.Tensor, batch_mapping: torch.Tensor, shifts_idx: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Apply a strict cutoff to the neighbor list defined in the mapping. This function filters the neighbor list based on a specified cutoff distance. It computes the squared distances between pairs of positions and retains only those pairs that are within the cutoff distance. The function also accounts for periodic boundary conditions by applying cell shifts when necessary. Args: cutoff (float): The maximum distance for considering two atoms as neighbors. This value is used to filter the neighbor pairs based on their distances. positions (torch.Tensor): A tensor of shape (n_atoms, 3) representing the positions of the atoms. cell (torch.Tensor): Unit cell vectors according to the row vector convention, i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. mapping (torch.Tensor): A tensor of shape (2, n_pairs) that specifies pairs of indices in `positions` for which to compute distances. batch_mapping (torch.Tensor): A tensor that maps the shifts to the corresponding cells, used in conjunction with `shifts_idx` to compute the correct periodic shifts. shifts_idx (torch.Tensor): A tensor of shape (n_shifts, 3) representing the indices for shifts to apply to the distances for periodic boundary conditions. Returns: tuple: A tuple containing: - mapping (torch.Tensor): A filtered tensor of shape (2, n_filtered_pairs) with pairs of indices that are within the cutoff distance. - mapping_batch (torch.Tensor): A tensor of shape (n_filtered_pairs,) that maps the filtered pairs to their corresponding batches. - shifts_idx (torch.Tensor): A tensor of shape (n_filtered_pairs, 3) containing the periodic shift indices for the filtered pairs. Notes: - The function computes the squared distances to avoid the computational cost of taking square roots, which is unnecessary for comparison. - If no cell shifts are needed (i.e., for non-periodic systems), the function directly computes the squared distances between the positions. References: - https://github.com/felixmusil/torch_nl """ cell_shifts = compute_cell_shifts(cell, shifts_idx, batch_mapping) if cell_shifts is None: d2 = (positions[mapping[0]] - positions[mapping[1]]).square().sum(dim=1) else: d2 = ( (positions[mapping[0]] - positions[mapping[1]] - cell_shifts) .square() .sum(dim=1) ) mask = d2 < cutoff * cutoff mapping = mapping[:, mask] mapping_batch = batch_mapping[mask] shifts_idx = shifts_idx[mask] return mapping, mapping_batch, shifts_idx
@torch.jit.script def torch_nl_n2( cutoff: float, positions: torch.Tensor, cell: torch.Tensor, pbc: torch.Tensor, batch: torch.Tensor, self_interaction: bool = False, # noqa: FBT001, FBT002 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Compute the neighbor list for a set of atomic structures using a naive neighbor search before applying a strict `cutoff`. The atomic positions `pos` should be wrapped inside their respective unit cells. Args: cutoff (float): The cutoff radius used for the neighbor search. positions (torch.Tensor [n_atom, 3]): A tensor containing the positions of atoms wrapped inside their respective unit cells. cell (torch.Tensor [3*n_structure, 3]): Unit cell vectors according to the row vector convention, i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. pbc (torch.Tensor [n_structure, 3] bool): A tensor indicating the periodic boundary conditions to apply. Partial PBC are not supported yet. batch (torch.Tensor [n_atom,] torch.long): A tensor containing the index of the structure to which each atom belongs. self_interaction (bool, optional): A flag to indicate whether to keep the center atoms as their own neighbors. Default is False. Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor]: mapping (torch.Tensor [2, n_neighbors]): A tensor containing the indices of the neighbor list for the given positions array. `mapping[0]` corresponds to the central atom indices, and `mapping[1]` corresponds to the neighbor atom indices. batch_mapping (torch.Tensor [n_neighbors]): A tensor mapping the neighbor atoms to their respective structures. shifts_idx (torch.Tensor [n_neighbors, 3]): A tensor containing the cell shift indices used to reconstruct the neighbor atom positions. References: - https://github.com/felixmusil/torch_nl """ n_atoms = torch.bincount(batch) mapping, batch_mapping, shifts_idx = build_naive_neighborhood( positions, cell, pbc, cutoff, n_atoms, self_interaction ) mapping, mapping_batch, shifts_idx = strict_nl( cutoff, positions, cell, mapping, batch_mapping, shifts_idx ) return mapping, mapping_batch, shifts_idx @torch.jit.script def torch_nl_linked_cell( cutoff: float, positions: torch.Tensor, cell: torch.Tensor, pbc: torch.Tensor, batch: torch.Tensor, self_interaction: bool = False, # noqa: FBT001, FBT002 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Compute the neighbor list for a set of atomic structures using the linked cell algorithm before applying a strict `cutoff`. The atoms positions `pos` should be wrapped inside their respective unit cells. Args: cutoff (float): The cutoff radius used for the neighbor search. positions (torch.Tensor [n_atom, 3]): A tensor containing the positions of atoms wrapped inside their respective unit cells. cell (torch.Tensor [3*n_structure, 3]): Unit cell vectors according to the row vector convention, i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. pbc (torch.Tensor [n_structure, 3] bool): A tensor indicating the periodic boundary conditions to apply. Partial PBC are not supported yet. batch (torch.Tensor [n_atom,] torch.long): A tensor containing the index of the structure to which each atom belongs. self_interaction (bool, optional): A flag to indicate whether to keep the center atoms as their own neighbors. Default is False. Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing: - mapping (torch.Tensor [2, n_neighbors]): A tensor containing the indices of the neighbor list for the given positions array. `mapping[0]` corresponds to the central atom indices, and `mapping[1]` corresponds to the neighbor atom indices. - batch_mapping (torch.Tensor [n_neighbors]): A tensor mapping the neighbor atoms to their respective structures. - shifts_idx (torch.Tensor [n_neighbors, 3]): A tensor containing the cell shift indices used to reconstruct the neighbor atom positions. References: - https://github.com/felixmusil/torch_nl """ n_atoms = torch.bincount(batch) mapping, batch_mapping, shifts_idx = build_linked_cell_neighborhood( positions, cell, pbc, cutoff, n_atoms, self_interaction ) mapping, mapping_batch, shifts_idx = strict_nl( cutoff, positions, cell, mapping, batch_mapping, shifts_idx ) return mapping, mapping_batch, shifts_idx