Source code for mythos.energy.martini.m2.bond

"""Bond potential energy function for Martini 2."""
import chex
import jax
import jax.numpy as jnp
from jax_md import space
from typing_extensions import override

from mythos.energy.martini.base import MartiniEnergyConfiguration, MartiniEnergyFunction
from mythos.simulators.io import SimulatorTrajectory
from mythos.utils.types import Arr_States_3, Vector2D

BOND_K_PREFIX = "bond_k_"
BOND_R0_PREFIX = "bond_r0_"

[docs] class BondConfiguration(MartiniEnergyConfiguration): """Configuration for Martini bond energy function. Bond params must be provided as "bond_k_NAME" and "bond_r0_NAME" in corresponding pairs for each bond name in the system. NAME should be in the format of "MOLTYPE_ATOMNAME1_ATOMNAME2", e.g., "DMPC_NC3_PO4". """
[docs] @override def __post_init__(self) -> None: for param in self.params: if not param.startswith((BOND_K_PREFIX, BOND_R0_PREFIX)): raise ValueError(f"Unexpected parameter {param} for BondConfiguration") if len(self.params) == 0 or len(self.params) % 2 != 0: raise ValueError("BondConfiguration requires pairs of k and r0 parameters")
[docs] def pair_bond( centers: Arr_States_3, pair: Vector2D, k_bond: float, r0_bond: float, displacement_fn: callable ) -> float: """Calculate bond energy for a given pair of particles.""" i = pair[0] j = pair[1] r = space.distance(displacement_fn(centers[i], centers[j])) return 0.5 * k_bond * (r - r0_bond) ** 2
[docs] @chex.dataclass(frozen=True, kw_only=True) class Bond(MartiniEnergyFunction): """Bond potential energy function for Martini 2.""" params: BondConfiguration
[docs] @override def __post_init__(self, topology: None = None) -> None: # cache parameters mapped to bonds by indices. The result is arrays of # len(bonded_neighbors) where each element corresponds to the k or r0 for that bond. k = [self.params[BOND_K_PREFIX + name] for name in self.bond_names] r0 = [self.params[BOND_R0_PREFIX + name] for name in self.bond_names] object.__setattr__(self, "_bonds_k", jnp.array(k)) object.__setattr__(self, "_bonds_r0", jnp.array(r0))
[docs] @override def compute_energy(self, trajectory: SimulatorTrajectory) -> float: displacement_fn = self.displacement_fn(trajectory.box_size) # Using our cached per-bond parameters, we map over the triplicate of # bond pairs, k values, and r0 values. pair_vmap = jax.vmap(pair_bond, in_axes=(None, 0, 0, 0, None)) return pair_vmap( trajectory.center, self.bonded_neighbors, self._bonds_k, self._bonds_r0, displacement_fn, ).sum()