From c59f0d6516e972ee38c4c2ba9fdd6b9072133a79 Mon Sep 17 00:00:00 2001 From: Vitor Hideyoshi Date: Sun, 1 Mar 2026 11:01:04 -0300 Subject: [PATCH] refactor: replace Logger with RunLogger and streamline logging setup --- diceplayer/__init__.py | 7 +- diceplayer/__main__.py | 63 +- diceplayer/environment/molecule.py | 4 +- diceplayer/interface/__init__.py | 6 - diceplayer/interface/__interface.py | 26 - diceplayer/interface/dice_interface.py | 389 ----------- diceplayer/interface/gaussian_interface.py | 359 ---------- diceplayer/logger.py | 4 + diceplayer/player.py | 464 +------------ diceplayer/utils/__init__.py | 5 +- diceplayer/utils/dataclass_protocol.py | 6 - diceplayer/utils/logger.py | 88 +-- tests/{shared => environment}/__init__.py | 0 tests/{shared => }/environment/test_atom.py | 0 .../{shared => }/environment/test_molecule.py | 0 tests/{shared => }/environment/test_system.py | 0 tests/shared/environment/__init__.py | 0 tests/shared/interface/__init__.py | 0 tests/shared/interface/test_dice_interface.py | 643 ------------------ .../interface/test_gaussian_interface.py | 115 ---- tests/shared/utils/__init__.py | 0 tests/shared/utils/test_logger.py | 132 ---- tests/test_player.py | 396 ----------- 23 files changed, 71 insertions(+), 2636 deletions(-) delete mode 100644 diceplayer/interface/__init__.py delete mode 100644 diceplayer/interface/__interface.py delete mode 100644 diceplayer/interface/dice_interface.py delete mode 100644 diceplayer/interface/gaussian_interface.py create mode 100644 diceplayer/logger.py delete mode 100644 diceplayer/utils/dataclass_protocol.py rename tests/{shared => environment}/__init__.py (100%) rename tests/{shared => }/environment/test_atom.py (100%) rename tests/{shared => }/environment/test_molecule.py (100%) rename tests/{shared => }/environment/test_system.py (100%) delete mode 100644 tests/shared/environment/__init__.py delete mode 100644 tests/shared/interface/__init__.py delete mode 100644 tests/shared/interface/test_dice_interface.py delete mode 100644 tests/shared/interface/test_gaussian_interface.py delete mode 100644 tests/shared/utils/__init__.py delete mode 100644 tests/shared/utils/test_logger.py delete mode 100644 tests/test_player.py diff --git a/diceplayer/__init__.py b/diceplayer/__init__.py index 5e16c04..36e8f16 100644 --- a/diceplayer/__init__.py +++ b/diceplayer/__init__.py @@ -1,8 +1,3 @@ -from diceplayer.utils import Logger - -from importlib import metadata +from diceplayer.utils.logger import RunLogger -VERSION = metadata.version("diceplayer") - -logger = Logger(__name__) diff --git a/diceplayer/__main__.py b/diceplayer/__main__.py index 53480c8..acafd25 100644 --- a/diceplayer/__main__.py +++ b/diceplayer/__main__.py @@ -1,22 +1,34 @@ -from diceplayer import VERSION, logger -from diceplayer.player import Player +import yaml + +from diceplayer.config.player_config import PlayerConfig +from diceplayer.logger import logger import argparse -import logging +from importlib import metadata + +from diceplayer.player import Player + +VERSION = metadata.version("diceplayer") + + +def read_input(infile) -> PlayerConfig: + try: + with open(infile, "r") as f: + return PlayerConfig.model_validate( + yaml.safe_load(f) + ) + except Exception as e: + logger.exception("Failed to read input file") + raise e def main(): - """ - Read and store the arguments passed to the program - and set the usage and help messages - """ - parser = argparse.ArgumentParser(prog="Diceplayer") parser.add_argument( - "-c", "--continue", dest="opt_continue", default=False, action="store_true" + "-v", "--version", action="version", version="diceplayer-" + VERSION ) parser.add_argument( - "-v", "--version", action="version", version="diceplayer-" + VERSION + "-c", "--continue", dest="continuation", default=False, action="store_true" ) parser.add_argument( "-i", @@ -36,35 +48,12 @@ def main(): ) args = parser.parse_args() - # Open OUTFILE for writing and print keywords and initial info - logger.set_logger(args.outfile, logging.INFO) + logger.set_output_file(args.outfile) - if args.opt_continue: - player = Player.from_save() - else: - player = Player.from_file(args.infile) + config = read_input(args.infile) - player.read_potentials() - - player.create_simulation_dir() - player.create_geoms_file() - - player.print_keywords() - - player.print_potentials() - - player.prepare_system() - - player.start() - - logger.info("\n+" + 88 * "-" + "+\n") - - player.print_results() - - logger.info("\n+" + 88 * "-" + "+\n") - - logger.info("Diceplayer finished successfully \n") + Player(config).play(continuation=args.continuation) if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/diceplayer/environment/molecule.py b/diceplayer/environment/molecule.py index f5946e1..4338599 100644 --- a/diceplayer/environment/molecule.py +++ b/diceplayer/environment/molecule.py @@ -1,14 +1,14 @@ from __future__ import annotations -from diceplayer import logger +from diceplayer.logger import logger from diceplayer.environment import Atom from diceplayer.utils.cache import invalidate_computed_properties from diceplayer.utils.misc import BOHR2ANG, EA_2_DEBYE from diceplayer.utils.ptable import GHOST_NUMBER import numpy as np -import numpy.typing as npt import numpy.linalg as linalg +import numpy.typing as npt from typing_extensions import List, Self, Tuple import math diff --git a/diceplayer/interface/__init__.py b/diceplayer/interface/__init__.py deleted file mode 100644 index aaedaa1..0000000 --- a/diceplayer/interface/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .__interface import Interface -from .dice_interface import DiceInterface -from .gaussian_interface import GaussianInterface - - -__all__ = ["Interface", "DiceInterface", "GaussianInterface"] diff --git a/diceplayer/interface/__interface.py b/diceplayer/interface/__interface.py deleted file mode 100644 index 95829e8..0000000 --- a/diceplayer/interface/__interface.py +++ /dev/null @@ -1,26 +0,0 @@ -from __future__ import annotations - -from diceplayer.config.player_config import PlayerConfig -from diceplayer.environment.system import System - -from abc import ABC, abstractmethod - - -class Interface(ABC): - __slots__ = ["step", "system"] - - def __init__(self): - self.system: System | None = None - self.step: PlayerConfig | None = None - - @abstractmethod - def configure(self, step: PlayerConfig, system: System): - pass - - @abstractmethod - def start(self, cycle: int): - pass - - @abstractmethod - def reset(self): - pass diff --git a/diceplayer/interface/dice_interface.py b/diceplayer/interface/dice_interface.py deleted file mode 100644 index 1f23c9c..0000000 --- a/diceplayer/interface/dice_interface.py +++ /dev/null @@ -1,389 +0,0 @@ -from __future__ import annotations - -from diceplayer import logger -from diceplayer.config.player_config import PlayerConfig -from diceplayer.environment.system import System -from diceplayer.interface import Interface - -from setproctitle import setproctitle -from typing_extensions import Final, TextIO - -import os -import random -import shutil -import subprocess -import sys -import time -from multiprocessing import Process, connection -from pathlib import Path - - -DICE_END_FLAG: Final[str] = "End of simulation" -DICE_FLAG_LINE: Final[int] = -2 -UMAANG3_TO_GCM3: Final[float] = 1.6605 - -MAX_SEED: Final[int] = 4294967295 - - -class DiceInterface(Interface): - title = "Diceplayer run" - - def configure(self, step: PlayerConfig, system: System): - self.step = step - self.system = system - - def start(self, cycle: int): - procs = [] - sentinels = [] - - for proc in range(1, self.step.nprocs + 1): - p = Process(target=self._simulation_process, args=(cycle, proc)) - p.start() - - procs.append(p) - sentinels.append(p.sentinel) - - while procs: - finished = connection.wait(sentinels) - for proc_sentinel in finished: - i = sentinels.index(proc_sentinel) - status = procs[i].exitcode - procs.pop(i) - sentinels.pop(i) - if status != 0: - for p in procs: - p.terminate() - sys.exit(status) - - logger.info("\n") - - def reset(self): - del self.step - del self.system - - def _simulation_process(self, cycle: int, proc: int): - setproctitle(f"diceplayer-step{cycle:0d}-p{proc:0d}") - - try: - self._make_proc_dir(cycle, proc) - self._make_dice_inputs(cycle, proc) - self._run_dice(cycle, proc) - except Exception as err: - sys.exit(err) - - def _make_proc_dir(self, cycle, proc): - simulation_dir = Path(self.step.simulation_dir) - if not simulation_dir.exists(): - simulation_dir.mkdir(parents=True) - - proc_dir = Path(simulation_dir, f"step{cycle:02d}", f"p{proc:02d}") - proc_dir.mkdir(parents=True, exist_ok=True) - - def _make_dice_inputs(self, cycle, proc): - proc_dir = Path(self.step.simulation_dir, f"step{cycle:02d}", f"p{proc:02d}") - - self._make_potentials(proc_dir) - - random.seed(self._make_dice_seed()) - - # This is logic is used to make the initial configuration file - # for the next cycle using the last.xyz file from the previous cycle. - if self.step.dice.randominit == "first" and cycle > 1: - last_xyz = Path( - self.step.simulation_dir, - f"step{(cycle - 1):02d}", - f"p{proc:02d}", - "last.xyz", - ) - if not last_xyz.exists(): - raise FileNotFoundError(f"File {last_xyz} not found.") - - with open(last_xyz, "r") as last_xyz_file: - self._make_init_file(proc_dir, last_xyz_file) - last_xyz_file.seek(0) - self.step.dice.dens = self._new_density(last_xyz_file) - - else: - self._make_nvt_ter(cycle, proc_dir) - - if len(self.step.dice.nstep) == 2: - self._make_nvt_eq(cycle, proc_dir) - - elif len(self.step.dice.nstep) == 3: - self._make_npt_ter(cycle, proc_dir) - self._make_npt_eq(proc_dir) - - def _run_dice(self, cycle: int, proc: int): - working_dir = os.getcwd() - - proc_dir = Path(self.step.simulation_dir, f"step{cycle:02d}", f"p{proc:02d}") - - logger.info( - f"Simulation process {str(proc_dir)} initiated with pid {os.getpid()}" - ) - - os.chdir(proc_dir) - - if not (self.step.dice.randominit == "first" and cycle > 1): - self.run_dice_file(cycle, proc, "NVT.ter") - - if len(self.step.dice.nstep) == 2: - self.run_dice_file(cycle, proc, "NVT.eq") - - elif len(self.step.dice.nstep) == 3: - self.run_dice_file(cycle, proc, "NPT.ter") - self.run_dice_file(cycle, proc, "NPT.eq") - - os.chdir(working_dir) - - xyz_file = Path(proc_dir, "phb.xyz") - last_xyz_file = Path(proc_dir, "last.xyz") - - if xyz_file.exists(): - shutil.copy(xyz_file, last_xyz_file) - else: - raise FileNotFoundError(f"File {xyz_file} not found.") - - @staticmethod - def _make_dice_seed() -> int: - num = time.time() - num = (num - int(num)) * 1e6 - - num = int((num - int(num)) * 1e6) - - return (os.getpid() * num) % (MAX_SEED + 1) - - def _make_init_file(self, proc_dir: Path, last_xyz_file: TextIO): - xyz_lines = last_xyz_file.readlines() - - SECONDARY_MOLECULE_LENGTH = 0 - for i in range(1, len(self.step.dice.nmol)): - SECONDARY_MOLECULE_LENGTH += self.step.dice.nmol[i] * len( - self.system.molecule[i].atom - ) - - xyz_lines = xyz_lines[-SECONDARY_MOLECULE_LENGTH:] - - input_file = Path(proc_dir, self.step.dice.outname + ".xy") - with open(input_file, "w") as f: - for atom in self.system.molecule[0].atom: - f.write(f"{atom.rx:>10.6f} {atom.ry:>10.6f} {atom.rz:>10.6f}\n") - - for line in xyz_lines: - atom = line.split() - rx = float(atom[1]) - ry = float(atom[2]) - rz = float(atom[3]) - f.write(f"{rx:>10.6f} {ry:>10.6f} {rz:>10.6f}\n") - - f.write("$end") - - def _new_density(self, last_xyz_file: TextIO): - last_xyz_lines = last_xyz_file.readlines() - - box = last_xyz_lines[1].split() - volume = float(box[-3]) * float(box[-2]) * float(box[-1]) - - total_mass = 0 - for i in range(len(self.system.molecule)): - total_mass += self.system.molecule[i].total_mass * self.step.dice.nmol[i] - - density = (total_mass / volume) * UMAANG3_TO_GCM3 - - return density - - def _make_nvt_ter(self, cycle, proc_dir): - file = Path(proc_dir, "NVT.ter") - with open(file, "w") as f: - f.write(f"title = {self.title} - NVT Thermalization\n") - f.write(f"ncores = {self.step.ncores}\n") - f.write(f"ljname = {self.step.dice.ljname}\n") - f.write(f"outname = {self.step.dice.outname}\n") - - mol_string = " ".join(str(x) for x in self.step.dice.nmol) - f.write(f"nmol = {mol_string}\n") - - f.write(f"dens = {self.step.dice.dens}\n") - f.write(f"temp = {self.step.dice.temp}\n") - - if self.step.dice.randominit == "first" and cycle > 1: - f.write("init = yesreadxyz\n") - f.write(f"nstep = {self.step.altsteps}\n") - else: - f.write("init = yes\n") - f.write(f"nstep = {self.step.dice.nstep[0]}\n") - - f.write("vstep = 0\n") - f.write("mstop = 1\n") - f.write("accum = no\n") - f.write("iprint = 1\n") - f.write("isave = 0\n") - f.write("irdf = 0\n") - - seed = int(1e6 * random.random()) - f.write(f"seed = {seed}\n") - f.write(f"upbuf = {self.step.dice.upbuf}") - - def _make_nvt_eq(self, cycle, proc_dir): - file = Path(proc_dir, "NVT.eq") - with open(file, "w") as f: - f.write(f"title = {self.title} - NVT Production\n") - f.write(f"ncores = {self.step.ncores}\n") - f.write(f"ljname = {self.step.dice.ljname}\n") - f.write(f"outname = {self.step.dice.outname}\n") - - mol_string = " ".join(str(x) for x in self.step.dice.nmol) - f.write(f"nmol = {mol_string}\n") - - f.write(f"dens = {self.step.dice.dens}\n") - f.write(f"temp = {self.step.dice.temp}\n") - - if self.step.dice.randominit == "first" and cycle > 1: - f.write("init = yesreadxyz\n") - else: - f.write("init = no\n") - - f.write(f"nstep = {self.step.dice.nstep[1]}\n") - - f.write("vstep = 0\n") - f.write("mstop = 1\n") - f.write("accum = no\n") - f.write("iprint = 1\n") - - f.write(f"isave = {self.step.dice.isave}\n") - f.write(f"irdf = {10 * self.step.nprocs}\n") - - seed = int(1e6 * random.random()) - f.write("seed = {}\n".format(seed)) - - def _make_npt_ter(self, cycle, proc_dir): - file = Path(proc_dir, "NPT.ter") - with open(file, "w") as f: - f.write(f"title = {self.title} - NPT Thermalization\n") - f.write(f"ncores = {self.step.ncores}\n") - f.write(f"ljname = {self.step.dice.ljname}\n") - f.write(f"outname = {self.step.dice.outname}\n") - - mol_string = " ".join(str(x) for x in self.step.dice.nmol) - f.write(f"nmol = {mol_string}\n") - - f.write(f"press = {self.step.dice.press}\n") - f.write(f"temp = {self.step.dice.temp}\n") - - if self.step.dice.randominit == "first" and cycle > 1: - f.write("init = yesreadxyz\n") - f.write(f"dens = {self.step.dice.dens:<8.4f}\n") - f.write(f"vstep = {int(self.step.altsteps / 5)}\n") - else: - f.write("init = no\n") - f.write(f"vstep = {int(self.step.dice.nstep[1] / 5)}\n") - - f.write("nstep = 5\n") - f.write("mstop = 1\n") - f.write("accum = no\n") - f.write("iprint = 1\n") - f.write("isave = 0\n") - f.write("irdf = 0\n") - - seed = int(1e6 * random.random()) - f.write(f"seed = {seed}\n") - - def _make_npt_eq(self, proc_dir): - file = Path(proc_dir, "NPT.eq") - with open(file, "w") as f: - f.write(f"title = {self.title} - NPT Production\n") - f.write(f"ncores = {self.step.ncores}\n") - f.write(f"ljname = {self.step.dice.ljname}\n") - f.write(f"outname = {self.step.dice.outname}\n") - - mol_string = " ".join(str(x) for x in self.step.dice.nmol) - f.write(f"nmol = {mol_string}\n") - - f.write(f"press = {self.step.dice.press}\n") - f.write(f"temp = {self.step.dice.temp}\n") - - f.write("nstep = 5\n") - - f.write(f"vstep = {int(self.step.dice.nstep[2] / 5)}\n") - f.write("init = no\n") - f.write("mstop = 1\n") - f.write("accum = no\n") - f.write("iprint = 1\n") - f.write(f"isave = {self.step.dice.isave}\n") - f.write(f"irdf = {10 * self.step.nprocs}\n") - - seed = int(1e6 * random.random()) - f.write(f"seed = {seed}\n") - - def _make_potentials(self, proc_dir): - fstr = "{:<3d} {:>3d} {:>10.5f} {:>10.5f} {:>10.5f} {:>10.6f} {:>9.5f} {:>7.4f}\n" - - file = Path(proc_dir, self.step.dice.ljname) - with open(file, "w") as f: - f.write(f"{self.step.dice.combrule}\n") - f.write(f"{len(self.step.dice.nmol)}\n") - - nsites_qm = len(self.system.molecule[0].atom) - f.write(f"{nsites_qm} {self.system.molecule[0].molname}\n") - - for atom in self.system.molecule[0].atom: - f.write( - fstr.format( - atom.lbl, - atom.na, - atom.rx, - atom.ry, - atom.rz, - atom.chg, - atom.eps, - atom.sig, - ) - ) - - for mol in self.system.molecule[1:]: - f.write(f"{len(mol.atom)} {mol.molname}\n") - for atom in mol.atom: - f.write( - fstr.format( - atom.lbl, - atom.na, - atom.rx, - atom.ry, - atom.rz, - atom.chg, - atom.eps, - atom.sig, - ) - ) - - def run_dice_file(self, cycle: int, proc: int, file_name: str): - with ( - open(Path(file_name), "r") as infile, - open(Path(file_name + ".out"), "w") as outfile, - ): - if shutil.which("bash") is not None: - exit_status = subprocess.call( - [ - "bash", - "-c", - f"exec -a dice-step{cycle}-p{proc} {self.step.dice.progname} < {infile.name} > {outfile.name}", - ] - ) - else: - exit_status = subprocess.call( - self.step.dice.progname, stdin=infile, stdout=outfile - ) - - if exit_status != 0: - raise RuntimeError( - f"Dice process step{cycle:02d}-p{proc:02d} did not exit properly" - ) - - with open(Path(file_name + ".out"), "r") as outfile: - flag = outfile.readlines()[DICE_FLAG_LINE].strip() - if flag != DICE_END_FLAG: - raise RuntimeError( - f"Dice process step{cycle:02d}-p{proc:02d} did not exit properly" - ) - - logger.info(f"Dice {file_name} - step{cycle:02d}-p{proc:02d} exited properly") diff --git a/diceplayer/interface/gaussian_interface.py b/diceplayer/interface/gaussian_interface.py deleted file mode 100644 index a3707ab..0000000 --- a/diceplayer/interface/gaussian_interface.py +++ /dev/null @@ -1,359 +0,0 @@ -from __future__ import annotations - -from diceplayer import logger -from diceplayer.config.player_config import PlayerConfig -from diceplayer.environment import Atom -from diceplayer.environment.molecule import Molecule -from diceplayer.environment.system import System -from diceplayer.interface import Interface -from diceplayer.utils.misc import date_time -from diceplayer.utils.ptable import PTable - -import numpy as np -import numpy.typing as npt -from typing_extensions import Any, Dict, List, Tuple - -import os -import shutil -import subprocess -import textwrap -from pathlib import Path - - -class GaussianInterface(Interface): - def configure(self, step_dto: PlayerConfig, system: System): - self.system = system - self.step = step_dto - - def start(self, cycle: int) -> Dict[str, NDArray]: - self._make_qm_dir(cycle) - - if cycle > 1: - self._copy_chk_file_from_previous_step(cycle) - - asec_charges = self.populate_asec_vdw(cycle) - self._make_gaussian_input_file(cycle, asec_charges) - - self._run_gaussian(cycle) - self._run_formchk(cycle) - - return_value = {} - if self.step.opt: - # return_value['position'] = np.array( - # self._run_optimization(cycle) - # ) - raise NotImplementedError("Optimization not implemented yet.") - - else: - return_value["charges"] = np.array(self._read_charges_from_fchk(cycle)) - - return return_value - - def reset(self): - del self.step - del self.system - - def _make_qm_dir(self, cycle: int): - qm_dir_path = Path(self.step.simulation_dir, f"step{cycle:02d}", "qm") - if not qm_dir_path.exists(): - qm_dir_path.mkdir() - - def _copy_chk_file_from_previous_step(self, cycle: int): - current_chk_file_path = Path( - self.step.simulation_dir, f"step{cycle:02d}", "qm", "asec.chk" - ) - if current_chk_file_path.exists(): - raise FileExistsError(f"File {current_chk_file_path} already exists.") - - previous_chk_file_path = Path( - self.step.simulation_dir, f"step{(cycle - 1):02d}", "qm", "asec.chk" - ) - if not previous_chk_file_path.exists(): - raise FileNotFoundError(f"File {previous_chk_file_path} does not exist.") - - shutil.copy(previous_chk_file_path, current_chk_file_path) - - def populate_asec_vdw(self, cycle: int) -> list[dict]: - norm_factor = self._calculate_norm_factor() - - nsitesref = len(self.system.molecule[0].atom) - - nsites_total = self._calculate_total_number_of_sites(nsitesref) - - proc_charges = [] - for proc in range(1, self.step.nprocs + 1): - proc_charges.append(self._read_charges_from_last_step(cycle, proc)) - - asec_charges, thickness, picked_mols = self._evaluate_proc_charges( - nsites_total, proc_charges - ) - - logger.info( - f"In average, {(sum(picked_mols) / norm_factor):^7.2f} molecules\n" - f"were selected from each of the {len(picked_mols)} configurations\n" - f"of the production simulations to form the ASEC, comprising a shell with\n" - f"minimum thickness of {(sum(thickness) / norm_factor):>6.2f} Angstrom\n" - ) - - for charge in asec_charges: - charge["chg"] = charge["chg"] / norm_factor - - return asec_charges - - def _calculate_norm_factor(self) -> int: - if self.step.dice.nstep[-1] % self.step.dice.isave == 0: - nconfigs = round(self.step.dice.nstep[-1] / self.step.dice.isave) - else: - nconfigs = int(self.step.dice.nstep[-1] / self.step.dice.isave) - - return nconfigs * self.step.nprocs - - def _calculate_total_number_of_sites(self, nsitesref) -> int: - nsites_total = self.step.dice.nmol[0] * nsitesref - for i in range(1, len(self.step.dice.nmol)): - nsites_total += self.step.dice.nmol[i] * len(self.system.molecule[i].atom) - - return nsites_total - - def _read_charges_from_last_step(self, cycle: int, proc: int) -> list[str]: - last_xyz_file_path = Path( - self.step.simulation_dir, f"step{cycle:02d}", f"p{proc:02d}", "last.xyz" - ) - if not last_xyz_file_path.exists(): - raise FileNotFoundError(f"File {last_xyz_file_path} does not exist.") - - with open(last_xyz_file_path, "r") as last_xyz_file: - lines = last_xyz_file.readlines() - - return lines - - def _evaluate_proc_charges( - self, total_nsites: int, proc_charges: list[list[str]] - ) -> Tuple[List[Dict[str, float | Any]], List[float], List[int]]: - asec_charges = [] - - thickness = [] - picked_mols = [] - - for charges in proc_charges: - charges_nsites = int(charges.pop(0)) - if int(charges_nsites) != total_nsites: - raise ValueError( - "Number of sites does not match total number of sites." - ) - - thickness.append(self._calculate_proc_thickness(charges)) - nsites_ref_mol = len(self.system.molecule[0].atom) - charges = charges[nsites_ref_mol:] - - mol_count = 0 - for type in range(len(self.step.dice.nmol)): - if type == 0: - # Reference Molecule must be ignored from type 0 - nmols = self.step.dice.nmol[type] - 1 - else: - nmols = self.step.dice.nmol[type] - - for mol in range(nmols): - new_molecule = Molecule("ASEC TMP MOLECULE") - for site in range(len(self.system.molecule[type].atom)): - line = charges.pop(0).split() - - if ( - line[0].title() - != PTable.get_atomic_symbol( - self.system.molecule[type].atom[site].na - ).strip() - ): - raise SyntaxError( - "Error: Invalid Dice Output. Atom type does not match." - ) - - new_molecule.add_atom( - Atom( - self.system.molecule[type].atom[site].lbl, - self.system.molecule[type].atom[site].na, - float(line[1]), - float(line[2]), - float(line[3]), - self.system.molecule[type].atom[site].chg, - self.system.molecule[type].atom[site].eps, - self.system.molecule[type].atom[site].sig, - ) - ) - - distance = self.system.molecule[0].minimum_distance(new_molecule) - - if distance < thickness[-1]: - for atom in new_molecule.atom: - asec_charges.append( - { - "lbl": PTable.get_atomic_symbol(atom.na), - "rx": atom.rx, - "ry": atom.ry, - "rz": atom.rz, - "chg": atom.chg, - } - ) - mol_count += 1 - - picked_mols.append(mol_count) - - return asec_charges, thickness, picked_mols - - def _calculate_proc_thickness(self, charges: list[str]) -> float: - box = charges.pop(0).split()[-3:] - box = [float(box[0]), float(box[1]), float(box[2])] - sizes = self.system.molecule[0].sizes_of_molecule() - - return min( - [ - (box[0] - sizes[0]) / 2, - (box[1] - sizes[1]) / 2, - (box[2] - sizes[2]) / 2, - ] - ) - - def _make_gaussian_input_file(self, cycle: int, asec_charges: list[dict]) -> None: - gaussian_input_file_path = Path( - self.step.simulation_dir, f"step{cycle:02d}", "qm", "asec.gjf" - ) - - with open(gaussian_input_file_path, "w") as gaussian_input_file: - gaussian_input_file.writelines( - self._generate_gaussian_input(cycle, asec_charges) - ) - - def _generate_gaussian_input( - self, cycle: int, asec_charges: list[dict] - ) -> list[str]: - gaussian_input = ["%Chk=asec.chk\n"] - - if self.step.mem is not None: - gaussian_input.append(f"%Mem={self.step.mem}GB\n") - - gaussian_input.append(f"%Nprocs={self.step.nprocs * self.step.ncores}\n") - - kwords_line = f"#P {self.step.gaussian.level}" - - if self.step.gaussian.keywords: - kwords_line += " " + self.step.gaussian.keywords - - if self.step.opt == "yes": - kwords_line += " Force" - - kwords_line += " NoSymm" - kwords_line += f" Pop={self.step.gaussian.pop} Density=Current" - - if cycle > 1: - kwords_line += " Guess=Read" - - gaussian_input.append(textwrap.fill(kwords_line, 90)) - gaussian_input.append("\n") - - gaussian_input.append("\nForce calculation - Cycle number {}\n".format(cycle)) - gaussian_input.append("\n") - gaussian_input.append( - f"{self.step.gaussian.chgmult[0]},{self.step.gaussian.chgmult[1]}\n" - ) - - for atom in self.system.molecule[0].atom: - symbol = PTable.get_atomic_symbol(atom.na) - gaussian_input.append( - "{:<2s} {:>10.5f} {:>10.5f} {:>10.5f}\n".format( - symbol, atom.rx, atom.ry, atom.rz - ) - ) - - gaussian_input.append("\n") - - for charge in asec_charges: - gaussian_input.append( - "{:>10.5f} {:>10.5f} {:>10.5f} {:>11.8f}\n".format( - charge["rx"], charge["ry"], charge["rz"], charge["chg"] - ) - ) - - gaussian_input.append("\n") - - return gaussian_input - - def _run_gaussian(self, cycle: int) -> None: - qm_dir = Path(self.step.simulation_dir, f"step{(cycle):02d}", "qm") - - working_dir = os.getcwd() - os.chdir(qm_dir) - - infile = "asec.gjf" - - operation = None - if self.step.opt: - operation = "forces" - else: - operation = "charges" - - logger.info( - f"Calculation of {operation} initiated with Gaussian on {date_time()}\n" - ) - - if shutil.which("bash") is not None: - exit_status = subprocess.call( - [ - "bash", - "-c", - "exec -a {}-step{} {} {}".format( - self.step.gaussian.qmprog, - cycle, - self.step.gaussian.qmprog, - infile, - ), - ] - ) - else: - exit_status = subprocess.call([self.step.gaussian.qmprog, infile]) - - if exit_status != 0: - raise SystemError("Gaussian process did not exit properly") - - logger.info(f"Calculation of {operation} finished on {date_time()}") - - os.chdir(working_dir) - - def _run_formchk(self, cycle: int): - qm_dir = Path(self.step.simulation_dir, f"step{(cycle):02d}", "qm") - - work_dir = os.getcwd() - os.chdir(qm_dir) - - logger.info("Formatting the checkpoint file... \n") - - exit_status = subprocess.call( - ["formchk", "asec.chk"], stdout=subprocess.DEVNULL - ) - - if exit_status != 0: - raise SystemError("Formchk process did not exit properly") - - logger.info("Done\n") - - os.chdir(work_dir) - - def _read_charges_from_fchk(self, cycle: int): - fchk_file_path = Path("simfiles", f"step{cycle:02d}", "qm", "asec.fchk") - with open(fchk_file_path) as fchk: - fchkfile = fchk.readlines() - - if self.step.gaussian.pop in ["chelpg", "mk"]: - CHARGE_FLAG = "ESP Charges" - else: - CHARGE_FLAG = "ESP Charges" - - start = fchkfile.pop(0).strip() - while start.find(CHARGE_FLAG) != 0: # expression in begining of line - start = fchkfile.pop(0).strip() - - charges: List[float] = [] - while len(charges) < len(self.system.molecule[0].atom): - charges.extend([float(x) for x in fchkfile.pop(0).split()]) - - return charges diff --git a/diceplayer/logger.py b/diceplayer/logger.py new file mode 100644 index 0000000..9b841c2 --- /dev/null +++ b/diceplayer/logger.py @@ -0,0 +1,4 @@ +from diceplayer import RunLogger + + +logger = RunLogger("diceplayer") diff --git a/diceplayer/player.py b/diceplayer/player.py index 70fd415..6800b28 100644 --- a/diceplayer/player.py +++ b/diceplayer/player.py @@ -1,465 +1,9 @@ -from diceplayer import VERSION, logger from diceplayer.config.player_config import PlayerConfig -from diceplayer.environment import Atom, Molecule, System -from diceplayer.interface import DiceInterface, GaussianInterface -from diceplayer.utils import PTable, weekday_date_time - -import yaml -from pydantic import BaseModel -from typing_extensions import Tuple - -import os -import pickle -import sys -from pathlib import Path - - -ENV = ["OMP_STACKSIZE"] class Player: - def __init__(self, infile: str = None, optimization: bool = False): - if infile is None and optimization is False: - raise ValueError("Must specify either infile or optimization") + def __init__(self, config: PlayerConfig): + self.config = config - elif infile is not None: - self.config = self.set_config(self.read_keywords(infile)) - - self.system = System() - - self.initial_cycle = 1 - - elif optimization is True: - save = self.load_run_from_pickle() - - self.config = save[0] - - self.system = save[1] - - self.initial_cycle = save[2] + 1 - - else: - raise ValueError("Must specify either infile or config") - - self.dice_interface = DiceInterface() - self.gaussian_interface = GaussianInterface() - - def start(self): - logger.info( - "==========================================================================================\n" - "Starting the iterative process.\n" - "==========================================================================================\n" - ) - - for cycle in range(self.initial_cycle, self.initial_cycle + self.config.maxcyc): - logger.info( - f"------------------------------------------------------------------------------------------\n" - f" Step # {cycle}\n" - f"------------------------------------------------------------------------------------------\n" - ) - - self.dice_start(cycle) - - try: - self.gaussian_start(cycle) - except StopIteration: - break - - self.save_run_in_pickle(cycle) - - def prepare_system(self): - for i, mol in enumerate(self.system.molecule): - logger.info(f"Molecule {i + 1} - {mol.molname}") - - mol.print_mol_info() - logger.info( - "\n Translating and rotating molecule to standard orientation..." - ) - - mol.rotate_to_standard_orientation() - logger.info("\n Done") - logger.info("\nNew values:\n") - mol.print_mol_info() - - logger.info("\n") - - def create_simulation_dir(self): - simulation_dir_path = Path(self.config.simulation_dir) - if simulation_dir_path.exists(): - raise FileExistsError( - f"Error: a file or a directory {self.config.simulation_dir} already exists," - f" move or delete the simfiles directory to continue." - ) - simulation_dir_path.mkdir() - - def create_geoms_file(self): - geoms_file_path = Path(self.config.geoms_file) - if geoms_file_path.exists(): - raise FileExistsError( - f"Error: a file or a directory {self.config.geoms_file} already exists," - f" move or delete the simfiles directory to continue." - ) - geoms_file_path.touch() - - def print_keywords(self) -> None: - def log_keywords(config: BaseModel): - for key, value in sorted(config.model_dump().items()): - if value is None: - continue - if isinstance(value, list): - string = " ".join(str(x) for x in value) - logger.info(f"{key} = [ {string} ]") - else: - logger.info(f"{key} = {value}") - - logger.info( - f"##########################################################################################\n" - f"############# Welcome to DICEPLAYER version {VERSION} #############\n" - f"##########################################################################################\n" - ) - logger.info("Your python version is {}\n".format(sys.version)) - logger.info("Program started on {}\n".format(weekday_date_time())) - logger.info("Environment variables:") - for var in ENV: - logger.info( - "{} = {}\n".format( - var, (os.environ[var] if var in os.environ else "Not set") - ) - ) - - logger.info( - "------------------------------------------------------------------------------------------\n" - " DICE variables being used in this run:\n" - "------------------------------------------------------------------------------------------\n" - ) - - log_keywords(self.config.dice) - - logger.info( - "------------------------------------------------------------------------------------------\n" - " GAUSSIAN variables being used in this run:\n" - "------------------------------------------------------------------------------------------\n" - ) - - log_keywords(self.config.gaussian) - - logger.info("\n") - - def read_potentials(self): - ljname_path = Path(self.config.dice.ljname) - if ljname_path.exists(): - with open(self.config.dice.ljname) as file: - ljc_data = file.readlines() - else: - raise RuntimeError(f"Potential file {self.config.dice.ljname} not found.") - - combrule = ljc_data.pop(0).split()[0] - if combrule not in ("*", "+"): - sys.exit( - "Error: expected a '*' or a '+' sign in 1st line of file {}".format( - self.config.dice.ljname - ) - ) - self.config.dice.combrule = combrule - - ntypes = ljc_data.pop(0).split()[0] - if not ntypes.isdigit(): - sys.exit( - "Error: expected an integer in the 2nd line of file {}".format( - self.config.dice.ljname - ) - ) - ntypes = int(ntypes) - - if ntypes != len(self.config.dice.nmol): - sys.exit( - f"Error: number of molecule types in file {self.config.dice.ljname} " - f"must match that of 'nmol' keyword in config file" - ) - - for i in range(ntypes): - try: - nsites, molname = ljc_data.pop(0).split()[:2] - except ValueError: - raise ValueError( - f"Error: expected nsites and molname for the molecule type {i + 1}" - ) - - if not nsites.isdigit(): - raise ValueError( - f"Error: expected nsites to be an integer for molecule type {i + 1}" - ) - - nsites = int(nsites) - self.system.add_type(Molecule(molname)) - - atom_fields = ["lbl", "na", "rx", "ry", "rz", "chg", "eps", "sig"] - for j in range(nsites): - new_atom = dict(zip(atom_fields, ljc_data.pop(0).split())) - self.system.molecule[i].add_atom( - Atom(**self.validate_atom_dict(i, j, new_atom)) - ) - - def print_potentials(self) -> None: - formatstr = "{:<3d} {:>3d} {:>10.5f} {:>10.5f} {:>10.5f} {:>10.6f} {:>9.5f} {:>7.4f} {:>9.4f}" - logger.info( - "==========================================================================================\n" - f" Potential parameters from file {self.config.dice.ljname}:\n" - "------------------------------------------------------------------------------------------" - "\n" - ) - - logger.info(f"Combination rule: {self.config.dice.combrule}") - logger.info(f"Types of molecules: {len(self.system.molecule)}\n") - - i = 0 - for mol in self.system.molecule: - i += 1 - logger.info("{} atoms in molecule type {}:".format(len(mol.atom), i)) - logger.info( - "---------------------------------------------------------------------------------" - ) - logger.info( - "Lbl AN X Y Z Charge Epsilon Sigma Mass" - ) - logger.info( - "---------------------------------------------------------------------------------" - ) - - for atom in mol.atom: - logger.info( - formatstr.format( - atom.lbl, - atom.na, - atom.rx, - atom.ry, - atom.rz, - atom.chg, - atom.eps, - atom.sig, - atom.mass, - ) - ) - - logger.info("\n") - - def dice_start(self, cycle: int): - self.dice_interface.configure( - self.config, - self.system, - ) - - self.dice_interface.start(cycle) - - self.dice_interface.reset() - - def gaussian_start(self, cycle: int): - self.gaussian_interface.configure( - self.config, - self.system, - ) - - result = self.gaussian_interface.start(cycle) - - self.gaussian_interface.reset() - - if self.config.opt: - if "position" not in result: - raise RuntimeError("Optimization failed. No position found in result.") - - else: - if "charges" not in result: - raise RuntimeError( - "Charges optimization failed. No charges found in result." - ) - - diff = self.system.molecule[0].update_charges(result["charges"]) - - self.print_charges_and_dipole(cycle) - self.print_geoms(cycle) - - if diff < self.config.gaussian.chg_tol: - logger.info(f"Charges converged after {cycle} cycles.") - raise StopIteration() - - def print_charges_and_dipole(self, cycle: int) -> None: - """ - Print the charges and dipole of the molecule in the Output file - - Args: - cycle (int): Number of the cycle - fh (TextIO): Output file - """ - - logger.info("Cycle # {}\n".format(cycle)) - logger.info("Number of site: {}\n".format(len(self.system.molecule[0].atom))) - - chargesAndDipole = self.system.molecule[0].charges_and_dipole() - - logger.info( - "{:>10.6f} {:>10.6f} {:>10.6f} {:>10.6f} {:>10.6f}\n".format( - chargesAndDipole[0], - chargesAndDipole[1], - chargesAndDipole[2], - chargesAndDipole[3], - chargesAndDipole[4], - ) - ) - - def print_geoms(self, cycle: int): - with open(self.config.geoms_file, "a") as file: - file.write(f"Cycle # {cycle}\n") - - for atom in self.system.molecule[0].atom: - symbol = PTable.get_atomic_symbol(atom.na) - file.write( - f"{symbol:<2s} {atom.rx:>10.6f} {atom.ry:>10.6f} {atom.rz:>10.6f}\n" - ) - - file.write("\n") - - @staticmethod - def validate_atom_dict(molecule_type, molecule_site, atom_dict: dict) -> dict: - molecule_type += 1 - molecule_site += 1 - - if len(atom_dict) < 8: - raise ValueError( - f"Invalid number of fields for site {molecule_site} for molecule type {molecule_type}." - ) - - try: - atom_dict["lbl"] = int(atom_dict["lbl"]) - except Exception: - raise ValueError( - f"Invalid lbl fields for site {molecule_site} for molecule type {molecule_type}." - ) - - try: - atom_dict["na"] = int(atom_dict["na"]) - except Exception: - raise ValueError( - f"Invalid na fields for site {molecule_site} for molecule type {molecule_type}." - ) - - try: - atom_dict["rx"] = float(atom_dict["rx"]) - except Exception: - raise ValueError( - f"Invalid rx fields for site {molecule_site} for molecule type {molecule_type}. " - f"Value must be a float." - ) - - try: - atom_dict["ry"] = float(atom_dict["ry"]) - except Exception: - raise ValueError( - f"Invalid ry fields for site {molecule_site} for molecule type {molecule_type}. " - f"Value must be a float." - ) - - try: - atom_dict["rz"] = float(atom_dict["rz"]) - except Exception: - raise ValueError( - f"Invalid rz fields for site {molecule_site} for molecule type {molecule_type}. " - f"Value must be a float." - ) - - try: - atom_dict["chg"] = float(atom_dict["chg"]) - except Exception: - raise ValueError( - f"Invalid chg fields for site {molecule_site} for molecule type {molecule_type}. " - f"Value must be a float." - ) - - try: - atom_dict["eps"] = float(atom_dict["eps"]) - except Exception: - raise ValueError( - f"Invalid eps fields for site {molecule_site} for molecule type {molecule_type}. " - f"Value must be a float." - ) - - try: - atom_dict["sig"] = float(atom_dict["sig"]) - except Exception: - raise ValueError( - f"Invalid sig fields for site {molecule_site} for molecule type {molecule_type}. " - f"Value must be a float." - ) - - return atom_dict - - def print_results(self): - formatstr = "{:<3d} {:>3d} {:>10.5f} {:>10.5f} {:>10.5f} {:>10.6f} {:>9.5f} {:>7.4f} {:>9.4f}" - - mol = self.system.molecule[0] - logger.info("{} atoms in molecule type {}:".format(len(mol.atom), 1)) - logger.info( - "---------------------------------------------------------------------------------" - ) - logger.info( - "Lbl AN X Y Z Charge Epsilon Sigma Mass" - ) - logger.info( - "---------------------------------------------------------------------------------" - ) - - for atom in mol.atom: - logger.info( - formatstr.format( - atom.lbl, - atom.na, - atom.rx, - atom.ry, - atom.rz, - atom.chg, - atom.eps, - atom.sig, - atom.mass, - ) - ) - - logger.info("\n") - - def save_run_in_pickle(self, cycle): - try: - with open("latest-step.pkl", "wb") as pickle_file: - pickle.dump((self.config, self.system, cycle), pickle_file) - except Exception: - raise RuntimeError("Could not save pickle file latest-step.pkl.") - - @staticmethod - def load_run_from_pickle() -> Tuple[PlayerConfig, System, int]: - pickle_path = Path("latest-step.pkl") - try: - with open(pickle_path, "rb") as pickle_file: - save = pickle.load(pickle_file) - return save[0], save[1], save[2] + 1 - - except Exception: - raise RuntimeError(f"Could not load pickle file {pickle_path}.") - - @staticmethod - def set_config(data: dict) -> PlayerConfig: - return PlayerConfig.model_validate(data) - - @staticmethod - def read_keywords(infile) -> dict: - with open(infile, "r") as yml_file: - config = yaml.load(yml_file, Loader=yaml.SafeLoader) - - if "diceplayer" in config: - return config.get("diceplayer") - - raise RuntimeError(f"Could not find diceplayer section in {infile}.") - - @classmethod - def from_file(cls, infile: str) -> "Player": - return cls(infile=infile) - - @classmethod - def from_save(cls): - return cls(optimization=True) + def play(self, continuation = False): + ... \ No newline at end of file diff --git a/diceplayer/utils/__init__.py b/diceplayer/utils/__init__.py index ecd350f..2da9790 100644 --- a/diceplayer/utils/__init__.py +++ b/diceplayer/utils/__init__.py @@ -1,4 +1,4 @@ -from .logger import Logger, valid_logger +from .logger import RunLogger from .misc import ( compress_files_1mb, date_time, @@ -10,8 +10,7 @@ from .ptable import AtomInfo, PTable __all__ = [ - "Logger", - "valid_logger", + "RunLogger", "PTable", "AtomInfo", "weekday_date_time", diff --git a/diceplayer/utils/dataclass_protocol.py b/diceplayer/utils/dataclass_protocol.py deleted file mode 100644 index 967e7c1..0000000 --- a/diceplayer/utils/dataclass_protocol.py +++ /dev/null @@ -1,6 +0,0 @@ -from typing_extensions import Protocol, runtime_checkable - - -@runtime_checkable -class Dataclass(Protocol): - __dataclass_fields__: dict diff --git a/diceplayer/utils/logger.py b/diceplayer/utils/logger.py index 4d4806a..d18be42 100644 --- a/diceplayer/utils/logger.py +++ b/diceplayer/utils/logger.py @@ -1,72 +1,48 @@ import logging +import sys from pathlib import Path +from typing_extensions import TypeVar + +H = TypeVar('H', bound=logging.Handler) -def valid_logger(func): - def wrapper(*args, **kwargs): - logger = args[0] - assert logger._was_set, "Logger is not set. Please call set_logger() first." +class RunLogger(logging.Logger): + def __init__(self, name, level=logging.INFO, stream=sys.stdout): + super().__init__(name, level) - return func(*args, **kwargs) + self.handlers.clear() - return wrapper + self.handlers.append( + self._configure_handler(logging.StreamHandler(stream), level) + ) -class Logger: - outfile = None + def set_output_file(self, outfile: Path, level=logging.INFO): + for handler in list(self.handlers): + if not isinstance(handler, logging.FileHandler): + continue + self.handlers.remove(handler) - _logger = None + self.handlers.append( + self._create_file_handler(outfile, level) + ) - _was_set = False - def __init__(self, logger_name): - if self._logger is None: - self._logger = logging.getLogger(logger_name) - def set_logger(self, outfile="run.log", level=logging.INFO, stream=None): - outfile_path = None - if outfile is not None and stream is None: - outfile_path = Path(outfile) - if outfile_path.exists(): - outfile_path.rename(str(outfile_path) + ".backup") + @staticmethod + def _create_file_handler(file: str|Path, level) -> logging.FileHandler: + file = Path(file) - if level is not None: - self._logger.setLevel(level) + if file.exists(): + file.rename(file.with_suffix('.log.backup')) - self._create_handlers(outfile_path, stream) + handler = logging.FileHandler(file) + return RunLogger._configure_handler(handler, level) - self._was_set = True - @valid_logger - def info(self, message): - self._logger.info(message) - - @valid_logger - def debug(self, message): - self._logger.debug(message) - - @valid_logger - def warning(self, message): - self._logger.warning(message) - - @valid_logger - def error(self, message): - self._logger.error(message) - - def _create_handlers(self, outfile_path: Path, stream): - handlers = [] - if outfile_path is not None: - handlers.append(logging.FileHandler(outfile_path, mode="a+")) - elif stream is not None: - handlers.append(logging.StreamHandler(stream)) - else: - handlers.append(logging.StreamHandler()) - - for handler in handlers: - handler.setFormatter(logging.Formatter("%(message)s")) - self._logger.addHandler(handler) - - def close(self): - for handler in self._logger.handlers: - handler.close() - self._logger.removeHandler(handler) + @staticmethod + def _configure_handler(handler: H, level) -> H: + handler.setLevel(level) + formatter = logging.Formatter('%(message)s') + handler.setFormatter(formatter) + return handler \ No newline at end of file diff --git a/tests/shared/__init__.py b/tests/environment/__init__.py similarity index 100% rename from tests/shared/__init__.py rename to tests/environment/__init__.py diff --git a/tests/shared/environment/test_atom.py b/tests/environment/test_atom.py similarity index 100% rename from tests/shared/environment/test_atom.py rename to tests/environment/test_atom.py diff --git a/tests/shared/environment/test_molecule.py b/tests/environment/test_molecule.py similarity index 100% rename from tests/shared/environment/test_molecule.py rename to tests/environment/test_molecule.py diff --git a/tests/shared/environment/test_system.py b/tests/environment/test_system.py similarity index 100% rename from tests/shared/environment/test_system.py rename to tests/environment/test_system.py diff --git a/tests/shared/environment/__init__.py b/tests/shared/environment/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/shared/interface/__init__.py b/tests/shared/interface/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/shared/interface/test_dice_interface.py b/tests/shared/interface/test_dice_interface.py deleted file mode 100644 index 339a209..0000000 --- a/tests/shared/interface/test_dice_interface.py +++ /dev/null @@ -1,643 +0,0 @@ -from diceplayer import logger -from diceplayer.config.player_config import PlayerConfig -from diceplayer.environment import Atom, Molecule, System -from diceplayer.interface import DiceInterface -from tests.mocks.mock_inputs import get_config_example -from tests.mocks.mock_proc import MockConnection, MockProc - -import yaml - -import io -import unittest -from unittest import mock - - -class TestDiceInterface(unittest.TestCase): - def setUp(self): - logger.set_logger(stream=io.StringIO()) - - config = yaml.load(get_config_example(), Loader=yaml.Loader) - self.config = PlayerConfig.model_validate(config["diceplayer"]) - - def test_class_instantiation(self): - dice = DiceInterface() - - self.assertIsInstance(dice, DiceInterface) - - def test_configure(self): - dice = DiceInterface() - - self.assertIsNone(dice.step) - self.assertIsNone(dice.system) - - # Ignoring the types for testing purposes - dice.configure(self.config, System()) - - self.assertIsNotNone(dice.step) - self.assertIsNotNone(dice.system) - - def test_reset(self): - dice = DiceInterface() - - dice.configure(self.config, System()) - - self.assertTrue(hasattr(dice, "step")) - self.assertTrue(hasattr(dice, "system")) - - dice.reset() - - self.assertFalse(hasattr(dice, "step")) - self.assertFalse(hasattr(dice, "system")) - - @mock.patch("diceplayer.interface.dice_interface.Process", MockProc()) - @mock.patch("diceplayer.interface.dice_interface.connection", MockConnection) - def test_start(self): - dice = DiceInterface() - dice.configure(self.config, System()) - - dice.start(1) - - @mock.patch("diceplayer.interface.dice_interface.connection", MockConnection) - @mock.patch("diceplayer.interface.dice_interface.Process", MockProc(exitcode=1)) - def test_start_with_process_error(self): - dice = DiceInterface() - dice.configure(self.config, System()) - - with self.assertRaises(SystemExit): - dice.start(1) - - def test_simulation_process_raises_exception(self): - dice = DiceInterface() - - with self.assertRaises(SystemExit): - dice._simulation_process(1, 1) - - @mock.patch("diceplayer.interface.dice_interface.DiceInterface._make_proc_dir") - @mock.patch("diceplayer.interface.dice_interface.DiceInterface._make_dice_inputs") - @mock.patch("diceplayer.interface.dice_interface.DiceInterface._run_dice") - def test_simulation_process( - self, mock_run_dice, mock_make_dice_inputs, mock_make_proc_dir - ): - dice = DiceInterface() - - dice._simulation_process(1, 1) - - self.assertTrue(dice._make_proc_dir.called) - self.assertTrue(dice._make_dice_inputs.called) - self.assertTrue(dice._run_dice.called) - - @mock.patch("diceplayer.interface.dice_interface.Path.mkdir") - @mock.patch("diceplayer.interface.dice_interface.Path.exists") - def test_make_proc_dir_if_simdir_exists(self, mock_path_exists, mock_path_mkdir): - dice = DiceInterface() - dice.configure(self.config, System()) - - mock_path_exists.return_value = False - - dice._make_proc_dir(1, 1) - - self.assertEqual(mock_path_mkdir.call_count, 2) - - @mock.patch("diceplayer.interface.dice_interface.Path.mkdir") - @mock.patch("diceplayer.interface.dice_interface.Path.exists") - def test_make_proc_dir_if_simdir_doesnt_exists( - self, mock_path_exists, mock_path_mkdir - ): - dice = DiceInterface() - dice.configure(self.config, System()) - - mock_path_exists.return_value = False - - dice._make_proc_dir(1, 1) - - self.assertEqual(mock_path_mkdir.call_count, 2) - - def test_make_dice_seed(self): - seed = DiceInterface._make_dice_seed() - - self.assertIsInstance(seed, int) - - def test_make_dice_inputs_nstep_len_two_with_randoninit_first_cycle_one(self): - dice = DiceInterface() - dice.configure(self.config, System()) - - dice.step.dice.nstep = [1, 1] - - dice._make_potentials = mock.Mock() - - dice._make_init_file = mock.Mock() - dice._new_density = mock.Mock() - - dice._make_nvt_ter = mock.Mock() - dice._make_nvt_eq = mock.Mock() - dice._make_npt_ter = mock.Mock() - dice._make_npt_eq = mock.Mock() - - dice._make_dice_inputs(1, 1) - - self.assertTrue(dice._make_potentials.called) - - self.assertFalse(dice._make_init_file.called) - self.assertFalse(dice._new_density.called) - - self.assertTrue(dice._make_nvt_ter.called) - self.assertTrue(dice._make_nvt_eq.called) - - self.assertFalse(dice._make_npt_ter.called) - self.assertFalse(dice._make_npt_eq.called) - - @mock.patch("builtins.open", new_callable=mock.mock_open, read_data="test") - @mock.patch("diceplayer.interface.dice_interface.Path.exists", return_value=True) - def test_make_dice_inputs_nstep_len_two_with_randoninit_first_cycle_two( - self, mock_path_exists, mock_open - ): - dice = DiceInterface() - dice.configure(self.config, System()) - - dice.step.dice.nstep = [1, 1] - - dice._make_potentials = mock.Mock() - - dice._make_init_file = mock.Mock() - dice._new_density = mock.Mock() - - dice._make_nvt_ter = mock.Mock() - dice._make_nvt_eq = mock.Mock() - dice._make_npt_ter = mock.Mock() - dice._make_npt_eq = mock.Mock() - - dice._make_dice_inputs(2, 1) - - self.assertTrue(dice._make_potentials.called) - - self.assertTrue(dice._make_init_file.called) - self.assertTrue(dice._new_density.called) - - self.assertFalse(dice._make_nvt_ter.called) - self.assertTrue(dice._make_nvt_eq.called) - - self.assertFalse(dice._make_npt_ter.called) - self.assertFalse(dice._make_npt_eq.called) - - @mock.patch("diceplayer.interface.dice_interface.Path.exists", return_value=False) - def test_make_dice_inputs_raises_exception_on_last_not_found( - self, mock_path_exists - ): - dice = DiceInterface() - dice.configure(self.config, System()) - - dice.step.dice.nstep = [1, 1] - - dice._make_potentials = mock.Mock() - - dice._make_init_file = mock.Mock() - dice._new_density = mock.Mock() - - dice._make_nvt_ter = mock.Mock() - dice._make_nvt_eq = mock.Mock() - dice._make_npt_ter = mock.Mock() - dice._make_npt_eq = mock.Mock() - - with self.assertRaises(FileNotFoundError): - dice._make_dice_inputs(2, 1) - - def test_make_dice_inputs_nstep_len_three_with_randoninit_first_cycle_one(self): - dice = DiceInterface() - dice.configure(self.config, System()) - - dice._make_potentials = mock.Mock() - - dice._make_init_file = mock.Mock() - dice._new_density = mock.Mock() - - dice._make_nvt_ter = mock.Mock() - dice._make_nvt_eq = mock.Mock() - dice._make_npt_ter = mock.Mock() - dice._make_npt_eq = mock.Mock() - - dice._make_dice_inputs(1, 1) - - self.assertTrue(dice._make_potentials.called) - - self.assertFalse(dice._make_init_file.called) - self.assertFalse(dice._new_density.called) - - self.assertTrue(dice._make_nvt_ter.called) - self.assertFalse(dice._make_nvt_eq.called) - - self.assertTrue(dice._make_npt_ter.called) - self.assertTrue(dice._make_npt_eq.called) - - @mock.patch("diceplayer.interface.dice_interface.os") - @mock.patch("diceplayer.interface.dice_interface.shutil") - @mock.patch("diceplayer.interface.dice_interface.Path.exists", return_value=True) - def test_run_dice_on_first_cycle_run_successful( - self, mock_path_exists, mock_shutils, mock_os - ): - dice = DiceInterface() - dice.configure(self.config, System()) - - dice.step.dice.nstep = [1, 1, 1] - - dice.run_dice_file = mock.Mock() - - dice._run_dice(1, 1) - - self.assertTrue(mock_os.getcwd.called) - self.assertTrue(mock_os.chdir.called) - - self.assertEqual(dice.run_dice_file.call_count, 3) - self.assertTrue(mock_shutils.copy.called) - - dice = DiceInterface() - dice.configure(self.config, System()) - - dice.step.dice.nstep = [1, 1] - - dice.run_dice_file = mock.Mock() - - dice._run_dice(1, 1) - - self.assertTrue(mock_os.getcwd.called) - self.assertTrue(mock_os.chdir.called) - - self.assertEqual(dice.run_dice_file.call_count, 2) - self.assertTrue(mock_shutils.copy.called) - - @mock.patch("diceplayer.interface.dice_interface.os") - @mock.patch("diceplayer.interface.dice_interface.shutil") - @mock.patch("diceplayer.interface.dice_interface.Path.exists", return_value=True) - def test_run_dice_on_second_cycle_run_successful( - self, mock_path_exists, mock_shutils, mock_os - ): - dice = DiceInterface() - dice.configure(self.config, System()) - - dice.run_dice_file = mock.Mock() - - dice._run_dice(2, 1) - - self.assertTrue(mock_os.getcwd.called) - self.assertTrue(mock_os.chdir.called) - - self.assertEqual(dice.run_dice_file.call_count, 2) - self.assertTrue(mock_shutils.copy.called) - - dice = DiceInterface() - dice.configure(self.config, System()) - - dice.run_dice_file = mock.Mock() - - dice._run_dice(2, 1) - - self.assertTrue(mock_os.getcwd.called) - self.assertTrue(mock_os.chdir.called) - - self.assertEqual(dice.run_dice_file.call_count, 2) - self.assertTrue(mock_shutils.copy.called) - - @mock.patch("diceplayer.interface.dice_interface.os") - @mock.patch("diceplayer.interface.dice_interface.shutil") - @mock.patch("diceplayer.interface.dice_interface.Path.exists", return_value=False) - def test_run_dice_raises_filenotfound_on_invalid_file( - self, mock_path_exists, mock_shutils, mock_os - ): - dice = DiceInterface() - dice.configure(self.config, System()) - - dice.run_dice_file = mock.Mock() - - with self.assertRaises(FileNotFoundError): - dice._run_dice(1, 1) - - @mock.patch("builtins.open", new_callable=mock.mock_open) - def test_make_init_file(self, mock_open): - example_atom = Atom( - lbl=1, - na=1, - rx=1.0, - ry=1.0, - rz=1.0, - chg=1.0, - eps=1.0, - sig=1.0, - ) - - main_molecule = Molecule("main_molecule") - main_molecule.add_atom(example_atom) - - secondary_molecule = Molecule("secondary_molecule") - secondary_molecule.add_atom(example_atom) - - system = System() - system.add_type(main_molecule) - system.add_type(secondary_molecule) - - dice = DiceInterface() - dice.configure(self.config, system) - - dice.step.dice.nmol = [1, 1] - - last_xyz_file = io.StringIO() - last_xyz_file.writelines( - [ - " TEST\n", - " Configuration number : TEST = TEST TEST TEST\n", - " H 1.00000 1.00000 1.00000\n", - " H 1.00000 1.00000 1.00000\n", - ] - ) - last_xyz_file.seek(0) - - dice._make_init_file("test", last_xyz_file) - - mock_handler = mock_open() - calls = mock_handler.write.call_args_list - - lines = list(map(lambda x: x[0][0], calls)) - - expected_lines = [ - " 1.000000 1.000000 1.000000\n", - " 1.000000 1.000000 1.000000\n", - "$end", - ] - - self.assertEqual(lines, expected_lines) - - @mock.patch("builtins.open", new_callable=mock.mock_open) - def test_new_density(self, mock_open): - example_atom = Atom( - lbl=1, - na=1, - rx=1.0, - ry=1.0, - rz=1.0, - chg=1.0, - eps=1.0, - sig=1.0, - ) - - main_molecule = Molecule("main_molecule") - main_molecule.add_atom(example_atom) - - secondary_molecule = Molecule("secondary_molecule") - secondary_molecule.add_atom(example_atom) - - system = System() - system.add_type(main_molecule) - system.add_type(secondary_molecule) - - dice = DiceInterface() - dice.configure(self.config, system) - - last_xyz_file = io.StringIO() - last_xyz_file.writelines( - [ - " TEST\n", - " Configuration number : TEST = 1 1 1\n", - " H 1.00000 1.00000 1.00000\n", - " H 1.00000 1.00000 1.00000\n", - ] - ) - last_xyz_file.seek(0) - - density = dice._new_density(last_xyz_file) - - self.assertEqual(density, 85.35451545000001) - - @mock.patch("builtins.open", new_callable=mock.mock_open) - @mock.patch("diceplayer.interface.dice_interface.random") - def test_make_nvt_ter(self, mock_random, mock_open): - mock_random.random.return_value = 1 - - dice = DiceInterface() - dice.configure(self.config, System()) - - dice._make_nvt_ter(1, "test") - - mock_handler = mock_open() - calls = mock_handler.write.call_args_list - - lines = list(map(lambda x: x[0][0], calls)) - - expected_lines = [ - "title = Diceplayer run - NVT Thermalization\n", - "ncores = 4\n", - "ljname = phb.ljc\n", - "outname = phb\n", - "nmol = 1 50\n", - "dens = 0.75\n", - "temp = 300.0\n", - "init = yes\n", - "nstep = 2000\n", - "vstep = 0\n", - "mstop = 1\n", - "accum = no\n", - "iprint = 1\n", - "isave = 0\n", - "irdf = 0\n", - "seed = 1000000\n", - "upbuf = 360", - ] - - self.assertEqual(lines, expected_lines) - - @mock.patch("builtins.open", new_callable=mock.mock_open) - @mock.patch("diceplayer.interface.dice_interface.random") - def test_make_nvt_eq(self, mock_random, mock_open): - mock_random.random.return_value = 1 - - dice = DiceInterface() - dice.configure(self.config, System()) - - dice._make_nvt_eq(1, "test") - - mock_handler = mock_open() - calls = mock_handler.write.call_args_list - - lines = list(map(lambda x: x[0][0], calls)) - - expected_lines = [ - "title = Diceplayer run - NVT Production\n", - "ncores = 4\n", - "ljname = phb.ljc\n", - "outname = phb\n", - "nmol = 1 50\n", - "dens = 0.75\n", - "temp = 300.0\n", - "init = no\n", - "nstep = 3000\n", - "vstep = 0\n", - "mstop = 1\n", - "accum = no\n", - "iprint = 1\n", - "isave = 1000\n", - "irdf = 40\n", - "seed = 1000000\n", - ] - - self.assertEqual(lines, expected_lines) - - @mock.patch("builtins.open", new_callable=mock.mock_open) - @mock.patch("diceplayer.interface.dice_interface.random") - def test_make_npt_ter(self, mock_random, mock_open): - mock_random.random.return_value = 1 - - dice = DiceInterface() - dice.configure(self.config, System()) - - dice._make_npt_ter(1, "test") - - mock_handler = mock_open() - calls = mock_handler.write.call_args_list - - lines = list(map(lambda x: x[0][0], calls)) - - expected_lines = [ - "title = Diceplayer run - NPT Thermalization\n", - "ncores = 4\n", - "ljname = phb.ljc\n", - "outname = phb\n", - "nmol = 1 50\n", - "press = 1.0\n", - "temp = 300.0\n", - "init = no\n", - "vstep = 600\n", - "nstep = 5\n", - "mstop = 1\n", - "accum = no\n", - "iprint = 1\n", - "isave = 0\n", - "irdf = 0\n", - "seed = 1000000\n", - ] - - self.assertEqual(lines, expected_lines) - - @mock.patch("builtins.open", new_callable=mock.mock_open) - @mock.patch("diceplayer.interface.dice_interface.random") - def test_make_npt_eq(self, mock_random, mock_open): - mock_random.random.return_value = 1 - - dice = DiceInterface() - dice.configure(self.config, System()) - - dice._make_npt_eq("test") - - mock_handler = mock_open() - calls = mock_handler.write.call_args_list - - lines = list(map(lambda x: x[0][0], calls)) - - expected_lines = [ - "title = Diceplayer run - NPT Production\n", - "ncores = 4\n", - "ljname = phb.ljc\n", - "outname = phb\n", - "nmol = 1 50\n", - "press = 1.0\n", - "temp = 300.0\n", - "nstep = 5\n", - "vstep = 800\n", - "init = no\n", - "mstop = 1\n", - "accum = no\n", - "iprint = 1\n", - "isave = 1000\n", - "irdf = 40\n", - "seed = 1000000\n", - ] - - self.assertEqual(lines, expected_lines) - - @mock.patch("builtins.open", new_callable=mock.mock_open) - def test_make_potentials(self, mock_open): - example_atom = Atom( - lbl=1, - na=1, - rx=1.0, - ry=1.0, - rz=1.0, - chg=1.0, - eps=1.0, - sig=1.0, - ) - - main_molecule = Molecule("main_molecule") - main_molecule.add_atom(example_atom) - - secondary_molecule = Molecule("secondary_molecule") - secondary_molecule.add_atom(example_atom) - - system = System() - system.add_type(main_molecule) - system.add_type(secondary_molecule) - - dice = DiceInterface() - dice.configure(self.config, system) - - dice._make_potentials("test") - - mock_handler = mock_open() - calls = mock_handler.write.call_args_list - - lines = list(map(lambda x: x[0][0], calls)) - - expected_lines = [ - "*\n", - "2\n", - "1 main_molecule\n", - "1 1 1.00000 1.00000 1.00000 1.000000 1.00000 1.0000\n", - "1 secondary_molecule\n", - "1 1 1.00000 1.00000 1.00000 1.000000 1.00000 1.0000\n", - ] - - self.assertEqual(lines, expected_lines) - - @mock.patch("diceplayer.interface.dice_interface.subprocess") - @mock.patch( - "builtins.open", - new_callable=mock.mock_open, - read_data="End of simulation\nBLABLA", - ) - def test_run_dice_file(self, mock_open, mock_subprocess): - mock_subprocess.call.return_value = 0 - dice = DiceInterface() - dice.configure(self.config, System()) - - dice.run_dice_file(1, 1, "test") - - self.assertTrue(mock_subprocess.call.called) - self.assertTrue(mock_open.called) - - @mock.patch("diceplayer.interface.dice_interface.subprocess") - @mock.patch("builtins.open", new_callable=mock.mock_open, read_data="Error\nBLABLA") - def test_run_dice_file_raises_runtime_error_on_dice_file( - self, mock_open, mock_subprocess - ): - mock_subprocess.call.return_value = 0 - dice = DiceInterface() - dice.configure(self.config, System()) - - with self.assertRaises(RuntimeError): - dice.run_dice_file(1, 1, "test") - - @mock.patch("diceplayer.interface.dice_interface.subprocess") - @mock.patch( - "builtins.open", - new_callable=mock.mock_open, - read_data="End of simulation\nBLABLA", - ) - def test_run_dice_file_raises_runtime_error_of_dice_exit_code( - self, mock_open, mock_subprocess - ): - mock_subprocess.call.return_value = 1 - dice = DiceInterface() - dice.configure(self.config, System()) - - with self.assertRaises(RuntimeError): - dice.run_dice_file(1, 1, "test") - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/shared/interface/test_gaussian_interface.py b/tests/shared/interface/test_gaussian_interface.py deleted file mode 100644 index 3ca5381..0000000 --- a/tests/shared/interface/test_gaussian_interface.py +++ /dev/null @@ -1,115 +0,0 @@ -from diceplayer import logger -from diceplayer.config.player_config import PlayerConfig -from diceplayer.environment import System -from diceplayer.interface import GaussianInterface -from tests.mocks.mock_inputs import get_config_example - -import yaml - -import io -import unittest -from unittest import mock - - -class TestGaussianInterface(unittest.TestCase): - def setUp(self) -> None: - logger.set_logger(stream=io.StringIO()) - - config = yaml.load(get_config_example(), Loader=yaml.Loader) - self.config = PlayerConfig.model_validate(config["diceplayer"]) - - def test_class_instantiation(self): - gaussian_interface = GaussianInterface() - self.assertIsInstance(gaussian_interface, GaussianInterface) - - def test_configure(self): - gaussian_interface = GaussianInterface() - - self.assertIsNone(gaussian_interface.step) - self.assertIsNone(gaussian_interface.system) - - gaussian_interface.configure(self.config, System()) - - self.assertIsNotNone(gaussian_interface.step) - self.assertIsNotNone(gaussian_interface.system) - - def test_reset(self): - gaussian_interface = GaussianInterface() - - gaussian_interface.configure(self.config, System()) - - self.assertIsNotNone(gaussian_interface.step) - self.assertIsNotNone(gaussian_interface.system) - - gaussian_interface.reset() - - self.assertFalse(hasattr(gaussian_interface, "step")) - self.assertFalse(hasattr(gaussian_interface, "system")) - - @mock.patch("diceplayer.interface.gaussian_interface.Path.mkdir") - @mock.patch("diceplayer.interface.gaussian_interface.Path.exists") - def test_make_qm_dir(self, mock_exists, mock_mkdir): - mock_exists.return_value = False - - gaussian_interface = GaussianInterface() - gaussian_interface.configure(self.config, System()) - - gaussian_interface._make_qm_dir(1) - - mock_exists.assert_called_once() - mock_mkdir.assert_called_once() - - @mock.patch("diceplayer.interface.gaussian_interface.shutil.copy") - @mock.patch("diceplayer.interface.gaussian_interface.Path.exists") - def test_copy_chk_file_from_previous_step(self, mock_exists, mock_copy): - gaussian_interface = GaussianInterface() - gaussian_interface.configure(self.config, System()) - - mock_exists.side_effect = [False, True] - - gaussian_interface._copy_chk_file_from_previous_step(2) - - self.assertTrue(mock_exists.called) - self.assertTrue(mock_copy.called) - - @mock.patch("diceplayer.interface.gaussian_interface.shutil.copy") - @mock.patch("diceplayer.interface.gaussian_interface.Path.exists") - def test_copy_chk_file_from_previous_step_no_previous_step( - self, mock_exists, mock_copy - ): - gaussian_interface = GaussianInterface() - gaussian_interface.configure(self.config, System()) - - mock_exists.side_effect = [False, False] - - with self.assertRaises(FileNotFoundError): - gaussian_interface._copy_chk_file_from_previous_step(2) - - @mock.patch("diceplayer.interface.gaussian_interface.shutil.copy") - @mock.patch("diceplayer.interface.gaussian_interface.Path.exists") - def test_copy_chk_file_from_previous_step_current_exists( - self, mock_exists, mock_copy - ): - gaussian_interface = GaussianInterface() - gaussian_interface.configure(self.config, System()) - - mock_exists.side_effect = [True, True] - - with self.assertRaises(FileExistsError): - gaussian_interface._copy_chk_file_from_previous_step(2) - - # def test_start(self): - # gaussian_interface = GaussianInterface() - # gaussian_interface.configure(self.config, System()) - # - # gaussian_interface._make_qm_dir = mock.Mock() - # gaussian_interface._copy_chk_file_from_previous_step = mock.Mock() - # - # gaussian_interface.start(2) - # - # gaussian_interface._make_qm_dir.assert_called_once_with(2) - # gaussian_interface._copy_chk_file_from_previous_step.assert_called_once_with(2) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/shared/utils/__init__.py b/tests/shared/utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/shared/utils/test_logger.py b/tests/shared/utils/test_logger.py deleted file mode 100644 index b3bba51..0000000 --- a/tests/shared/utils/test_logger.py +++ /dev/null @@ -1,132 +0,0 @@ -from diceplayer.utils import Logger, valid_logger - -import io -import logging -import unittest -from unittest import mock - - -class TestValidateLogger(unittest.TestCase): - def test_validate_logger(self): - class MockLogger: - _was_set = True - - @valid_logger - def test_func(self): - pass - - MockLogger().test_func() - - def test_validate_logger_exception(self): - class MockLogger: - _was_set = False - - @valid_logger - def test_func(self): - pass - - with self.assertRaises(AssertionError): - MockLogger().test_func() - - -class TestLogger(unittest.TestCase): - def test_class_instantiation(self): - logger = Logger("test") - - self.assertIsInstance(logger, Logger) - - @mock.patch("builtins.open", mock.mock_open()) - def test_set_logger_to_file(self): - logger = Logger("test") - - logger.set_logger(stream=io.StringIO()) - - self.assertIsNotNone(logger._logger) - self.assertEqual(logger._logger.name, "test") - - def test_set_logger_to_stream(self): - logger = Logger("test") - - logger.set_logger(stream=io.StringIO()) - - self.assertIsNotNone(logger._logger) - self.assertEqual(logger._logger.name, "test") - - @mock.patch("builtins.open", mock.mock_open()) - @mock.patch("diceplayer.utils.logger.Path.exists") - @mock.patch("diceplayer.utils.logger.Path.rename") - def test_set_logger_if_file_exists(self, mock_rename, mock_exists): - logger = Logger("test") - - mock_exists.return_value = True - logger.set_logger() - - self.assertTrue(mock_rename.called) - self.assertIsNotNone(logger._logger) - self.assertEqual(logger._logger.name, "test") - - @mock.patch("builtins.open", mock.mock_open()) - @mock.patch("diceplayer.utils.logger.Path.exists") - @mock.patch("diceplayer.utils.logger.Path.rename") - def test_set_logger_if_file_not_exists(self, mock_rename, mock_exists): - logger = Logger("test") - - mock_exists.return_value = False - logger.set_logger() - - self.assertFalse(mock_rename.called) - self.assertIsNotNone(logger._logger) - self.assertEqual(logger._logger.name, "test") - - @mock.patch("builtins.open", mock.mock_open()) - def test_close(self): - logger = Logger("test") - - logger.set_logger() - logger.close() - - self.assertEqual(len(logger._logger.handlers), 0) - - @mock.patch("builtins.open", mock.mock_open()) - def test_info(self): - logger = Logger("test") - logger.set_logger(stream=io.StringIO()) - - with self.assertLogs(level="INFO") as cm: - logger.info("test") - - self.assertEqual(cm.output, ["INFO:test:test"]) - - @mock.patch("builtins.open", mock.mock_open()) - def test_debug(self): - logger = Logger("test") - logger.set_logger(stream=io.StringIO(), level=logging.DEBUG) - - with self.assertLogs(level="DEBUG") as cm: - logger.debug("test") - - self.assertEqual(cm.output, ["DEBUG:test:test"]) - - @mock.patch("builtins.open", mock.mock_open()) - def test_warning(self): - logger = Logger("test") - logger.set_logger(stream=io.StringIO()) - - with self.assertLogs(level="WARNING") as cm: - logger.warning("test") - - self.assertEqual(cm.output, ["WARNING:test:test"]) - - @mock.patch("builtins.open", mock.mock_open()) - def test_error(self): - logger = Logger("test") - logger.set_logger(stream=io.StringIO()) - - with self.assertLogs(level="ERROR") as cm: - logger.error("test") - - self.assertEqual(cm.output, ["ERROR:test:test"]) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_player.py b/tests/test_player.py deleted file mode 100644 index 137c2f8..0000000 --- a/tests/test_player.py +++ /dev/null @@ -1,396 +0,0 @@ -from diceplayer import logger -from diceplayer.player import Player -from tests.mocks.mock_inputs import mock_open - -import io -import unittest -from unittest import mock - - -class TestPlayer(unittest.TestCase): - def setUp(self): - logger.set_logger(stream=io.StringIO()) - - @mock.patch("builtins.open", mock_open) - def test_class_instantiation(self): - # This file does not exist and it will be mocked - player = Player("control.test.yml") - - self.assertIsInstance(player, Player) - - @mock.patch("builtins.open", mock_open) - def test_start(self): - player = Player("control.test.yml") - - player.gaussian_start = mock.MagicMock() - player.dice_start = mock.MagicMock() - - player.start() - - self.assertEqual(player.dice_start.call_count, 3) - self.assertEqual(player.gaussian_start.call_count, 3) - - @mock.patch("builtins.open", mock_open) - @mock.patch("diceplayer.player.Path") - def test_create_simulation_dir_if_already_exists(self, mock_path): - player = Player("control.test.yml") - mock_path.return_value.exists.return_value = True - - with self.assertRaises(FileExistsError): - player.create_simulation_dir() - - self.assertTrue(mock_path.called) - - @mock.patch("builtins.open", mock_open) - @mock.patch("diceplayer.player.Path") - def test_create_simulation_dir_if_not_exists(self, mock_path): - player = Player("control.test.yml") - mock_path.return_value.exists.return_value = False - - player.create_simulation_dir() - - self.assertTrue(mock_path.called) - - @mock.patch("builtins.open", mock_open) - @mock.patch("diceplayer.player.VERSION", "test") - @mock.patch("diceplayer.player.sys") - @mock.patch("diceplayer.player.weekday_date_time") - def test_print_keywords(self, mock_date_func, mock_sys): - player = Player("control.test.yml") - - mock_sys.version = "TEST" - mock_date_func.return_value = "00 Test 0000 at 00:00:00" - - with self.assertLogs() as cm: - player.print_keywords() - - expected_output = [ - "INFO:diceplayer:##########################################################################################\n############# Welcome to DICEPLAYER version test #############\n##########################################################################################\n", - "INFO:diceplayer:Your python version is TEST\n", - "INFO:diceplayer:Program started on 00 Test 0000 at 00:00:00\n", - "INFO:diceplayer:Environment variables:", - "INFO:diceplayer:OMP_STACKSIZE = Not set\n", - "INFO:diceplayer:------------------------------------------------------------------------------------------\n DICE variables being used in this run:\n------------------------------------------------------------------------------------------\n", - "INFO:diceplayer:combrule = *", - "INFO:diceplayer:dens = 0.75", - "INFO:diceplayer:isave = 1000", - "INFO:diceplayer:ljname = phb.ljc", - "INFO:diceplayer:nmol = [ 1 50 ]", - "INFO:diceplayer:nstep = [ 2000 3000 4000 ]", - "INFO:diceplayer:outname = phb", - "INFO:diceplayer:press = 1.0", - "INFO:diceplayer:progname = ~/.local/bin/dice", - "INFO:diceplayer:randominit = first", - "INFO:diceplayer:temp = 300.0", - "INFO:diceplayer:upbuf = 360", - "INFO:diceplayer:------------------------------------------------------------------------------------------\n GAUSSIAN variables being used in this run:\n------------------------------------------------------------------------------------------\n", - "INFO:diceplayer:chg_tol = 0.01", - "INFO:diceplayer:chgmult = [ 0 1 ]", - "INFO:diceplayer:keywords = freq", - "INFO:diceplayer:level = MP2/aug-cc-pVDZ", - "INFO:diceplayer:pop = chelpg", - "INFO:diceplayer:qmprog = g16", - "INFO:diceplayer:\n", - ] - - self.assertEqual(cm.output, expected_output) - - def test_validate_atom_dict(self): - with self.assertRaises(ValueError) as context: - Player.validate_atom_dict( - molecule_type=0, - molecule_site=0, - atom_dict={ - "lbl": 0, - "na": 1, - "rx": 1.0, - "ry": 1.0, - "rz": 1.0, - "chg": 1.0, - "eps": 1.0, - }, - ) - self.assertEqual( - str(context.exception), - "Invalid number of fields for site 1 for molecule type 1.", - ) - - with self.assertRaises(ValueError) as context: - Player.validate_atom_dict( - molecule_type=0, - molecule_site=0, - atom_dict={ - "lbl": "", - "na": 1, - "rx": 1.0, - "ry": 1.0, - "rz": 1.0, - "chg": 1.0, - "eps": 1.0, - "sig": 1.0, - }, - ) - self.assertEqual( - str(context.exception), "Invalid lbl fields for site 1 for molecule type 1." - ) - - with self.assertRaises(ValueError) as context: - Player.validate_atom_dict( - molecule_type=0, - molecule_site=0, - atom_dict={ - "lbl": 1.0, - "na": "", - "rx": 1.0, - "ry": 1.0, - "rz": 1.0, - "chg": 1.0, - "eps": 1.0, - "sig": 1.0, - }, - ) - self.assertEqual( - str(context.exception), "Invalid na fields for site 1 for molecule type 1." - ) - - with self.assertRaises(ValueError) as context: - Player.validate_atom_dict( - molecule_type=0, - molecule_site=0, - atom_dict={ - "lbl": 1.0, - "na": 1, - "rx": "", - "ry": 1.0, - "rz": 1.0, - "chg": 1.0, - "eps": 1.0, - "sig": 1.0, - }, - ) - self.assertEqual( - str(context.exception), - "Invalid rx fields for site 1 for molecule type 1. Value must be a float.", - ) - - with self.assertRaises(ValueError) as context: - Player.validate_atom_dict( - molecule_type=0, - molecule_site=0, - atom_dict={ - "lbl": 1.0, - "na": 1, - "rx": 1.0, - "ry": "", - "rz": 1.0, - "chg": 1.0, - "eps": 1.0, - "sig": 1.0, - }, - ) - self.assertEqual( - str(context.exception), - "Invalid ry fields for site 1 for molecule type 1. Value must be a float.", - ) - - with self.assertRaises(ValueError) as context: - Player.validate_atom_dict( - molecule_type=0, - molecule_site=0, - atom_dict={ - "lbl": 1.0, - "na": 1, - "rx": 1.0, - "ry": 1.0, - "rz": "", - "chg": 1.0, - "eps": 1.0, - "sig": 1.0, - }, - ) - self.assertEqual( - str(context.exception), - "Invalid rz fields for site 1 for molecule type 1. Value must be a float.", - ) - - with self.assertRaises(ValueError) as context: - Player.validate_atom_dict( - molecule_type=0, - molecule_site=0, - atom_dict={ - "lbl": 1.0, - "na": 1, - "rx": 1.0, - "ry": 1.0, - "rz": 1.0, - "chg": "", - "eps": 1.0, - "sig": 1.0, - }, - ) - self.assertEqual( - str(context.exception), - "Invalid chg fields for site 1 for molecule type 1. Value must be a float.", - ) - - with self.assertRaises(ValueError) as context: - Player.validate_atom_dict( - molecule_type=0, - molecule_site=0, - atom_dict={ - "lbl": 1.0, - "na": 1, - "rx": 1.0, - "ry": 1.0, - "rz": 1.0, - "chg": 1.0, - "eps": "", - "sig": 1.0, - }, - ) - self.assertEqual( - str(context.exception), - "Invalid eps fields for site 1 for molecule type 1. Value must be a float.", - ) - - with self.assertRaises(ValueError) as context: - Player.validate_atom_dict( - molecule_type=0, - molecule_site=0, - atom_dict={ - "lbl": 1.0, - "na": 1, - "rx": 1.0, - "ry": 1.0, - "rz": 1.0, - "chg": 1.0, - "eps": 1.0, - "sig": "", - }, - ) - self.assertEqual( - str(context.exception), - "Invalid sig fields for site 1 for molecule type 1. Value must be a float.", - ) - - @mock.patch("builtins.open", mock_open) - @mock.patch("diceplayer.player.Path.exists", return_value=True) - def test_read_potentials(self, mock_path_exists): - player = Player("control.test.yml") - - player.read_potentials() - - self.assertEqual(player.system.molecule[0].molname, "TEST") - self.assertEqual(len(player.system.molecule[0].atom), 1) - - self.assertEqual(player.system.molecule[1].molname, "PLACEHOLDER") - self.assertEqual(len(player.system.molecule[1].atom), 1) - - @mock.patch("builtins.open", mock_open) - @mock.patch("diceplayer.player.Path.exists") - def test_read_potentials_error(self, mock_path_exists): - player = Player("control.test.yml") - - # Testing file not found error - mock_path_exists.return_value = False - with self.assertRaises(RuntimeError) as context: - player.read_potentials() - - self.assertEqual(str(context.exception), "Potential file phb.ljc not found.") - - # Enabling file found for next tests - mock_path_exists.return_value = True - - # Testing combrule error - with self.assertRaises(SystemExit) as context: - player.config.dice.ljname = "phb.error.combrule.ljc" - player.read_potentials() - - self.assertEqual( - str(context.exception), - "Error: expected a '*' or a '+' sign in 1st line of file phb.error.combrule.ljc", - ) - - # Testing ntypes error - with self.assertRaises(SystemExit) as context: - player.config.dice.ljname = "phb.error.ntypes.ljc" - player.read_potentials() - - self.assertEqual( - str(context.exception), - "Error: expected an integer in the 2nd line of file phb.error.ntypes.ljc", - ) - - # Testing ntypes error on config - with self.assertRaises(SystemExit) as context: - player.config.dice.ljname = "phb.error.ntypes.config.ljc" - player.read_potentials() - - self.assertEqual( - str(context.exception), - "Error: number of molecule types in file phb.error.ntypes.config.ljc " - "must match that of 'nmol' keyword in config file", - ) - - # Testing nsite error - with self.assertRaises(ValueError) as context: - player.config.dice.ljname = "phb.error.nsites.ljc" - player.read_potentials() - - self.assertEqual( - str(context.exception), - "Error: expected nsites to be an integer for molecule type 1", - ) - - # Testing molname error - with self.assertRaises(ValueError) as context: - player.config.dice.ljname = "phb.error.molname.ljc" - player.read_potentials() - - self.assertEqual( - str(context.exception), - "Error: expected nsites and molname for the molecule type 1", - ) - - @mock.patch("builtins.open", mock_open) - @mock.patch("diceplayer.player.Path.exists", return_value=True) - def test_print_potentials(self, mock_path_exists): - player = Player("control.test.yml") - player.read_potentials() - - with self.assertLogs(level="INFO") as context: - player.print_potentials() - - expected_output = [ - "INFO:diceplayer:==========================================================================================\n Potential parameters from file phb.ljc:\n------------------------------------------------------------------------------------------\n", - "INFO:diceplayer:Combination rule: *", - "INFO:diceplayer:Types of molecules: 2\n", - "INFO:diceplayer:1 atoms in molecule type 1:", - "INFO:diceplayer:---------------------------------------------------------------------------------", - "INFO:diceplayer:Lbl AN X Y Z Charge Epsilon Sigma Mass", - "INFO:diceplayer:---------------------------------------------------------------------------------", - "INFO:diceplayer:1 1 0.00000 0.00000 0.00000 0.000000 0.00000 0.0000 1.0079", - "INFO:diceplayer:\n", - "INFO:diceplayer:1 atoms in molecule type 2:", - "INFO:diceplayer:---------------------------------------------------------------------------------", - "INFO:diceplayer:Lbl AN X Y Z Charge Epsilon Sigma Mass", - "INFO:diceplayer:---------------------------------------------------------------------------------", - "INFO:diceplayer:1 1 0.00000 0.00000 0.00000 0.000000 0.00000 0.0000 1.0079", - "INFO:diceplayer:\n", - ] - - self.assertEqual(context.output, expected_output) - - @mock.patch("builtins.open", mock_open) - def test_dice_start(self): - player = Player("control.test.yml") - player.dice_interface = mock.MagicMock() - player.dice_interface.start = mock.MagicMock() - - player.dice_start(1) - - player.dice_interface.start.assert_called_once() - - -if __name__ == "__main__": - unittest.main()