refactor: modernize System and Molecule classes with dataclasses and cleanup

- Convert System and Molecule classes to use @dataclass and field for defaults
- Remove unused imports and legacy code from system.py
- Move print_charges_and_dipole method from System to Player for better separation of concerns
- Minor formatting and import order improvements for consistency
This commit is contained in:
2026-02-28 15:54:46 -03:00
parent a5504b0435
commit 636c65c07c
5 changed files with 46 additions and 153 deletions

View File

@@ -1,4 +1,4 @@
from diceplayer.utils.ptable import PTable, AtomInfo from diceplayer.utils.ptable import AtomInfo, PTable
from dataclasses import dataclass from dataclasses import dataclass

View File

@@ -1,8 +1,5 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, Field
from functools import cached_property
from diceplayer import logger from diceplayer import logger
from diceplayer.environment import Atom from diceplayer.environment import Atom
from diceplayer.utils.cache import invalidate_computed_properties from diceplayer.utils.cache import invalidate_computed_properties
@@ -12,10 +9,12 @@ from diceplayer.utils.ptable import GHOST_NUMBER
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
from numpy.linalg import linalg from numpy.linalg import linalg
from typing_extensions import List, Tuple, Self from typing_extensions import List, Self, Tuple
import math import math
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass, field
from functools import cached_property
@dataclass @dataclass
@@ -30,8 +29,9 @@ class Molecule:
com (npt.NDArray[np.float64]): The center of mass of the molecule com (npt.NDArray[np.float64]): The center of mass of the molecule
inertia_tensor (npt.NDArray[np.float64]): The inertia tensor of the molecule inertia_tensor (npt.NDArray[np.float64]): The inertia tensor of the molecule
""" """
molname: str molname: str
atom: List[Atom] = Field(default_factory=list) atom: List[Atom] = field(default_factory=list)
@cached_property @cached_property
def total_mass(self) -> float: def total_mass(self) -> float:
@@ -301,12 +301,16 @@ class Molecule:
Returns: Returns:
float: minimum distance between the two molecules float: minimum distance between the two molecules
""" """
coords_a = np.array([(a.rx, a.ry, a.rz) for a in self.atom if a.na != GHOST_NUMBER]) coords_a = np.array(
coords_b = np.array([(a.rx, a.ry, a.rz) for a in molec.atom if a.na != GHOST_NUMBER]) [(a.rx, a.ry, a.rz) for a in self.atom if a.na != GHOST_NUMBER]
)
coords_b = np.array(
[(a.rx, a.ry, a.rz) for a in molec.atom if a.na != GHOST_NUMBER]
)
if len(coords_a) == 0 or len(coords_b) == 0: if len(coords_a) == 0 or len(coords_b) == 0:
raise ValueError("No real atoms to compare") raise ValueError("No real atoms to compare")
diff = coords_a[:, None, :] - coords_b[None, :, :] diff = coords_a[:, None, :] - coords_b[None, :, :]
d2 = np.sum(diff ** 2, axis=-1) d2 = np.sum(diff**2, axis=-1)
return np.sqrt(d2.min()) return np.sqrt(d2.min())

View File

@@ -1,15 +1,11 @@
from diceplayer import logger
from diceplayer.environment.molecule import Molecule from diceplayer.environment.molecule import Molecule
from diceplayer.utils.misc import BOHR2ANG
import numpy as np from typing_extensions import List
from numpy import linalg
from typing_extensions import List, Tuple
import math from dataclasses import dataclass, field
from copy import deepcopy
@dataclass
class System: class System:
""" """
System class declaration. This class is used throughout the DicePlayer program to represent the system containing the molecules. System class declaration. This class is used throughout the DicePlayer program to represent the system containing the molecules.
@@ -19,12 +15,8 @@ class System:
nmols (List[int]): List of number of molecules in the system nmols (List[int]): List of number of molecules in the system
""" """
def __init__(self) -> None: nmols: List[int] = field(default_factory=list)
""" molecule: List[Molecule] = field(default_factory=list)
Initializes an empty system object that will be populated afterwards
"""
self.nmols: List[int] = []
self.molecule: List[Molecule] = []
def add_type(self, m: Molecule) -> None: def add_type(self, m: Molecule) -> None:
""" """
@@ -36,123 +28,3 @@ class System:
if not isinstance(m, Molecule): if not isinstance(m, Molecule):
raise TypeError("Error: molecule is not a Molecule instance") raise TypeError("Error: molecule is not a Molecule instance")
self.molecule.append(m) self.molecule.append(m)
def update_molecule(self, position: np.ndarray) -> None:
"""Updates the position of the molecule in the Output file
Args:
position (np.ndarray): numpy position vector
"""
position_in_ang = (position * BOHR2ANG).tolist()
self.add_type(deepcopy(self.molecule[0]))
for atom in self.molecule[-1].atom:
atom.rx = position_in_ang.pop(0)
atom.ry = position_in_ang.pop(0)
atom.rz = position_in_ang.pop(0)
rmsd, self.molecule[0] = self.rmsd_fit(-1, 0)
self.molecule.pop(-1)
logger.info("Projected new conformation of reference molecule with RMSD fit")
logger.info(f"RMSD = {rmsd:>8.5f} Angstrom")
def rmsd_fit(self, p_index: int, r_index: int) -> Tuple[float, Molecule]:
projecting_mol = self.molecule[p_index]
reference_mol = self.molecule[r_index]
if len(projecting_mol.atom) != len(reference_mol.atom):
raise RuntimeError(
"Error in RMSD fit procedure: molecules have different number of atoms"
)
dim = len(projecting_mol.atom)
new_projecting_mol = deepcopy(projecting_mol)
new_reference_mol = deepcopy(reference_mol)
new_projecting_mol.move_center_of_mass_to_origin()
new_reference_mol.move_center_of_mass_to_origin()
x = []
y = []
for atom in new_projecting_mol.atom:
x.extend([atom.rx, atom.ry, atom.rz])
for atom in new_reference_mol.atom:
y.extend([atom.rx, atom.ry, atom.rz])
x = np.array(x).reshape(dim, 3)
y = np.array(y).reshape(dim, 3)
r = np.matmul(y.T, x)
rr = np.matmul(r.T, r)
try:
evals, evecs = linalg.eigh(rr)
except Exception as err:
raise RuntimeError(
"Error: diagonalization of RR matrix did not converge"
) from err
a1 = evecs[:, 2].T
a2 = evecs[:, 1].T
a3 = np.cross(a1, a2)
A = np.array([a1[0], a1[1], a1[2], a2[0], a2[1], a2[2], a3[0], a3[1], a3[2]])
A = A.reshape(3, 3)
b1 = np.matmul(r, a1.T).T # or np.dot(r, a1)
b1 /= linalg.norm(b1)
b2 = np.matmul(r, a2.T).T # or np.dot(r, a2)
b2 /= linalg.norm(b2)
b3 = np.cross(b1, b2)
B = np.array([b1[0], b1[1], b1[2], b2[0], b2[1], b2[2], b3[0], b3[1], b3[2]])
B = B.reshape(3, 3).T
rot_matrix = np.matmul(B, A)
x = np.matmul(rot_matrix, x.T).T
rmsd = 0
for i in range(dim):
rmsd += (
(x[i, 0] - y[i, 0]) ** 2
+ (x[i, 1] - y[i, 1]) ** 2
+ (x[i, 2] - y[i, 2]) ** 2
)
rmsd = math.sqrt(rmsd / dim)
for i in range(dim):
new_projecting_mol.atom[i].rx = x[i, 0]
new_projecting_mol.atom[i].ry = x[i, 1]
new_projecting_mol.atom[i].rz = x[i, 2]
projected_mol = new_projecting_mol.translate(reference_mol.com)
return rmsd, projected_mol
def print_charges_and_dipole(self, cycle: int) -> None:
"""
Print the charges and dipole of the molecule in the Output file
Args:
cycle (int): Number of the cycle
fh (TextIO): Output file
"""
logger.info("Cycle # {}\n".format(cycle))
logger.info("Number of site: {}\n".format(len(self.molecule[0].atom)))
chargesAndDipole = self.molecule[0].charges_and_dipole()
logger.info(
"{:>10.6f} {:>10.6f} {:>10.6f} {:>10.6f} {:>10.6f}\n".format(
chargesAndDipole[0],
chargesAndDipole[1],
chargesAndDipole[2],
chargesAndDipole[3],
chargesAndDipole[4],
)
)

View File

@@ -267,8 +267,6 @@ class Player:
if "position" not in result: if "position" not in result:
raise RuntimeError("Optimization failed. No position found in result.") raise RuntimeError("Optimization failed. No position found in result.")
self.system.update_molecule(result["position"])
else: else:
if "charges" not in result: if "charges" not in result:
raise RuntimeError( raise RuntimeError(
@@ -277,13 +275,37 @@ class Player:
diff = self.system.molecule[0].update_charges(result["charges"]) diff = self.system.molecule[0].update_charges(result["charges"])
self.system.print_charges_and_dipole(cycle) self.print_charges_and_dipole(cycle)
self.print_geoms(cycle) self.print_geoms(cycle)
if diff < self.config.gaussian.chg_tol: if diff < self.config.gaussian.chg_tol:
logger.info(f"Charges converged after {cycle} cycles.") logger.info(f"Charges converged after {cycle} cycles.")
raise StopIteration() raise StopIteration()
def print_charges_and_dipole(self, cycle: int) -> None:
"""
Print the charges and dipole of the molecule in the Output file
Args:
cycle (int): Number of the cycle
fh (TextIO): Output file
"""
logger.info("Cycle # {}\n".format(cycle))
logger.info("Number of site: {}\n".format(len(self.system.molecule[0].atom)))
chargesAndDipole = self.system.molecule[0].charges_and_dipole()
logger.info(
"{:>10.6f} {:>10.6f} {:>10.6f} {:>10.6f} {:>10.6f}\n".format(
chargesAndDipole[0],
chargesAndDipole[1],
chargesAndDipole[2],
chargesAndDipole[3],
chargesAndDipole[4],
)
)
def print_geoms(self, cycle: int): def print_geoms(self, cycle: int):
with open(self.config.geoms_file, "a") as file: with open(self.config.geoms_file, "a") as file:
file.write(f"Cycle # {cycle}\n") file.write(f"Cycle # {cycle}\n")

View File

@@ -68,15 +68,10 @@ class TestMolecule(unittest.TestCase):
Atom(lbl=1, na=1, rx=1.0, ry=1.0, rz=1.0, chg=1.0, eps=1.0, sig=1.0) Atom(lbl=1, na=1, rx=1.0, ry=1.0, rz=1.0, chg=1.0, eps=1.0, sig=1.0)
) )
expected = [ expected = [[0.0, 1.73205081], [1.73205081, 0.0]]
[0.0, 1.73205081],
[1.73205081, 0.0]
]
actual = mol.distances_between_atoms() actual = mol.distances_between_atoms()
npt.assert_almost_equal( npt.assert_almost_equal(expected, actual)
expected, actual
)
def test_inertia_tensor(self): def test_inertia_tensor(self):
mol = Molecule("test") mol = Molecule("test")