diff --git a/control.example.yml b/control.example.yml index 99f5ae4..b664d27 100644 --- a/control.example.yml +++ b/control.example.yml @@ -6,8 +6,8 @@ diceplayer: qmprog: 'g16' lps: no ghosts: no - altsteps: 20000 + dice: nmol: [1, 50] dens: 0.75 diff --git a/diceplayer/player.py b/diceplayer/player.py index 2312d90..25c10a5 100644 --- a/diceplayer/player.py +++ b/diceplayer/player.py @@ -9,7 +9,7 @@ from diceplayer.shared.environment.system import System from diceplayer.shared.config.step_dto import StepDTO from diceplayer.shared.config.dice_dto import DiceDTO from diceplayer.shared.environment.atom import Atom -from diceplayer.shared.utils.ptable import atommass +from diceplayer import logger from dataclasses import fields from pathlib import Path @@ -24,13 +24,6 @@ ENV = ["OMP_STACKSIZE"] class Player: - __slots__ = [ - 'config', - 'system', - 'dice', - 'gaussian', - ] - def __init__(self, infile: str): config_data = self.read_keywords(infile) @@ -43,16 +36,16 @@ class Player: self.gaussian = GaussianInterface(config_data.get("gaussian")) self.dice = DiceInterface(config_data.get("dice")) - def start(self): + def start(self, initial_cycle: int = 1): self.print_keywords() self.create_simulation_dir() self.read_potentials() - # self.print_potentials() + self.print_potentials() - self.dice_start(1) - self.dice_start(2) + for cycle in range(initial_cycle, self.config.maxcyc + 1): + self.dice_start(cycle) def create_simulation_dir(self): simulation_dir_path = Path(self.config.simulation_dir) @@ -61,12 +54,7 @@ class Player: f"Error: a file or a directory {self.config.simulation_dir} already exists," f" move or delete the simfiles directory to continue." ) - try: - simulation_dir_path.mkdir() - except FileExistsError: - OSError( - f"Error: cannot make directory {self.config.simulation_dir}" - ) + simulation_dir_path.mkdir() def print_keywords(self) -> None: @@ -75,38 +63,38 @@ class Player: if getattr(config, key) is not None: if isinstance(getattr(config, key), list): string = " ".join(str(x) for x in getattr(config, key)) - logging.info(f"{key} = [ {string} ]") + logger.info(f"{key} = [ {string} ]") else: - logging.info(f"{key} = {getattr(config, key)}") + logger.info(f"{key} = {getattr(config, key)}") - logging.info( + logger.info( "##########################################################################################\n" "############# Welcome to DICEPLAYER version 1.0 #############\n" "##########################################################################################\n" "\n" ) - logging.info("Your python version is {}\n".format(sys.version)) - logging.info("\n") - logging.info("Program started on {}\n".format(weekday_date_time())) - logging.info("\n") - logging.info("Environment variables:\n") + logger.info("Your python version is {}\n".format(sys.version)) + logger.info("\n") + logger.info("Program started on {}\n".format(weekday_date_time())) + logger.info("\n") + logger.info("Environment variables:\n") for var in ENV: - logging.info( + logger.info( "{} = {}\n".format( var, (os.environ[var] if var in os.environ else "Not set") ) ) - logging.info( + logger.info( "\n==========================================================================================\n" " CONTROL variables being used in this run:\n" "------------------------------------------------------------------------------------------\n" "\n" ) - logging.info("\n") + logger.info("\n") - logging.info( + logger.info( "------------------------------------------------------------------------------------------\n" " DICE variables being used in this run:\n" "------------------------------------------------------------------------------------------\n" @@ -115,9 +103,9 @@ class Player: log_keywords(self.dice.config, DiceDTO) - logging.info("\n") + logger.info("\n") - logging.info( + logger.info( "------------------------------------------------------------------------------------------\n" " GAUSSIAN variables being used in this run:\n" "------------------------------------------------------------------------------------------\n" @@ -126,18 +114,19 @@ class Player: log_keywords(self.gaussian.config, GaussianDTO) - logging.info("\n") + logger.info("\n") def read_potentials(self): - try: + ljname_path = Path(self.dice.config.ljname) + if ljname_path.exists(): with open(self.dice.config.ljname) as file: - ljdata = file.readlines() - except FileNotFoundError: + ljc_data = file.readlines() + else: raise RuntimeError( f"Potential file {self.dice.config.ljname} not found." ) - combrule = ljdata.pop(0).split()[0] + combrule = ljc_data.pop(0).split()[0] if combrule not in ("*", "+"): sys.exit( "Error: expected a '*' or a '+' sign in 1st line of file {}".format( @@ -146,7 +135,7 @@ class Player: ) self.dice.config.combrule = combrule - ntypes = ljdata.pop(0).split()[0] + ntypes = ljc_data.pop(0).split()[0] if not ntypes.isdigit(): sys.exit( "Error: expected an integer in the 2nd line of file {}".format( @@ -157,22 +146,22 @@ class Player: if ntypes != len(self.dice.config.nmol): sys.exit( - f"Error: number of molecule types in file {self.dice.config.ljname}" + f"Error: number of molecule types in file {self.dice.config.ljname} " f"must match that of 'nmol' keyword in config file" ) for i in range(ntypes): - nsites, molname = ljdata.pop(0).split()[:2] + try: + nsites, molname = ljc_data.pop(0).split()[:2] + except ValueError: + raise ValueError( + f"Error: expected nsites and molname for the molecule type {i+1}" + ) if not nsites.isdigit(): raise ValueError( - f"Error: expected nsites to be an integer for molecule type {i}" - ) - - if molname is None: - raise ValueError( - f"Error: expected molecule name for molecule type {i}" + f"Error: expected nsites to be an integer for molecule type {i+1}" ) nsites = int(nsites) @@ -182,12 +171,68 @@ class Player: for j in range(nsites): new_atom = dict(zip( atom_fields, - ljdata.pop(0).split() + ljc_data.pop(0).split() )) self.system.molecule[i].add_atom( Atom(**self.validate_atom_dict(i, j, new_atom)) ) + def print_potentials(self) -> None: + + formatstr = "{:<3d} {:>3d} {:>10.5f} {:>10.5f} {:>10.5f} {:>10.6f} {:>9.5f} {:>7.4f} {:>9.4f}" + logger.info( + "==========================================================================================\n" + ) + logger.info( + f" Potential parameters from file {self.dice.config.ljname}:" + ) + logger.info( + "------------------------------------------------------------------------------------------" + "\n" + ) + + logger.info(f"Combination rule: {self.dice.config.combrule}") + logger.info( + f"Types of molecules: {len(self.system.molecule)}\n" + ) + + i = 0 + for mol in self.system.molecule: + i += 1 + logger.info( + "{} atoms in molecule type {}:".format(len(mol.atom), i) + ) + logger.info( + "---------------------------------------------------------------------------------" + ) + logger.info( + "Lbl AN X Y Z Charge Epsilon Sigma Mass" + ) + logger.info( + "---------------------------------------------------------------------------------" + ) + + for atom in mol.atom: + logger.info( + formatstr.format( + atom.lbl, + atom.na, + atom.rx, + atom.ry, + atom.rz, + atom.chg, + atom.eps, + atom.sig, + atom.mass, + ) + ) + + logger.info("\n") + + logger.info( + "==========================================================================================" + ) + def dice_start(self, cycle: int): self.dice.configure( StepDTO( @@ -204,8 +249,8 @@ class Player: self.dice.reset() - def gaussian_start(self): - self.gaussian.start() + def gaussian_start(self, cycle: int): + self.gaussian.start(cycle) @staticmethod def validate_atom_dict(molecule_type, molecule_site, atom_dict: dict) -> dict: @@ -219,63 +264,63 @@ class Player: try: atom_dict['lbl'] = int(atom_dict['lbl']) - except ValueError: + except Exception: raise ValueError( f'Invalid lbl fields for site {molecule_site} for molecule type {molecule_type}.' ) try: atom_dict['na'] = int(atom_dict['na']) - except ValueError: + except Exception: raise ValueError( f'Invalid na fields for site {molecule_site} for molecule type {molecule_type}.' ) try: atom_dict['rx'] = float(atom_dict['rx']) - except ValueError: + except Exception: raise ValueError( - f'Invalid rx fields for site {molecule_site} for molecule type {molecule_type}.' + f'Invalid rx fields for site {molecule_site} for molecule type {molecule_type}. ' f'Value must be a float.' ) try: atom_dict['ry'] = float(atom_dict['ry']) - except ValueError: + except Exception: raise ValueError( - f'Invalid ry fields for site {molecule_site} for molecule type {molecule_type}.' + f'Invalid ry fields for site {molecule_site} for molecule type {molecule_type}. ' f'Value must be a float.' ) try: - atom_dict['rz'] = float(atom_dict['rx']) - except ValueError: + atom_dict['rz'] = float(atom_dict['rz']) + except Exception: raise ValueError( - f'Invalid rz fields for site {molecule_site} for molecule type {molecule_type}.' + f'Invalid rz fields for site {molecule_site} for molecule type {molecule_type}. ' f'Value must be a float.' ) try: atom_dict['chg'] = float(atom_dict['chg']) - except ValueError: + except Exception: raise ValueError( - f'Invalid chg fields for site {molecule_site} for molecule type {molecule_type}.' + f'Invalid chg fields for site {molecule_site} for molecule type {molecule_type}. ' f'Value must be a float.' ) try: atom_dict['eps'] = float(atom_dict['eps']) - except ValueError: + except Exception: raise ValueError( - f'Invalid eps fields for site {molecule_site} for molecule type {molecule_type}.' + f'Invalid eps fields for site {molecule_site} for molecule type {molecule_type}. ' f'Value must be a float.' ) try: atom_dict['sig'] = float(atom_dict['sig']) - except ValueError: + except Exception: raise ValueError( - f'Invalid sig fields for site {molecule_site} for molecule type {molecule_type}.' + f'Invalid sig fields for site {molecule_site} for molecule type {molecule_type}. ' f'Value must be a float.' ) diff --git a/diceplayer/shared/interface/dice_interface.py b/diceplayer/shared/interface/dice_interface.py index 89bcbd1..68d4be8 100644 --- a/diceplayer/shared/interface/dice_interface.py +++ b/diceplayer/shared/interface/dice_interface.py @@ -3,6 +3,7 @@ from __future__ import annotations from diceplayer.shared.config.dice_dto import DiceDTO from diceplayer.shared.config.step_dto import StepDTO from diceplayer.shared.interface import Interface +from diceplayer import logger from multiprocessing import Process, connection from setproctitle import setproctitle @@ -15,7 +16,6 @@ import time import sys import os - DICE_END_FLAG: Final[str] = "End of simulation" DICE_FLAG_LINE: Final[int] = -2 UMAANG3_TO_GCM3: Final[float] = 1.6605 @@ -41,8 +41,9 @@ class DiceInterface(Interface): procs = [] sentinels = [] - for proc in range(1, self.step.nprocs + 1): + logger.info(f"---------------------- DICE - CYCLE {cycle} --------------------------\n") + for proc in range(1, self.step.nprocs + 1): p = Process(target=self._simulation_process, args=(cycle, proc)) p.start() @@ -61,6 +62,8 @@ class DiceInterface(Interface): p.terminate() sys.exit(status) + logger.info("\n") + def reset(self): del self.step @@ -132,6 +135,11 @@ class DiceInterface(Interface): f"step{cycle:02d}", f"p{proc:02d}" ) + + logger.info( + f"Simulation process {str(proc_dir)} initiated with pid {os.getpid()}" + ) + os.chdir(proc_dir) if not (self.config.randominit == 'first' and cycle > 1): @@ -385,3 +393,5 @@ class DiceInterface(Interface): flag = outfile.readlines()[DICE_FLAG_LINE].strip() if flag != DICE_END_FLAG: raise RuntimeError(f"Dice process step{cycle:02d}-p{proc:02d} did not exit properly") + + logger.info(f"Dice {file_name} - step{cycle:02d}-p{proc:02d} exited properly") diff --git a/diceplayer/shared/interface/gaussian_interface.py b/diceplayer/shared/interface/gaussian_interface.py index 71bdf54..113878a 100644 --- a/diceplayer/shared/interface/gaussian_interface.py +++ b/diceplayer/shared/interface/gaussian_interface.py @@ -14,7 +14,7 @@ class GaussianInterface(Interface): def configure(self): pass - def start(self): + def start(self, cycle: int): pass def reset(self): diff --git a/diceplayer/shared/utils/logger.py b/diceplayer/shared/utils/logger.py index 2ed8372..2404f20 100644 --- a/diceplayer/shared/utils/logger.py +++ b/diceplayer/shared/utils/logger.py @@ -24,15 +24,17 @@ class Logger: if self._logger is None: self._logger = logging.getLogger(logger_name) - def set_logger(self, outfile='run.log', level=logging.INFO): - self.outfile = Path(outfile) - if self.outfile.exists(): - self.outfile.rename(str(self.outfile) + ".backup") + def set_logger(self, outfile='run.log', level=logging.INFO, stream=None): + outfile_path = None + if outfile is not None and stream is None: + outfile_path = Path(outfile) + if outfile_path.exists(): + outfile_path.rename(str(outfile_path) + ".backup") if level is not None: self._logger.setLevel(level) - self._create_handlers() + self._create_handlers(outfile_path, stream) self._was_set = True @@ -52,10 +54,12 @@ class Logger: def error(self, message): self._logger.error(message) - def _create_handlers(self): + def _create_handlers(self, outfile_path: Path, stream): handlers = [] - if self.outfile is not None: - handlers.append(logging.FileHandler(self.outfile, mode='a+')) + if outfile_path is not None: + handlers.append(logging.FileHandler(outfile_path, mode='a+')) + elif stream is not None: + handlers.append(logging.StreamHandler(stream)) else: handlers.append(logging.StreamHandler()) diff --git a/tests/shared/interface/test_dice_interface.py b/tests/shared/interface/test_dice_interface.py index c6bbfdb..6d39c4f 100644 --- a/tests/shared/interface/test_dice_interface.py +++ b/tests/shared/interface/test_dice_interface.py @@ -2,6 +2,7 @@ from diceplayer.shared.interface.dice_interface import DiceInterface from diceplayer.shared.environment.molecule import Molecule from diceplayer.shared.environment.atom import Atom from diceplayer.shared.config.step_dto import StepDTO +from diceplayer import logger import io @@ -12,6 +13,9 @@ import unittest class TestDiceInterface(unittest.TestCase): + def setUp(self): + logger.set_logger(stream=io.StringIO()) + def test_class_instantiation(self): dice = DiceInterface( { diff --git a/tests/shared/utils/test_logger.py b/tests/shared/utils/test_logger.py index 482341f..abe41a8 100644 --- a/tests/shared/utils/test_logger.py +++ b/tests/shared/utils/test_logger.py @@ -1,6 +1,7 @@ from diceplayer.shared.utils.logger import Logger, valid_logger import logging +import io from unittest import mock import unittest @@ -35,6 +36,23 @@ class TestLogger(unittest.TestCase): self.assertIsInstance(logger, Logger) + @mock.patch('builtins.open', mock.mock_open()) + def test_set_logger_to_file(self): + logger = Logger('test') + + logger.set_logger(stream=io.StringIO()) + + self.assertIsNotNone(logger._logger) + self.assertEqual(logger._logger.name, 'test') + + def test_set_logger_to_stream(self): + logger = Logger('test') + + logger.set_logger(stream=io.StringIO()) + + self.assertIsNotNone(logger._logger) + self.assertEqual(logger._logger.name, 'test') + @mock.patch('builtins.open', mock.mock_open()) @mock.patch('diceplayer.shared.utils.logger.Path.exists') @mock.patch('diceplayer.shared.utils.logger.Path.rename') @@ -64,6 +82,7 @@ class TestLogger(unittest.TestCase): @mock.patch('builtins.open', mock.mock_open()) def test_close(self): logger = Logger('test') + logger.set_logger() logger.close() @@ -72,7 +91,7 @@ class TestLogger(unittest.TestCase): @mock.patch('builtins.open', mock.mock_open()) def test_info(self): logger = Logger('test') - logger.set_logger() + logger.set_logger(stream=io.StringIO()) with self.assertLogs(level='INFO') as cm: logger.info('test') @@ -82,7 +101,7 @@ class TestLogger(unittest.TestCase): @mock.patch('builtins.open', mock.mock_open()) def test_debug(self): logger = Logger('test') - logger.set_logger(level=logging.DEBUG) + logger.set_logger(stream=io.StringIO(), level=logging.DEBUG) with self.assertLogs(level='DEBUG') as cm: logger.debug('test') @@ -92,7 +111,7 @@ class TestLogger(unittest.TestCase): @mock.patch('builtins.open', mock.mock_open()) def test_warning(self): logger = Logger('test') - logger.set_logger() + logger.set_logger(stream=io.StringIO()) with self.assertLogs(level='WARNING') as cm: logger.warning('test') @@ -102,7 +121,7 @@ class TestLogger(unittest.TestCase): @mock.patch('builtins.open', mock.mock_open()) def test_error(self): logger = Logger('test') - logger.set_logger() + logger.set_logger(stream=io.StringIO()) with self.assertLogs(level='ERROR') as cm: logger.error('test') diff --git a/tests/test_player.py b/tests/test_player.py new file mode 100644 index 0000000..e421331 --- /dev/null +++ b/tests/test_player.py @@ -0,0 +1,424 @@ +from diceplayer.player import Player +from diceplayer import logger + +import io + +from unittest import mock +import unittest + + +def get_config_example(): + return """ +diceplayer: + maxcyc: 3 + opt: no + ncores: 4 + nprocs: 4 + qmprog: 'g16' + lps: no + ghosts: no + altsteps: 20000 + +dice: + nmol: [1, 50] + dens: 0.75 + nstep: [2000, 3000, 4000] + isave: 1000 + outname: 'phb' + progname: '~/.local/bin/dice' + ljname: 'phb.ljc' + randominit: 'first' + +gaussian: + qmprog: 'g16' + level: 'MP2/aug-cc-pVDZ' + keywords: 'freq' +""" + + +def get_potentials_exemple(): + return """\ +* +2 +1 TEST + 1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000 +1 PLACEHOLDER + 1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000 +""" + + +def get_potentials_error_combrule(): + return """\ +. +2 +1 TEST + 1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000 +1 PLACEHOLDER + 1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000 +""" + + +def get_potentials_error_ntypes(): + return """\ +* +a +1 TEST + 1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000 +1 PLACEHOLDER + 1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000 +""" + + +def get_potentials_error_ntypes_config(): + return """\ +* +3 +1 TEST + 1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000 +1 PLACEHOLDER + 1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000 +""" + + +def get_potentials_error_nsites(): + return """\ +* +2 +. TEST + 1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000 +1 PLACEHOLDER + 1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000 +""" + + +def get_potentials_error_molname(): + return """\ +* +2 +1 + 1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000 +1 PLACEHOLDER + 1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000 +""" + + +def mock_open(file, *args, **kwargs): + values = { + "control.test.yml": get_config_example(), + "phb.ljc": get_potentials_exemple(), + "phb.error.combrule.ljc": get_potentials_error_combrule(), + "phb.error.ntypes.ljc": get_potentials_error_ntypes(), + "phb.error.ntypes.config.ljc": get_potentials_error_ntypes_config(), + "phb.error.nsites.ljc": get_potentials_error_nsites(), + "phb.error.molname.ljc": get_potentials_error_molname(), + } + mock_file = mock.mock_open(read_data=values[file]) + return mock_file() + + +class TestPlayer(unittest.TestCase): + def setUp(self): + logger.set_logger(stream=io.StringIO()) + + @mock.patch("builtins.open", mock_open) + def test_class_instantiation(self): + # This file does not exist and it will be mocked + player = Player("control.test.yml") + + self.assertIsInstance(player, Player) + + @mock.patch("builtins.open", mock_open) + def test_start(self): + player = Player("control.test.yml") + + player.print_keywords = mock.MagicMock() + player.create_simulation_dir = mock.MagicMock() + player.read_potentials = mock.MagicMock() + player.print_potentials = mock.MagicMock() + player.dice_start = mock.MagicMock() + + player.start(1) + + self.assertTrue(player.print_keywords.called) + self.assertTrue(player.create_simulation_dir.called) + self.assertTrue(player.read_potentials.called) + self.assertTrue(player.print_potentials.called) + self.assertEqual(player.dice_start.call_count, 3) + + @mock.patch("builtins.open", mock_open) + @mock.patch("diceplayer.player.Path") + def test_create_simulation_dir_if_already_exists(self, mock_path): + player = Player("control.test.yml") + mock_path.return_value.exists.return_value = True + + with self.assertRaises(FileExistsError): + player.create_simulation_dir() + + self.assertTrue(mock_path.called) + + @mock.patch("builtins.open", mock_open) + @mock.patch("diceplayer.player.Path") + def test_create_simulation_dir_if_not_exists(self, mock_path): + player = Player("control.test.yml") + mock_path.return_value.exists.return_value = False + + player.create_simulation_dir() + + self.assertTrue(mock_path.called) + + @mock.patch("diceplayer.player.sys") + @mock.patch("diceplayer.player.weekday_date_time") + @mock.patch("builtins.open", mock_open) + def test_print_keywords(self, mock_date_func, mock_sys): + player = Player("control.test.yml") + + mock_sys.version = 'TEST' + mock_date_func.return_value = '00 Test 0000 at 00:00:00' + + with self.assertLogs() as cm: + player.print_keywords() + + expected_output = ['INFO:diceplayer:##########################################################################################\n############# Welcome to DICEPLAYER version 1.0 #############\n##########################################################################################\n\n', 'INFO:diceplayer:Your python version is TEST\n', 'INFO:diceplayer:\n', 'INFO:diceplayer:Program started on 00 Test 0000 at 00:00:00\n', 'INFO:diceplayer:\n', 'INFO:diceplayer:Environment variables:\n', 'INFO:diceplayer:OMP_STACKSIZE = Not set\n', 'INFO:diceplayer:\n==========================================================================================\n CONTROL variables being used in this run:\n------------------------------------------------------------------------------------------\n\n', 'INFO:diceplayer:\n', 'INFO:diceplayer:------------------------------------------------------------------------------------------\n DICE variables being used in this run:\n------------------------------------------------------------------------------------------\n\n', 'INFO:diceplayer:dens = 0.75', 'INFO:diceplayer:isave = 1000', 'INFO:diceplayer:ljname = phb.ljc', 'INFO:diceplayer:nmol = [ 1 50 ]', 'INFO:diceplayer:nstep = [ 2000 3000 4000 ]', 'INFO:diceplayer:outname = phb', 'INFO:diceplayer:press = 1.0', 'INFO:diceplayer:progname = ~/.local/bin/dice', 'INFO:diceplayer:randominit = first', 'INFO:diceplayer:temp = 300.0', 'INFO:diceplayer:\n', 'INFO:diceplayer:------------------------------------------------------------------------------------------\n GAUSSIAN variables being used in this run:\n------------------------------------------------------------------------------------------\n\n', 'INFO:diceplayer:keywords = freq', 'INFO:diceplayer:level = MP2/aug-cc-pVDZ', 'INFO:diceplayer:pop = chelpg', 'INFO:diceplayer:qmprog = g16', 'INFO:diceplayer:\n'] + + self.assertEqual(cm.output, expected_output) + + def test_validate_atom_dict(self): + with self.assertRaises(ValueError) as context: + Player.validate_atom_dict( + molecule_type=0, + molecule_site=0, + atom_dict={ + "lbl": 0, "na": 1, "rx": 1.0, "ry": 1.0, "rz": 1.0, "chg": 1.0, "eps": 1.0 + } + ) + self.assertEqual( + str(context.exception), + "Invalid number of fields for site 1 for molecule type 1." + ) + + with self.assertRaises(ValueError) as context: + Player.validate_atom_dict( + molecule_type=0, + molecule_site=0, + atom_dict={ + "lbl": '', "na": 1, "rx": 1.0, "ry": 1.0, "rz": 1.0, "chg": 1.0, "eps": 1.0, "sig": 1.0 + } + ) + self.assertEqual( + str(context.exception), + "Invalid lbl fields for site 1 for molecule type 1." + ) + + with self.assertRaises(ValueError) as context: + Player.validate_atom_dict( + molecule_type=0, + molecule_site=0, + atom_dict={ + "lbl": 1.0, "na": '', "rx": 1.0, "ry": 1.0, "rz": 1.0, "chg": 1.0, "eps": 1.0, "sig": 1.0 + } + ) + self.assertEqual( + str(context.exception), + "Invalid na fields for site 1 for molecule type 1." + ) + + with self.assertRaises(ValueError) as context: + Player.validate_atom_dict( + molecule_type=0, + molecule_site=0, + atom_dict={ + "lbl": 1.0, "na": 1, "rx": '', "ry": 1.0, "rz": 1.0, "chg": 1.0, "eps": 1.0, "sig": 1.0 + } + ) + self.assertEqual( + str(context.exception), + "Invalid rx fields for site 1 for molecule type 1. Value must be a float." + ) + + with self.assertRaises(ValueError) as context: + Player.validate_atom_dict( + molecule_type=0, + molecule_site=0, + atom_dict={ + "lbl": 1.0, "na": 1, "rx": 1.0, "ry": '', "rz": 1.0, "chg": 1.0, "eps": 1.0, "sig": 1.0 + } + ) + self.assertEqual( + str(context.exception), + "Invalid ry fields for site 1 for molecule type 1. Value must be a float." + ) + + with self.assertRaises(ValueError) as context: + Player.validate_atom_dict( + molecule_type=0, + molecule_site=0, + atom_dict={ + "lbl": 1.0, "na": 1, "rx": 1.0, "ry": 1.0, "rz": '', "chg": 1.0, "eps": 1.0, "sig": 1.0 + } + ) + self.assertEqual( + str(context.exception), + "Invalid rz fields for site 1 for molecule type 1. Value must be a float." + ) + + with self.assertRaises(ValueError) as context: + Player.validate_atom_dict( + molecule_type=0, + molecule_site=0, + atom_dict={ + "lbl": 1.0, "na": 1, "rx": 1.0, "ry": 1.0, "rz": 1.0, "chg": '', "eps": 1.0, "sig": 1.0 + } + ) + self.assertEqual( + str(context.exception), + "Invalid chg fields for site 1 for molecule type 1. Value must be a float." + ) + + with self.assertRaises(ValueError) as context: + Player.validate_atom_dict( + molecule_type=0, + molecule_site=0, + atom_dict={ + "lbl": 1.0, "na": 1, "rx": 1.0, "ry": 1.0, "rz": 1.0, "chg": 1.0, "eps": '', "sig": 1.0 + } + ) + self.assertEqual( + str(context.exception), + "Invalid eps fields for site 1 for molecule type 1. Value must be a float." + ) + + with self.assertRaises(ValueError) as context: + Player.validate_atom_dict( + molecule_type=0, + molecule_site=0, + atom_dict={ + "lbl": 1.0, "na": 1, "rx": 1.0, "ry": 1.0, "rz": 1.0, "chg": 1.0, "eps": 1.0, "sig": '' + } + ) + self.assertEqual( + str(context.exception), + "Invalid sig fields for site 1 for molecule type 1. Value must be a float." + ) + + @mock.patch("builtins.open", mock_open) + @mock.patch("diceplayer.player.Path.exists", return_value=True) + def test_read_potentials(self, mock_path_exists): + player = Player("control.test.yml") + + player.read_potentials() + + self.assertEqual(player.system.molecule[0].molname, "TEST") + self.assertEqual(len(player.system.molecule[0].atom), 1) + + self.assertEqual(player.system.molecule[1].molname, "PLACEHOLDER") + self.assertEqual(len(player.system.molecule[1].atom), 1) + + @mock.patch("builtins.open", mock_open) + @mock.patch("diceplayer.player.Path.exists") + def test_read_potentials_error(self, mock_path_exists): + player = Player("control.test.yml") + + # Testing file not found error + mock_path_exists.return_value = False + with self.assertRaises(RuntimeError) as context: + player.read_potentials() + + self.assertEqual( + str(context.exception), + "Potential file phb.ljc not found." + ) + + # Enabling file found for next tests + mock_path_exists.return_value = True + + # Testing combrule error + with self.assertRaises(SystemExit) as context: + player.dice.config.ljname = "phb.error.combrule.ljc" + player.read_potentials() + + self.assertEqual( + str(context.exception), + "Error: expected a '*' or a '+' sign in 1st line of file phb.error.combrule.ljc" + ) + + # Testing ntypes error + with self.assertRaises(SystemExit) as context: + player.dice.config.ljname = "phb.error.ntypes.ljc" + player.read_potentials() + + self.assertEqual( + str(context.exception), + "Error: expected an integer in the 2nd line of file phb.error.ntypes.ljc" + ) + + # Testing ntypes error on config + with self.assertRaises(SystemExit) as context: + player.dice.config.ljname = "phb.error.ntypes.config.ljc" + player.read_potentials() + + self.assertEqual( + str(context.exception), + "Error: number of molecule types in file phb.error.ntypes.config.ljc " + "must match that of 'nmol' keyword in config file" + ) + + # Testing nsite error + with self.assertRaises(ValueError) as context: + player.dice.config.ljname = "phb.error.nsites.ljc" + player.read_potentials() + + self.assertEqual( + str(context.exception), + "Error: expected nsites to be an integer for molecule type 1" + ) + + # Testing molname error + with self.assertRaises(ValueError) as context: + player.dice.config.ljname = "phb.error.molname.ljc" + player.read_potentials() + + self.assertEqual( + str(context.exception), + "Error: expected nsites and molname for the molecule type 1" + ) + + @mock.patch("builtins.open", mock_open) + @mock.patch("diceplayer.player.Path.exists", return_value=True) + def test_print_potentials(self, mock_path_exists): + player = Player("control.test.yml") + player.read_potentials() + + with self.assertLogs(level='INFO') as context: + player.print_potentials() + + expected_output = ['INFO:diceplayer:==========================================================================================\n', 'INFO:diceplayer: Potential parameters from file phb.ljc:', 'INFO:diceplayer:------------------------------------------------------------------------------------------\n', 'INFO:diceplayer:Combination rule: *', 'INFO:diceplayer:Types of molecules: 2\n', 'INFO:diceplayer:1 atoms in molecule type 1:', 'INFO:diceplayer:---------------------------------------------------------------------------------', 'INFO:diceplayer:Lbl AN X Y Z Charge Epsilon Sigma Mass', 'INFO:diceplayer:---------------------------------------------------------------------------------', 'INFO:diceplayer:1 1 0.00000 0.00000 0.00000 0.000000 0.00000 0.0000 1.0079', 'INFO:diceplayer:\n', 'INFO:diceplayer:1 atoms in molecule type 2:', 'INFO:diceplayer:---------------------------------------------------------------------------------', 'INFO:diceplayer:Lbl AN X Y Z Charge Epsilon Sigma Mass', 'INFO:diceplayer:---------------------------------------------------------------------------------', 'INFO:diceplayer:1 1 0.00000 0.00000 0.00000 0.000000 0.00000 0.0000 1.0079', 'INFO:diceplayer:\n', 'INFO:diceplayer:=========================================================================================='] + + self.assertEqual( + context.output, + expected_output + ) + + @mock.patch("builtins.open", mock_open) + def test_dice_start(self): + player = Player("control.test.yml") + player.dice = mock.MagicMock() + player.dice.start = mock.MagicMock() + + player.dice_start(1) + + player.dice.start.assert_called_once() + + @mock.patch("builtins.open", mock_open) + def test_gaussian_start(self): + player = Player("control.test.yml") + player.gaussian = mock.MagicMock() + player.gaussian.start = mock.MagicMock() + + player.gaussian_start(1) + + player.gaussian.start.assert_called_once() + + + +if __name__ == '__main__': + unittest.main()