Source code for mythos.energy.martini.m2.lj

"""Lennard-Jones potential energy function for Martini 2."""


import chex
import jax
import jax.numpy as jnp
from jax_md import space
from typing_extensions import override

from mythos.energy.martini.base import MartiniEnergyConfiguration, MartiniEnergyFunction
from mythos.simulators.io import SimulatorTrajectory
from mythos.utils.types import Arr_N, Arr_States_3, MatrixSq, Vector2D

LJ_SIGMA_PREFIX = "lj_sigma_"
LJ_EPSILON_PREFIX = "lj_epsilon_"


[docs] class LJConfiguration(MartiniEnergyConfiguration): """"Configuration for Martini Lennard-Jones energy function. All parameters provided must be of the form "lj_sigma_A_B" or "lj_epsilon_A_B", where A and B are bead types. Pair order is ignored unless both orderings are provided. It is required that sigma and epsilon parameters are provided for any bead type pairs present in the system. Couplings are supported (see :class:`MartiniEnergyConfiguration` for details). """
[docs] @override def __post_init__(self) -> None: bead_types = set() for param in self.params: if not param.startswith((LJ_SIGMA_PREFIX, LJ_EPSILON_PREFIX)): raise ValueError(f"Unexpected parameter {param} for LJConfiguration") bead_types.update(param.split("_")[2:4]) self.bead_types = tuple(sorted(bead_types)) # Construct lookup tables for the values for use in vmapped energy # calculations. These should be symmetric matrices, but we do not # explicitly enforce that. At least one of the pair orderings must exist # or an exception is raised. def get_param(prefix: str, a: str, b: str) -> float: param = self.params.get(f"lj_{prefix}_{a}_{b}", self.params.get(f"lj_{prefix}_{b}_{a}")) if param is None: raise ValueError(f"Missing LJ {prefix} parameter for pair {a}_{b} ({b}_{a})") return param self.sigmas: MatrixSq = jnp.array([ [get_param("sigma", i, j) for j in self.bead_types] for i in self.bead_types ]) self.epsilons: MatrixSq = jnp.array([ [get_param("epsilon", i, j) for j in self.bead_types] for i in self.bead_types ])
[docs] def lennard_jones(r: float, eps: float, sigma: float) -> float: """Calculate Lennard-Jones potential given distance r, epsilon, and sigma.""" cutoff = 1.1 # calculating the standard LJ potential v = 4 * eps * ((sigma / r) ** 12 - (sigma / r) ** 6) # calculating the value of the potential at cutoff v_c = 4 * eps * ((sigma / cutoff) ** 12 - (sigma / cutoff) ** 6) # applying the shifting function: V_s(r) = V(r) - V(r_c) for r < r_c, 0 otherwise return jnp.where( r < cutoff, v - v_c, 0.0 # shifting the potential by subtracting V(r_c) )
[docs] def pair_lj( centers: Arr_States_3, pair: Vector2D, sigmas: MatrixSq, epsilons: MatrixSq, types: Arr_N, displacement_fn: callable, ) -> float: """Calculate LJ energy for a given pair of particles.""" i = pair[0] j = pair[1] i_type = types[i] j_type = types[j] sigma = sigmas[i_type, j_type] eps = epsilons[i_type, j_type] r = space.distance(displacement_fn(centers[i], centers[j])) return lennard_jones(r, eps, sigma)
[docs] @chex.dataclass(frozen=True, kw_only=True) class LJ(MartiniEnergyFunction): """Lennard-Jones potential energy function for Martini 2.""" params: LJConfiguration
[docs] @override def __post_init__(self, topology: None = None) -> None: # Cache a mapping between atom index and its type within sigma/epsilon # matrices type_map = {t: i for i,t in enumerate(self.params.bead_types)} atom_type_map = jnp.array([type_map[t] for t in self.atom_types]) object.__setattr__(self, "_atom_type_map", atom_type_map)
[docs] @override def compute_energy(self, trajectory: SimulatorTrajectory) -> float: displacement_fn = self.displacement_fn(trajectory.box_size) ljmap = jax.vmap(pair_lj, in_axes=(None, 0, None, None, None, None)) return ljmap( trajectory.center, self.unbonded_neighbors, self.params.sigmas, self.params.epsilons, self._atom_type_map, displacement_fn, ).sum()