refactor: modernize System and Molecule classes with dataclasses and cleanup
- Convert System and Molecule classes to use @dataclass and field for defaults - Remove unused imports and legacy code from system.py - Move print_charges_and_dipole method from System to Player for better separation of concerns - Minor formatting and import order improvements for consistency
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
from diceplayer.utils.ptable import PTable, AtomInfo
|
from diceplayer.utils.ptable import AtomInfo, PTable
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, Field
|
|
||||||
from functools import cached_property
|
|
||||||
|
|
||||||
from diceplayer import logger
|
from diceplayer import logger
|
||||||
from diceplayer.environment import Atom
|
from diceplayer.environment import Atom
|
||||||
from diceplayer.utils.cache import invalidate_computed_properties
|
from diceplayer.utils.cache import invalidate_computed_properties
|
||||||
@@ -12,10 +9,12 @@ from diceplayer.utils.ptable import GHOST_NUMBER
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
from numpy.linalg import linalg
|
from numpy.linalg import linalg
|
||||||
from typing_extensions import List, Tuple, Self
|
from typing_extensions import List, Self, Tuple
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from functools import cached_property
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -30,8 +29,9 @@ class Molecule:
|
|||||||
com (npt.NDArray[np.float64]): The center of mass of the molecule
|
com (npt.NDArray[np.float64]): The center of mass of the molecule
|
||||||
inertia_tensor (npt.NDArray[np.float64]): The inertia tensor of the molecule
|
inertia_tensor (npt.NDArray[np.float64]): The inertia tensor of the molecule
|
||||||
"""
|
"""
|
||||||
|
|
||||||
molname: str
|
molname: str
|
||||||
atom: List[Atom] = Field(default_factory=list)
|
atom: List[Atom] = field(default_factory=list)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def total_mass(self) -> float:
|
def total_mass(self) -> float:
|
||||||
@@ -301,12 +301,16 @@ class Molecule:
|
|||||||
Returns:
|
Returns:
|
||||||
float: minimum distance between the two molecules
|
float: minimum distance between the two molecules
|
||||||
"""
|
"""
|
||||||
coords_a = np.array([(a.rx, a.ry, a.rz) for a in self.atom if a.na != GHOST_NUMBER])
|
coords_a = np.array(
|
||||||
coords_b = np.array([(a.rx, a.ry, a.rz) for a in molec.atom if a.na != GHOST_NUMBER])
|
[(a.rx, a.ry, a.rz) for a in self.atom if a.na != GHOST_NUMBER]
|
||||||
|
)
|
||||||
|
coords_b = np.array(
|
||||||
|
[(a.rx, a.ry, a.rz) for a in molec.atom if a.na != GHOST_NUMBER]
|
||||||
|
)
|
||||||
|
|
||||||
if len(coords_a) == 0 or len(coords_b) == 0:
|
if len(coords_a) == 0 or len(coords_b) == 0:
|
||||||
raise ValueError("No real atoms to compare")
|
raise ValueError("No real atoms to compare")
|
||||||
|
|
||||||
diff = coords_a[:, None, :] - coords_b[None, :, :]
|
diff = coords_a[:, None, :] - coords_b[None, :, :]
|
||||||
d2 = np.sum(diff ** 2, axis=-1)
|
d2 = np.sum(diff**2, axis=-1)
|
||||||
return np.sqrt(d2.min())
|
return np.sqrt(d2.min())
|
||||||
|
|||||||
@@ -1,15 +1,11 @@
|
|||||||
from diceplayer import logger
|
|
||||||
from diceplayer.environment.molecule import Molecule
|
from diceplayer.environment.molecule import Molecule
|
||||||
from diceplayer.utils.misc import BOHR2ANG
|
|
||||||
|
|
||||||
import numpy as np
|
from typing_extensions import List
|
||||||
from numpy import linalg
|
|
||||||
from typing_extensions import List, Tuple
|
|
||||||
|
|
||||||
import math
|
from dataclasses import dataclass, field
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class System:
|
class System:
|
||||||
"""
|
"""
|
||||||
System class declaration. This class is used throughout the DicePlayer program to represent the system containing the molecules.
|
System class declaration. This class is used throughout the DicePlayer program to represent the system containing the molecules.
|
||||||
@@ -19,12 +15,8 @@ class System:
|
|||||||
nmols (List[int]): List of number of molecules in the system
|
nmols (List[int]): List of number of molecules in the system
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
nmols: List[int] = field(default_factory=list)
|
||||||
"""
|
molecule: List[Molecule] = field(default_factory=list)
|
||||||
Initializes an empty system object that will be populated afterwards
|
|
||||||
"""
|
|
||||||
self.nmols: List[int] = []
|
|
||||||
self.molecule: List[Molecule] = []
|
|
||||||
|
|
||||||
def add_type(self, m: Molecule) -> None:
|
def add_type(self, m: Molecule) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -36,123 +28,3 @@ class System:
|
|||||||
if not isinstance(m, Molecule):
|
if not isinstance(m, Molecule):
|
||||||
raise TypeError("Error: molecule is not a Molecule instance")
|
raise TypeError("Error: molecule is not a Molecule instance")
|
||||||
self.molecule.append(m)
|
self.molecule.append(m)
|
||||||
|
|
||||||
def update_molecule(self, position: np.ndarray) -> None:
|
|
||||||
"""Updates the position of the molecule in the Output file
|
|
||||||
|
|
||||||
Args:
|
|
||||||
position (np.ndarray): numpy position vector
|
|
||||||
"""
|
|
||||||
|
|
||||||
position_in_ang = (position * BOHR2ANG).tolist()
|
|
||||||
self.add_type(deepcopy(self.molecule[0]))
|
|
||||||
|
|
||||||
for atom in self.molecule[-1].atom:
|
|
||||||
atom.rx = position_in_ang.pop(0)
|
|
||||||
atom.ry = position_in_ang.pop(0)
|
|
||||||
atom.rz = position_in_ang.pop(0)
|
|
||||||
|
|
||||||
rmsd, self.molecule[0] = self.rmsd_fit(-1, 0)
|
|
||||||
self.molecule.pop(-1)
|
|
||||||
|
|
||||||
logger.info("Projected new conformation of reference molecule with RMSD fit")
|
|
||||||
logger.info(f"RMSD = {rmsd:>8.5f} Angstrom")
|
|
||||||
|
|
||||||
def rmsd_fit(self, p_index: int, r_index: int) -> Tuple[float, Molecule]:
|
|
||||||
projecting_mol = self.molecule[p_index]
|
|
||||||
reference_mol = self.molecule[r_index]
|
|
||||||
|
|
||||||
if len(projecting_mol.atom) != len(reference_mol.atom):
|
|
||||||
raise RuntimeError(
|
|
||||||
"Error in RMSD fit procedure: molecules have different number of atoms"
|
|
||||||
)
|
|
||||||
dim = len(projecting_mol.atom)
|
|
||||||
|
|
||||||
new_projecting_mol = deepcopy(projecting_mol)
|
|
||||||
new_reference_mol = deepcopy(reference_mol)
|
|
||||||
|
|
||||||
new_projecting_mol.move_center_of_mass_to_origin()
|
|
||||||
new_reference_mol.move_center_of_mass_to_origin()
|
|
||||||
|
|
||||||
x = []
|
|
||||||
y = []
|
|
||||||
|
|
||||||
for atom in new_projecting_mol.atom:
|
|
||||||
x.extend([atom.rx, atom.ry, atom.rz])
|
|
||||||
|
|
||||||
for atom in new_reference_mol.atom:
|
|
||||||
y.extend([atom.rx, atom.ry, atom.rz])
|
|
||||||
|
|
||||||
x = np.array(x).reshape(dim, 3)
|
|
||||||
y = np.array(y).reshape(dim, 3)
|
|
||||||
|
|
||||||
r = np.matmul(y.T, x)
|
|
||||||
rr = np.matmul(r.T, r)
|
|
||||||
|
|
||||||
try:
|
|
||||||
evals, evecs = linalg.eigh(rr)
|
|
||||||
except Exception as err:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Error: diagonalization of RR matrix did not converge"
|
|
||||||
) from err
|
|
||||||
|
|
||||||
a1 = evecs[:, 2].T
|
|
||||||
a2 = evecs[:, 1].T
|
|
||||||
a3 = np.cross(a1, a2)
|
|
||||||
|
|
||||||
A = np.array([a1[0], a1[1], a1[2], a2[0], a2[1], a2[2], a3[0], a3[1], a3[2]])
|
|
||||||
A = A.reshape(3, 3)
|
|
||||||
|
|
||||||
b1 = np.matmul(r, a1.T).T # or np.dot(r, a1)
|
|
||||||
b1 /= linalg.norm(b1)
|
|
||||||
b2 = np.matmul(r, a2.T).T # or np.dot(r, a2)
|
|
||||||
b2 /= linalg.norm(b2)
|
|
||||||
b3 = np.cross(b1, b2)
|
|
||||||
|
|
||||||
B = np.array([b1[0], b1[1], b1[2], b2[0], b2[1], b2[2], b3[0], b3[1], b3[2]])
|
|
||||||
B = B.reshape(3, 3).T
|
|
||||||
|
|
||||||
rot_matrix = np.matmul(B, A)
|
|
||||||
x = np.matmul(rot_matrix, x.T).T
|
|
||||||
|
|
||||||
rmsd = 0
|
|
||||||
for i in range(dim):
|
|
||||||
rmsd += (
|
|
||||||
(x[i, 0] - y[i, 0]) ** 2
|
|
||||||
+ (x[i, 1] - y[i, 1]) ** 2
|
|
||||||
+ (x[i, 2] - y[i, 2]) ** 2
|
|
||||||
)
|
|
||||||
rmsd = math.sqrt(rmsd / dim)
|
|
||||||
|
|
||||||
for i in range(dim):
|
|
||||||
new_projecting_mol.atom[i].rx = x[i, 0]
|
|
||||||
new_projecting_mol.atom[i].ry = x[i, 1]
|
|
||||||
new_projecting_mol.atom[i].rz = x[i, 2]
|
|
||||||
|
|
||||||
projected_mol = new_projecting_mol.translate(reference_mol.com)
|
|
||||||
|
|
||||||
return rmsd, projected_mol
|
|
||||||
|
|
||||||
def print_charges_and_dipole(self, cycle: int) -> None:
|
|
||||||
"""
|
|
||||||
Print the charges and dipole of the molecule in the Output file
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cycle (int): Number of the cycle
|
|
||||||
fh (TextIO): Output file
|
|
||||||
"""
|
|
||||||
|
|
||||||
logger.info("Cycle # {}\n".format(cycle))
|
|
||||||
logger.info("Number of site: {}\n".format(len(self.molecule[0].atom)))
|
|
||||||
|
|
||||||
chargesAndDipole = self.molecule[0].charges_and_dipole()
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"{:>10.6f} {:>10.6f} {:>10.6f} {:>10.6f} {:>10.6f}\n".format(
|
|
||||||
chargesAndDipole[0],
|
|
||||||
chargesAndDipole[1],
|
|
||||||
chargesAndDipole[2],
|
|
||||||
chargesAndDipole[3],
|
|
||||||
chargesAndDipole[4],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -267,8 +267,6 @@ class Player:
|
|||||||
if "position" not in result:
|
if "position" not in result:
|
||||||
raise RuntimeError("Optimization failed. No position found in result.")
|
raise RuntimeError("Optimization failed. No position found in result.")
|
||||||
|
|
||||||
self.system.update_molecule(result["position"])
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if "charges" not in result:
|
if "charges" not in result:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@@ -277,13 +275,37 @@ class Player:
|
|||||||
|
|
||||||
diff = self.system.molecule[0].update_charges(result["charges"])
|
diff = self.system.molecule[0].update_charges(result["charges"])
|
||||||
|
|
||||||
self.system.print_charges_and_dipole(cycle)
|
self.print_charges_and_dipole(cycle)
|
||||||
self.print_geoms(cycle)
|
self.print_geoms(cycle)
|
||||||
|
|
||||||
if diff < self.config.gaussian.chg_tol:
|
if diff < self.config.gaussian.chg_tol:
|
||||||
logger.info(f"Charges converged after {cycle} cycles.")
|
logger.info(f"Charges converged after {cycle} cycles.")
|
||||||
raise StopIteration()
|
raise StopIteration()
|
||||||
|
|
||||||
|
def print_charges_and_dipole(self, cycle: int) -> None:
|
||||||
|
"""
|
||||||
|
Print the charges and dipole of the molecule in the Output file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cycle (int): Number of the cycle
|
||||||
|
fh (TextIO): Output file
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info("Cycle # {}\n".format(cycle))
|
||||||
|
logger.info("Number of site: {}\n".format(len(self.system.molecule[0].atom)))
|
||||||
|
|
||||||
|
chargesAndDipole = self.system.molecule[0].charges_and_dipole()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"{:>10.6f} {:>10.6f} {:>10.6f} {:>10.6f} {:>10.6f}\n".format(
|
||||||
|
chargesAndDipole[0],
|
||||||
|
chargesAndDipole[1],
|
||||||
|
chargesAndDipole[2],
|
||||||
|
chargesAndDipole[3],
|
||||||
|
chargesAndDipole[4],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def print_geoms(self, cycle: int):
|
def print_geoms(self, cycle: int):
|
||||||
with open(self.config.geoms_file, "a") as file:
|
with open(self.config.geoms_file, "a") as file:
|
||||||
file.write(f"Cycle # {cycle}\n")
|
file.write(f"Cycle # {cycle}\n")
|
||||||
|
|||||||
@@ -68,15 +68,10 @@ class TestMolecule(unittest.TestCase):
|
|||||||
Atom(lbl=1, na=1, rx=1.0, ry=1.0, rz=1.0, chg=1.0, eps=1.0, sig=1.0)
|
Atom(lbl=1, na=1, rx=1.0, ry=1.0, rz=1.0, chg=1.0, eps=1.0, sig=1.0)
|
||||||
)
|
)
|
||||||
|
|
||||||
expected = [
|
expected = [[0.0, 1.73205081], [1.73205081, 0.0]]
|
||||||
[0.0, 1.73205081],
|
|
||||||
[1.73205081, 0.0]
|
|
||||||
]
|
|
||||||
actual = mol.distances_between_atoms()
|
actual = mol.distances_between_atoms()
|
||||||
|
|
||||||
npt.assert_almost_equal(
|
npt.assert_almost_equal(expected, actual)
|
||||||
expected, actual
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_inertia_tensor(self):
|
def test_inertia_tensor(self):
|
||||||
mol = Molecule("test")
|
mol = Molecule("test")
|
||||||
|
|||||||
Reference in New Issue
Block a user