From 636c65c07caf13f81d4ec1b57424fbf935f17dfc Mon Sep 17 00:00:00 2001 From: Vitor Hideyoshi Date: Sat, 28 Feb 2026 15:54:46 -0300 Subject: [PATCH] refactor: modernize System and Molecule classes with dataclasses and cleanup - Convert System and Molecule classes to use @dataclass and field for defaults - Remove unused imports and legacy code from system.py - Move print_charges_and_dipole method from System to Player for better separation of concerns - Minor formatting and import order improvements for consistency --- diceplayer/environment/atom.py | 2 +- diceplayer/environment/molecule.py | 22 ++-- diceplayer/environment/system.py | 138 +--------------------- diceplayer/player.py | 28 ++++- tests/shared/environment/test_molecule.py | 9 +- 5 files changed, 46 insertions(+), 153 deletions(-) diff --git a/diceplayer/environment/atom.py b/diceplayer/environment/atom.py index 34adea5..c5c9dee 100644 --- a/diceplayer/environment/atom.py +++ b/diceplayer/environment/atom.py @@ -1,4 +1,4 @@ -from diceplayer.utils.ptable import PTable, AtomInfo +from diceplayer.utils.ptable import AtomInfo, PTable from dataclasses import dataclass diff --git a/diceplayer/environment/molecule.py b/diceplayer/environment/molecule.py index 95f975c..9b1b865 100644 --- a/diceplayer/environment/molecule.py +++ b/diceplayer/environment/molecule.py @@ -1,8 +1,5 @@ from __future__ import annotations -from dataclasses import dataclass, Field -from functools import cached_property - from diceplayer import logger from diceplayer.environment import Atom from diceplayer.utils.cache import invalidate_computed_properties @@ -12,10 +9,12 @@ from diceplayer.utils.ptable import GHOST_NUMBER import numpy as np import numpy.typing as npt from numpy.linalg import linalg -from typing_extensions import List, Tuple, Self +from typing_extensions import List, Self, Tuple import math from copy import deepcopy +from dataclasses import dataclass, field +from functools import cached_property @dataclass @@ -30,8 +29,9 @@ class Molecule: com (npt.NDArray[np.float64]): The center of mass of the molecule inertia_tensor (npt.NDArray[np.float64]): The inertia tensor of the molecule """ + molname: str - atom: List[Atom] = Field(default_factory=list) + atom: List[Atom] = field(default_factory=list) @cached_property def total_mass(self) -> float: @@ -301,12 +301,16 @@ class Molecule: Returns: float: minimum distance between the two molecules """ - coords_a = np.array([(a.rx, a.ry, a.rz) for a in self.atom if a.na != GHOST_NUMBER]) - coords_b = np.array([(a.rx, a.ry, a.rz) for a in molec.atom if a.na != GHOST_NUMBER]) + coords_a = np.array( + [(a.rx, a.ry, a.rz) for a in self.atom if a.na != GHOST_NUMBER] + ) + coords_b = np.array( + [(a.rx, a.ry, a.rz) for a in molec.atom if a.na != GHOST_NUMBER] + ) if len(coords_a) == 0 or len(coords_b) == 0: raise ValueError("No real atoms to compare") diff = coords_a[:, None, :] - coords_b[None, :, :] - d2 = np.sum(diff ** 2, axis=-1) - return np.sqrt(d2.min()) \ No newline at end of file + d2 = np.sum(diff**2, axis=-1) + return np.sqrt(d2.min()) diff --git a/diceplayer/environment/system.py b/diceplayer/environment/system.py index f84e656..d35523c 100644 --- a/diceplayer/environment/system.py +++ b/diceplayer/environment/system.py @@ -1,15 +1,11 @@ -from diceplayer import logger from diceplayer.environment.molecule import Molecule -from diceplayer.utils.misc import BOHR2ANG -import numpy as np -from numpy import linalg -from typing_extensions import List, Tuple +from typing_extensions import List -import math -from copy import deepcopy +from dataclasses import dataclass, field +@dataclass class System: """ System class declaration. This class is used throughout the DicePlayer program to represent the system containing the molecules. @@ -19,12 +15,8 @@ class System: nmols (List[int]): List of number of molecules in the system """ - def __init__(self) -> None: - """ - Initializes an empty system object that will be populated afterwards - """ - self.nmols: List[int] = [] - self.molecule: List[Molecule] = [] + nmols: List[int] = field(default_factory=list) + molecule: List[Molecule] = field(default_factory=list) def add_type(self, m: Molecule) -> None: """ @@ -36,123 +28,3 @@ class System: if not isinstance(m, Molecule): raise TypeError("Error: molecule is not a Molecule instance") self.molecule.append(m) - - def update_molecule(self, position: np.ndarray) -> None: - """Updates the position of the molecule in the Output file - - Args: - position (np.ndarray): numpy position vector - """ - - position_in_ang = (position * BOHR2ANG).tolist() - self.add_type(deepcopy(self.molecule[0])) - - for atom in self.molecule[-1].atom: - atom.rx = position_in_ang.pop(0) - atom.ry = position_in_ang.pop(0) - atom.rz = position_in_ang.pop(0) - - rmsd, self.molecule[0] = self.rmsd_fit(-1, 0) - self.molecule.pop(-1) - - logger.info("Projected new conformation of reference molecule with RMSD fit") - logger.info(f"RMSD = {rmsd:>8.5f} Angstrom") - - def rmsd_fit(self, p_index: int, r_index: int) -> Tuple[float, Molecule]: - projecting_mol = self.molecule[p_index] - reference_mol = self.molecule[r_index] - - if len(projecting_mol.atom) != len(reference_mol.atom): - raise RuntimeError( - "Error in RMSD fit procedure: molecules have different number of atoms" - ) - dim = len(projecting_mol.atom) - - new_projecting_mol = deepcopy(projecting_mol) - new_reference_mol = deepcopy(reference_mol) - - new_projecting_mol.move_center_of_mass_to_origin() - new_reference_mol.move_center_of_mass_to_origin() - - x = [] - y = [] - - for atom in new_projecting_mol.atom: - x.extend([atom.rx, atom.ry, atom.rz]) - - for atom in new_reference_mol.atom: - y.extend([atom.rx, atom.ry, atom.rz]) - - x = np.array(x).reshape(dim, 3) - y = np.array(y).reshape(dim, 3) - - r = np.matmul(y.T, x) - rr = np.matmul(r.T, r) - - try: - evals, evecs = linalg.eigh(rr) - except Exception as err: - raise RuntimeError( - "Error: diagonalization of RR matrix did not converge" - ) from err - - a1 = evecs[:, 2].T - a2 = evecs[:, 1].T - a3 = np.cross(a1, a2) - - A = np.array([a1[0], a1[1], a1[2], a2[0], a2[1], a2[2], a3[0], a3[1], a3[2]]) - A = A.reshape(3, 3) - - b1 = np.matmul(r, a1.T).T # or np.dot(r, a1) - b1 /= linalg.norm(b1) - b2 = np.matmul(r, a2.T).T # or np.dot(r, a2) - b2 /= linalg.norm(b2) - b3 = np.cross(b1, b2) - - B = np.array([b1[0], b1[1], b1[2], b2[0], b2[1], b2[2], b3[0], b3[1], b3[2]]) - B = B.reshape(3, 3).T - - rot_matrix = np.matmul(B, A) - x = np.matmul(rot_matrix, x.T).T - - rmsd = 0 - for i in range(dim): - rmsd += ( - (x[i, 0] - y[i, 0]) ** 2 - + (x[i, 1] - y[i, 1]) ** 2 - + (x[i, 2] - y[i, 2]) ** 2 - ) - rmsd = math.sqrt(rmsd / dim) - - for i in range(dim): - new_projecting_mol.atom[i].rx = x[i, 0] - new_projecting_mol.atom[i].ry = x[i, 1] - new_projecting_mol.atom[i].rz = x[i, 2] - - projected_mol = new_projecting_mol.translate(reference_mol.com) - - return rmsd, projected_mol - - def print_charges_and_dipole(self, cycle: int) -> None: - """ - Print the charges and dipole of the molecule in the Output file - - Args: - cycle (int): Number of the cycle - fh (TextIO): Output file - """ - - logger.info("Cycle # {}\n".format(cycle)) - logger.info("Number of site: {}\n".format(len(self.molecule[0].atom))) - - chargesAndDipole = self.molecule[0].charges_and_dipole() - - logger.info( - "{:>10.6f} {:>10.6f} {:>10.6f} {:>10.6f} {:>10.6f}\n".format( - chargesAndDipole[0], - chargesAndDipole[1], - chargesAndDipole[2], - chargesAndDipole[3], - chargesAndDipole[4], - ) - ) diff --git a/diceplayer/player.py b/diceplayer/player.py index b646fb4..70fd415 100644 --- a/diceplayer/player.py +++ b/diceplayer/player.py @@ -267,8 +267,6 @@ class Player: if "position" not in result: raise RuntimeError("Optimization failed. No position found in result.") - self.system.update_molecule(result["position"]) - else: if "charges" not in result: raise RuntimeError( @@ -277,13 +275,37 @@ class Player: diff = self.system.molecule[0].update_charges(result["charges"]) - self.system.print_charges_and_dipole(cycle) + self.print_charges_and_dipole(cycle) self.print_geoms(cycle) if diff < self.config.gaussian.chg_tol: logger.info(f"Charges converged after {cycle} cycles.") raise StopIteration() + def print_charges_and_dipole(self, cycle: int) -> None: + """ + Print the charges and dipole of the molecule in the Output file + + Args: + cycle (int): Number of the cycle + fh (TextIO): Output file + """ + + logger.info("Cycle # {}\n".format(cycle)) + logger.info("Number of site: {}\n".format(len(self.system.molecule[0].atom))) + + chargesAndDipole = self.system.molecule[0].charges_and_dipole() + + logger.info( + "{:>10.6f} {:>10.6f} {:>10.6f} {:>10.6f} {:>10.6f}\n".format( + chargesAndDipole[0], + chargesAndDipole[1], + chargesAndDipole[2], + chargesAndDipole[3], + chargesAndDipole[4], + ) + ) + def print_geoms(self, cycle: int): with open(self.config.geoms_file, "a") as file: file.write(f"Cycle # {cycle}\n") diff --git a/tests/shared/environment/test_molecule.py b/tests/shared/environment/test_molecule.py index ad1670f..1f0826b 100644 --- a/tests/shared/environment/test_molecule.py +++ b/tests/shared/environment/test_molecule.py @@ -68,15 +68,10 @@ class TestMolecule(unittest.TestCase): Atom(lbl=1, na=1, rx=1.0, ry=1.0, rz=1.0, chg=1.0, eps=1.0, sig=1.0) ) - expected = [ - [0.0, 1.73205081], - [1.73205081, 0.0] - ] + expected = [[0.0, 1.73205081], [1.73205081, 0.0]] actual = mol.distances_between_atoms() - npt.assert_almost_equal( - expected, actual - ) + npt.assert_almost_equal(expected, actual) def test_inertia_tensor(self): mol = Molecule("test")