Source code for mythos.energy.martini.m2.angle

"""Angle potential energy function for Martini 2."""

from typing import ClassVar

import chex
import jax
import jax.numpy as jnp
from typing_extensions import override

from mythos.energy.martini.base import MartiniEnergyConfiguration, MartiniEnergyFunction
from mythos.simulators.io import SimulatorTrajectory
from mythos.utils.types import Arr_States_3, Vector3D

ANGLE_K_PREFIX = "angle_k_"
ANGLE_THETA0_PREFIX = "angle_theta0_"


[docs] class AngleConfiguration(MartiniEnergyConfiguration): """Configuration for Martini angle energy function. Angle params must be provided as "angle_k_I_J_K" and "angle_theta0_I_J_K" in corresponding pairs for each angle name in the system. NAME should be in the format of "MOLTYPE_ATOMNAME1_ATOMNAME2_ATOMNAME3", e.g., "DMPC_NC3_PO4_GL1". """
[docs] @override def __post_init__(self) -> None: for param in self.params: if not param.startswith((ANGLE_K_PREFIX, ANGLE_THETA0_PREFIX)): raise ValueError(f"Unexpected parameter {param} for AngleConfiguration") if len(self.params) == 0 or len(self.params) % 2 != 0: raise ValueError("AngleConfiguration requires pairs of k and theta0 parameters")
[docs] def compute_angle( r_ij: Vector3D, r_kj: Vector3D, ) -> float: """Compute the angle between three particles (angle at j). Args: r_ij: Displacement vector from j to i. r_kj: Displacement vector from j to k. Returns: The angle theta_ijk in radians. """ # Normalize the vectors r_ij_norm = r_ij / jnp.linalg.norm(r_ij) r_kj_norm = r_kj / jnp.linalg.norm(r_kj) # calculating the cross and dot products cross_prod = jnp.cross(r_ij_norm, r_kj_norm) dot_prod = jnp.dot(r_ij_norm, r_kj_norm) # using arctan2 for better numerical stability # arctan2(|a * b|, a ยท b) gives angle between vectors return jnp.arctan2(jnp.sqrt(jnp.sum(cross_prod**2)), dot_prod)
[docs] def triplet_angle( centers: Arr_States_3, triplet: Vector3D, k_angle: float, theta0_angle: float, displacement_fn: callable, use_G96: bool, # noqa: FBT001, N803 ) -> float: """Calculate angle energy for a given triplet of particles. Args: centers: Positions of all particles. triplet: Indices [i, j, k] of the three particles forming the angle. k_angle: Force constant for the angle. theta0_angle: Equilibrium angle in radians. displacement_fn: Function to compute displacement between particles. use_G96: Whether to use Gromacs 1996 cosine-based angle potential (as in Martini 2) or standard harmonic angle potential. Returns: Harmonic angle energy: 0.5 * k * (theta - theta0)^2 """ i = triplet[0] j = triplet[1] k = triplet[2] # Compute displacement vectors from central atom j r_ij = displacement_fn(centers[j], centers[i]) r_kj = displacement_fn(centers[j], centers[k]) theta = compute_angle(r_ij, r_kj) theta_term = (jnp.cos(theta) - jnp.cos(theta0_angle)) if use_G96 else (theta - theta0_angle) return 0.5 * k_angle * theta_term ** 2
[docs] @chex.dataclass(frozen=True, kw_only=True) class Angle(MartiniEnergyFunction): """Angle potential energy function for Martini 2.""" params: AngleConfiguration # https://manual.gromacs.org/current/reference-manual/functions/bonded-interactions.html#harmonicangle # Martini2 uses angle type 2 (G96 Angle) so MSE is defined w.r.t. # cos(theta). Martini3 can set this classvar to False and reuse this code. use_G96: ClassVar[bool] = True # noqa: N815
[docs] @override def __post_init__(self, topology: None = None) -> None: # Cache parameters mapped to angles by indices. The result is arrays of # len(angle_neighbors) where each element corresponds to the k or theta0 for that angle. k = [self.params[ANGLE_K_PREFIX + name] for name in self.angle_names] theta0 = [self.params[ANGLE_THETA0_PREFIX + name] for name in self.angle_names] object.__setattr__(self, "_angles_k", jnp.array(k)) object.__setattr__(self, "_angles_theta0", jnp.array(theta0))
[docs] @override def compute_energy(self, trajectory: SimulatorTrajectory) -> float: displacement_fn = self.displacement_fn(trajectory.box_size) # Using our cached per-angle parameters, we map over the triplet of # angle triplets, k values, and theta0 values. triplet_vmap = jax.vmap(triplet_angle, in_axes=(None, 0, 0, 0, None, None)) return triplet_vmap( trajectory.center, self.angles, self._angles_k, self._angles_theta0, displacement_fn, self.use_G96, ).sum()