Implements Additional Logs and class Player Tests

This commit is contained in:
2023-05-03 03:14:26 -03:00
parent 56994dba27
commit b440a0f05d
8 changed files with 586 additions and 80 deletions

View File

@@ -6,8 +6,8 @@ diceplayer:
qmprog: 'g16' qmprog: 'g16'
lps: no lps: no
ghosts: no ghosts: no
altsteps: 20000 altsteps: 20000
dice: dice:
nmol: [1, 50] nmol: [1, 50]
dens: 0.75 dens: 0.75

View File

@@ -9,7 +9,7 @@ from diceplayer.shared.environment.system import System
from diceplayer.shared.config.step_dto import StepDTO from diceplayer.shared.config.step_dto import StepDTO
from diceplayer.shared.config.dice_dto import DiceDTO from diceplayer.shared.config.dice_dto import DiceDTO
from diceplayer.shared.environment.atom import Atom from diceplayer.shared.environment.atom import Atom
from diceplayer.shared.utils.ptable import atommass from diceplayer import logger
from dataclasses import fields from dataclasses import fields
from pathlib import Path from pathlib import Path
@@ -24,13 +24,6 @@ ENV = ["OMP_STACKSIZE"]
class Player: class Player:
__slots__ = [
'config',
'system',
'dice',
'gaussian',
]
def __init__(self, infile: str): def __init__(self, infile: str):
config_data = self.read_keywords(infile) config_data = self.read_keywords(infile)
@@ -43,16 +36,16 @@ class Player:
self.gaussian = GaussianInterface(config_data.get("gaussian")) self.gaussian = GaussianInterface(config_data.get("gaussian"))
self.dice = DiceInterface(config_data.get("dice")) self.dice = DiceInterface(config_data.get("dice"))
def start(self): def start(self, initial_cycle: int = 1):
self.print_keywords() self.print_keywords()
self.create_simulation_dir() self.create_simulation_dir()
self.read_potentials() self.read_potentials()
# self.print_potentials() self.print_potentials()
self.dice_start(1) for cycle in range(initial_cycle, self.config.maxcyc + 1):
self.dice_start(2) self.dice_start(cycle)
def create_simulation_dir(self): def create_simulation_dir(self):
simulation_dir_path = Path(self.config.simulation_dir) simulation_dir_path = Path(self.config.simulation_dir)
@@ -61,12 +54,7 @@ class Player:
f"Error: a file or a directory {self.config.simulation_dir} already exists," f"Error: a file or a directory {self.config.simulation_dir} already exists,"
f" move or delete the simfiles directory to continue." f" move or delete the simfiles directory to continue."
) )
try:
simulation_dir_path.mkdir() simulation_dir_path.mkdir()
except FileExistsError:
OSError(
f"Error: cannot make directory {self.config.simulation_dir}"
)
def print_keywords(self) -> None: def print_keywords(self) -> None:
@@ -75,38 +63,38 @@ class Player:
if getattr(config, key) is not None: if getattr(config, key) is not None:
if isinstance(getattr(config, key), list): if isinstance(getattr(config, key), list):
string = " ".join(str(x) for x in getattr(config, key)) string = " ".join(str(x) for x in getattr(config, key))
logging.info(f"{key} = [ {string} ]") logger.info(f"{key} = [ {string} ]")
else: else:
logging.info(f"{key} = {getattr(config, key)}") logger.info(f"{key} = {getattr(config, key)}")
logging.info( logger.info(
"##########################################################################################\n" "##########################################################################################\n"
"############# Welcome to DICEPLAYER version 1.0 #############\n" "############# Welcome to DICEPLAYER version 1.0 #############\n"
"##########################################################################################\n" "##########################################################################################\n"
"\n" "\n"
) )
logging.info("Your python version is {}\n".format(sys.version)) logger.info("Your python version is {}\n".format(sys.version))
logging.info("\n") logger.info("\n")
logging.info("Program started on {}\n".format(weekday_date_time())) logger.info("Program started on {}\n".format(weekday_date_time()))
logging.info("\n") logger.info("\n")
logging.info("Environment variables:\n") logger.info("Environment variables:\n")
for var in ENV: for var in ENV:
logging.info( logger.info(
"{} = {}\n".format( "{} = {}\n".format(
var, (os.environ[var] if var in os.environ else "Not set") var, (os.environ[var] if var in os.environ else "Not set")
) )
) )
logging.info( logger.info(
"\n==========================================================================================\n" "\n==========================================================================================\n"
" CONTROL variables being used in this run:\n" " CONTROL variables being used in this run:\n"
"------------------------------------------------------------------------------------------\n" "------------------------------------------------------------------------------------------\n"
"\n" "\n"
) )
logging.info("\n") logger.info("\n")
logging.info( logger.info(
"------------------------------------------------------------------------------------------\n" "------------------------------------------------------------------------------------------\n"
" DICE variables being used in this run:\n" " DICE variables being used in this run:\n"
"------------------------------------------------------------------------------------------\n" "------------------------------------------------------------------------------------------\n"
@@ -115,9 +103,9 @@ class Player:
log_keywords(self.dice.config, DiceDTO) log_keywords(self.dice.config, DiceDTO)
logging.info("\n") logger.info("\n")
logging.info( logger.info(
"------------------------------------------------------------------------------------------\n" "------------------------------------------------------------------------------------------\n"
" GAUSSIAN variables being used in this run:\n" " GAUSSIAN variables being used in this run:\n"
"------------------------------------------------------------------------------------------\n" "------------------------------------------------------------------------------------------\n"
@@ -126,18 +114,19 @@ class Player:
log_keywords(self.gaussian.config, GaussianDTO) log_keywords(self.gaussian.config, GaussianDTO)
logging.info("\n") logger.info("\n")
def read_potentials(self): def read_potentials(self):
try: ljname_path = Path(self.dice.config.ljname)
if ljname_path.exists():
with open(self.dice.config.ljname) as file: with open(self.dice.config.ljname) as file:
ljdata = file.readlines() ljc_data = file.readlines()
except FileNotFoundError: else:
raise RuntimeError( raise RuntimeError(
f"Potential file {self.dice.config.ljname} not found." f"Potential file {self.dice.config.ljname} not found."
) )
combrule = ljdata.pop(0).split()[0] combrule = ljc_data.pop(0).split()[0]
if combrule not in ("*", "+"): if combrule not in ("*", "+"):
sys.exit( sys.exit(
"Error: expected a '*' or a '+' sign in 1st line of file {}".format( "Error: expected a '*' or a '+' sign in 1st line of file {}".format(
@@ -146,7 +135,7 @@ class Player:
) )
self.dice.config.combrule = combrule self.dice.config.combrule = combrule
ntypes = ljdata.pop(0).split()[0] ntypes = ljc_data.pop(0).split()[0]
if not ntypes.isdigit(): if not ntypes.isdigit():
sys.exit( sys.exit(
"Error: expected an integer in the 2nd line of file {}".format( "Error: expected an integer in the 2nd line of file {}".format(
@@ -157,22 +146,22 @@ class Player:
if ntypes != len(self.dice.config.nmol): if ntypes != len(self.dice.config.nmol):
sys.exit( sys.exit(
f"Error: number of molecule types in file {self.dice.config.ljname}" f"Error: number of molecule types in file {self.dice.config.ljname} "
f"must match that of 'nmol' keyword in config file" f"must match that of 'nmol' keyword in config file"
) )
for i in range(ntypes): for i in range(ntypes):
nsites, molname = ljdata.pop(0).split()[:2] try:
nsites, molname = ljc_data.pop(0).split()[:2]
except ValueError:
raise ValueError(
f"Error: expected nsites and molname for the molecule type {i+1}"
)
if not nsites.isdigit(): if not nsites.isdigit():
raise ValueError( raise ValueError(
f"Error: expected nsites to be an integer for molecule type {i}" f"Error: expected nsites to be an integer for molecule type {i+1}"
)
if molname is None:
raise ValueError(
f"Error: expected molecule name for molecule type {i}"
) )
nsites = int(nsites) nsites = int(nsites)
@@ -182,12 +171,68 @@ class Player:
for j in range(nsites): for j in range(nsites):
new_atom = dict(zip( new_atom = dict(zip(
atom_fields, atom_fields,
ljdata.pop(0).split() ljc_data.pop(0).split()
)) ))
self.system.molecule[i].add_atom( self.system.molecule[i].add_atom(
Atom(**self.validate_atom_dict(i, j, new_atom)) Atom(**self.validate_atom_dict(i, j, new_atom))
) )
def print_potentials(self) -> None:
formatstr = "{:<3d} {:>3d} {:>10.5f} {:>10.5f} {:>10.5f} {:>10.6f} {:>9.5f} {:>7.4f} {:>9.4f}"
logger.info(
"==========================================================================================\n"
)
logger.info(
f" Potential parameters from file {self.dice.config.ljname}:"
)
logger.info(
"------------------------------------------------------------------------------------------"
"\n"
)
logger.info(f"Combination rule: {self.dice.config.combrule}")
logger.info(
f"Types of molecules: {len(self.system.molecule)}\n"
)
i = 0
for mol in self.system.molecule:
i += 1
logger.info(
"{} atoms in molecule type {}:".format(len(mol.atom), i)
)
logger.info(
"---------------------------------------------------------------------------------"
)
logger.info(
"Lbl AN X Y Z Charge Epsilon Sigma Mass"
)
logger.info(
"---------------------------------------------------------------------------------"
)
for atom in mol.atom:
logger.info(
formatstr.format(
atom.lbl,
atom.na,
atom.rx,
atom.ry,
atom.rz,
atom.chg,
atom.eps,
atom.sig,
atom.mass,
)
)
logger.info("\n")
logger.info(
"=========================================================================================="
)
def dice_start(self, cycle: int): def dice_start(self, cycle: int):
self.dice.configure( self.dice.configure(
StepDTO( StepDTO(
@@ -204,8 +249,8 @@ class Player:
self.dice.reset() self.dice.reset()
def gaussian_start(self): def gaussian_start(self, cycle: int):
self.gaussian.start() self.gaussian.start(cycle)
@staticmethod @staticmethod
def validate_atom_dict(molecule_type, molecule_site, atom_dict: dict) -> dict: def validate_atom_dict(molecule_type, molecule_site, atom_dict: dict) -> dict:
@@ -219,63 +264,63 @@ class Player:
try: try:
atom_dict['lbl'] = int(atom_dict['lbl']) atom_dict['lbl'] = int(atom_dict['lbl'])
except ValueError: except Exception:
raise ValueError( raise ValueError(
f'Invalid lbl fields for site {molecule_site} for molecule type {molecule_type}.' f'Invalid lbl fields for site {molecule_site} for molecule type {molecule_type}.'
) )
try: try:
atom_dict['na'] = int(atom_dict['na']) atom_dict['na'] = int(atom_dict['na'])
except ValueError: except Exception:
raise ValueError( raise ValueError(
f'Invalid na fields for site {molecule_site} for molecule type {molecule_type}.' f'Invalid na fields for site {molecule_site} for molecule type {molecule_type}.'
) )
try: try:
atom_dict['rx'] = float(atom_dict['rx']) atom_dict['rx'] = float(atom_dict['rx'])
except ValueError: except Exception:
raise ValueError( raise ValueError(
f'Invalid rx fields for site {molecule_site} for molecule type {molecule_type}.' f'Invalid rx fields for site {molecule_site} for molecule type {molecule_type}. '
f'Value must be a float.' f'Value must be a float.'
) )
try: try:
atom_dict['ry'] = float(atom_dict['ry']) atom_dict['ry'] = float(atom_dict['ry'])
except ValueError: except Exception:
raise ValueError( raise ValueError(
f'Invalid ry fields for site {molecule_site} for molecule type {molecule_type}.' f'Invalid ry fields for site {molecule_site} for molecule type {molecule_type}. '
f'Value must be a float.' f'Value must be a float.'
) )
try: try:
atom_dict['rz'] = float(atom_dict['rx']) atom_dict['rz'] = float(atom_dict['rz'])
except ValueError: except Exception:
raise ValueError( raise ValueError(
f'Invalid rz fields for site {molecule_site} for molecule type {molecule_type}.' f'Invalid rz fields for site {molecule_site} for molecule type {molecule_type}. '
f'Value must be a float.' f'Value must be a float.'
) )
try: try:
atom_dict['chg'] = float(atom_dict['chg']) atom_dict['chg'] = float(atom_dict['chg'])
except ValueError: except Exception:
raise ValueError( raise ValueError(
f'Invalid chg fields for site {molecule_site} for molecule type {molecule_type}.' f'Invalid chg fields for site {molecule_site} for molecule type {molecule_type}. '
f'Value must be a float.' f'Value must be a float.'
) )
try: try:
atom_dict['eps'] = float(atom_dict['eps']) atom_dict['eps'] = float(atom_dict['eps'])
except ValueError: except Exception:
raise ValueError( raise ValueError(
f'Invalid eps fields for site {molecule_site} for molecule type {molecule_type}.' f'Invalid eps fields for site {molecule_site} for molecule type {molecule_type}. '
f'Value must be a float.' f'Value must be a float.'
) )
try: try:
atom_dict['sig'] = float(atom_dict['sig']) atom_dict['sig'] = float(atom_dict['sig'])
except ValueError: except Exception:
raise ValueError( raise ValueError(
f'Invalid sig fields for site {molecule_site} for molecule type {molecule_type}.' f'Invalid sig fields for site {molecule_site} for molecule type {molecule_type}. '
f'Value must be a float.' f'Value must be a float.'
) )

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
from diceplayer.shared.config.dice_dto import DiceDTO from diceplayer.shared.config.dice_dto import DiceDTO
from diceplayer.shared.config.step_dto import StepDTO from diceplayer.shared.config.step_dto import StepDTO
from diceplayer.shared.interface import Interface from diceplayer.shared.interface import Interface
from diceplayer import logger
from multiprocessing import Process, connection from multiprocessing import Process, connection
from setproctitle import setproctitle from setproctitle import setproctitle
@@ -15,7 +16,6 @@ import time
import sys import sys
import os import os
DICE_END_FLAG: Final[str] = "End of simulation" DICE_END_FLAG: Final[str] = "End of simulation"
DICE_FLAG_LINE: Final[int] = -2 DICE_FLAG_LINE: Final[int] = -2
UMAANG3_TO_GCM3: Final[float] = 1.6605 UMAANG3_TO_GCM3: Final[float] = 1.6605
@@ -41,8 +41,9 @@ class DiceInterface(Interface):
procs = [] procs = []
sentinels = [] sentinels = []
for proc in range(1, self.step.nprocs + 1): logger.info(f"---------------------- DICE - CYCLE {cycle} --------------------------\n")
for proc in range(1, self.step.nprocs + 1):
p = Process(target=self._simulation_process, args=(cycle, proc)) p = Process(target=self._simulation_process, args=(cycle, proc))
p.start() p.start()
@@ -61,6 +62,8 @@ class DiceInterface(Interface):
p.terminate() p.terminate()
sys.exit(status) sys.exit(status)
logger.info("\n")
def reset(self): def reset(self):
del self.step del self.step
@@ -132,6 +135,11 @@ class DiceInterface(Interface):
f"step{cycle:02d}", f"step{cycle:02d}",
f"p{proc:02d}" f"p{proc:02d}"
) )
logger.info(
f"Simulation process {str(proc_dir)} initiated with pid {os.getpid()}"
)
os.chdir(proc_dir) os.chdir(proc_dir)
if not (self.config.randominit == 'first' and cycle > 1): if not (self.config.randominit == 'first' and cycle > 1):
@@ -385,3 +393,5 @@ class DiceInterface(Interface):
flag = outfile.readlines()[DICE_FLAG_LINE].strip() flag = outfile.readlines()[DICE_FLAG_LINE].strip()
if flag != DICE_END_FLAG: if flag != DICE_END_FLAG:
raise RuntimeError(f"Dice process step{cycle:02d}-p{proc:02d} did not exit properly") raise RuntimeError(f"Dice process step{cycle:02d}-p{proc:02d} did not exit properly")
logger.info(f"Dice {file_name} - step{cycle:02d}-p{proc:02d} exited properly")

View File

@@ -14,7 +14,7 @@ class GaussianInterface(Interface):
def configure(self): def configure(self):
pass pass
def start(self): def start(self, cycle: int):
pass pass
def reset(self): def reset(self):

View File

@@ -24,15 +24,17 @@ class Logger:
if self._logger is None: if self._logger is None:
self._logger = logging.getLogger(logger_name) self._logger = logging.getLogger(logger_name)
def set_logger(self, outfile='run.log', level=logging.INFO): def set_logger(self, outfile='run.log', level=logging.INFO, stream=None):
self.outfile = Path(outfile) outfile_path = None
if self.outfile.exists(): if outfile is not None and stream is None:
self.outfile.rename(str(self.outfile) + ".backup") outfile_path = Path(outfile)
if outfile_path.exists():
outfile_path.rename(str(outfile_path) + ".backup")
if level is not None: if level is not None:
self._logger.setLevel(level) self._logger.setLevel(level)
self._create_handlers() self._create_handlers(outfile_path, stream)
self._was_set = True self._was_set = True
@@ -52,10 +54,12 @@ class Logger:
def error(self, message): def error(self, message):
self._logger.error(message) self._logger.error(message)
def _create_handlers(self): def _create_handlers(self, outfile_path: Path, stream):
handlers = [] handlers = []
if self.outfile is not None: if outfile_path is not None:
handlers.append(logging.FileHandler(self.outfile, mode='a+')) handlers.append(logging.FileHandler(outfile_path, mode='a+'))
elif stream is not None:
handlers.append(logging.StreamHandler(stream))
else: else:
handlers.append(logging.StreamHandler()) handlers.append(logging.StreamHandler())

View File

@@ -2,6 +2,7 @@ from diceplayer.shared.interface.dice_interface import DiceInterface
from diceplayer.shared.environment.molecule import Molecule from diceplayer.shared.environment.molecule import Molecule
from diceplayer.shared.environment.atom import Atom from diceplayer.shared.environment.atom import Atom
from diceplayer.shared.config.step_dto import StepDTO from diceplayer.shared.config.step_dto import StepDTO
from diceplayer import logger
import io import io
@@ -12,6 +13,9 @@ import unittest
class TestDiceInterface(unittest.TestCase): class TestDiceInterface(unittest.TestCase):
def setUp(self):
logger.set_logger(stream=io.StringIO())
def test_class_instantiation(self): def test_class_instantiation(self):
dice = DiceInterface( dice = DiceInterface(
{ {

View File

@@ -1,6 +1,7 @@
from diceplayer.shared.utils.logger import Logger, valid_logger from diceplayer.shared.utils.logger import Logger, valid_logger
import logging import logging
import io
from unittest import mock from unittest import mock
import unittest import unittest
@@ -35,6 +36,23 @@ class TestLogger(unittest.TestCase):
self.assertIsInstance(logger, Logger) self.assertIsInstance(logger, Logger)
@mock.patch('builtins.open', mock.mock_open())
def test_set_logger_to_file(self):
logger = Logger('test')
logger.set_logger(stream=io.StringIO())
self.assertIsNotNone(logger._logger)
self.assertEqual(logger._logger.name, 'test')
def test_set_logger_to_stream(self):
logger = Logger('test')
logger.set_logger(stream=io.StringIO())
self.assertIsNotNone(logger._logger)
self.assertEqual(logger._logger.name, 'test')
@mock.patch('builtins.open', mock.mock_open()) @mock.patch('builtins.open', mock.mock_open())
@mock.patch('diceplayer.shared.utils.logger.Path.exists') @mock.patch('diceplayer.shared.utils.logger.Path.exists')
@mock.patch('diceplayer.shared.utils.logger.Path.rename') @mock.patch('diceplayer.shared.utils.logger.Path.rename')
@@ -64,6 +82,7 @@ class TestLogger(unittest.TestCase):
@mock.patch('builtins.open', mock.mock_open()) @mock.patch('builtins.open', mock.mock_open())
def test_close(self): def test_close(self):
logger = Logger('test') logger = Logger('test')
logger.set_logger() logger.set_logger()
logger.close() logger.close()
@@ -72,7 +91,7 @@ class TestLogger(unittest.TestCase):
@mock.patch('builtins.open', mock.mock_open()) @mock.patch('builtins.open', mock.mock_open())
def test_info(self): def test_info(self):
logger = Logger('test') logger = Logger('test')
logger.set_logger() logger.set_logger(stream=io.StringIO())
with self.assertLogs(level='INFO') as cm: with self.assertLogs(level='INFO') as cm:
logger.info('test') logger.info('test')
@@ -82,7 +101,7 @@ class TestLogger(unittest.TestCase):
@mock.patch('builtins.open', mock.mock_open()) @mock.patch('builtins.open', mock.mock_open())
def test_debug(self): def test_debug(self):
logger = Logger('test') logger = Logger('test')
logger.set_logger(level=logging.DEBUG) logger.set_logger(stream=io.StringIO(), level=logging.DEBUG)
with self.assertLogs(level='DEBUG') as cm: with self.assertLogs(level='DEBUG') as cm:
logger.debug('test') logger.debug('test')
@@ -92,7 +111,7 @@ class TestLogger(unittest.TestCase):
@mock.patch('builtins.open', mock.mock_open()) @mock.patch('builtins.open', mock.mock_open())
def test_warning(self): def test_warning(self):
logger = Logger('test') logger = Logger('test')
logger.set_logger() logger.set_logger(stream=io.StringIO())
with self.assertLogs(level='WARNING') as cm: with self.assertLogs(level='WARNING') as cm:
logger.warning('test') logger.warning('test')
@@ -102,7 +121,7 @@ class TestLogger(unittest.TestCase):
@mock.patch('builtins.open', mock.mock_open()) @mock.patch('builtins.open', mock.mock_open())
def test_error(self): def test_error(self):
logger = Logger('test') logger = Logger('test')
logger.set_logger() logger.set_logger(stream=io.StringIO())
with self.assertLogs(level='ERROR') as cm: with self.assertLogs(level='ERROR') as cm:
logger.error('test') logger.error('test')

424
tests/test_player.py Normal file
View File

@@ -0,0 +1,424 @@
from diceplayer.player import Player
from diceplayer import logger
import io
from unittest import mock
import unittest
def get_config_example():
return """
diceplayer:
maxcyc: 3
opt: no
ncores: 4
nprocs: 4
qmprog: 'g16'
lps: no
ghosts: no
altsteps: 20000
dice:
nmol: [1, 50]
dens: 0.75
nstep: [2000, 3000, 4000]
isave: 1000
outname: 'phb'
progname: '~/.local/bin/dice'
ljname: 'phb.ljc'
randominit: 'first'
gaussian:
qmprog: 'g16'
level: 'MP2/aug-cc-pVDZ'
keywords: 'freq'
"""
def get_potentials_exemple():
return """\
*
2
1 TEST
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
1 PLACEHOLDER
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
"""
def get_potentials_error_combrule():
return """\
.
2
1 TEST
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
1 PLACEHOLDER
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
"""
def get_potentials_error_ntypes():
return """\
*
a
1 TEST
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
1 PLACEHOLDER
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
"""
def get_potentials_error_ntypes_config():
return """\
*
3
1 TEST
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
1 PLACEHOLDER
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
"""
def get_potentials_error_nsites():
return """\
*
2
. TEST
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
1 PLACEHOLDER
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
"""
def get_potentials_error_molname():
return """\
*
2
1
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
1 PLACEHOLDER
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
"""
def mock_open(file, *args, **kwargs):
values = {
"control.test.yml": get_config_example(),
"phb.ljc": get_potentials_exemple(),
"phb.error.combrule.ljc": get_potentials_error_combrule(),
"phb.error.ntypes.ljc": get_potentials_error_ntypes(),
"phb.error.ntypes.config.ljc": get_potentials_error_ntypes_config(),
"phb.error.nsites.ljc": get_potentials_error_nsites(),
"phb.error.molname.ljc": get_potentials_error_molname(),
}
mock_file = mock.mock_open(read_data=values[file])
return mock_file()
class TestPlayer(unittest.TestCase):
def setUp(self):
logger.set_logger(stream=io.StringIO())
@mock.patch("builtins.open", mock_open)
def test_class_instantiation(self):
# This file does not exist and it will be mocked
player = Player("control.test.yml")
self.assertIsInstance(player, Player)
@mock.patch("builtins.open", mock_open)
def test_start(self):
player = Player("control.test.yml")
player.print_keywords = mock.MagicMock()
player.create_simulation_dir = mock.MagicMock()
player.read_potentials = mock.MagicMock()
player.print_potentials = mock.MagicMock()
player.dice_start = mock.MagicMock()
player.start(1)
self.assertTrue(player.print_keywords.called)
self.assertTrue(player.create_simulation_dir.called)
self.assertTrue(player.read_potentials.called)
self.assertTrue(player.print_potentials.called)
self.assertEqual(player.dice_start.call_count, 3)
@mock.patch("builtins.open", mock_open)
@mock.patch("diceplayer.player.Path")
def test_create_simulation_dir_if_already_exists(self, mock_path):
player = Player("control.test.yml")
mock_path.return_value.exists.return_value = True
with self.assertRaises(FileExistsError):
player.create_simulation_dir()
self.assertTrue(mock_path.called)
@mock.patch("builtins.open", mock_open)
@mock.patch("diceplayer.player.Path")
def test_create_simulation_dir_if_not_exists(self, mock_path):
player = Player("control.test.yml")
mock_path.return_value.exists.return_value = False
player.create_simulation_dir()
self.assertTrue(mock_path.called)
@mock.patch("diceplayer.player.sys")
@mock.patch("diceplayer.player.weekday_date_time")
@mock.patch("builtins.open", mock_open)
def test_print_keywords(self, mock_date_func, mock_sys):
player = Player("control.test.yml")
mock_sys.version = 'TEST'
mock_date_func.return_value = '00 Test 0000 at 00:00:00'
with self.assertLogs() as cm:
player.print_keywords()
expected_output = ['INFO:diceplayer:##########################################################################################\n############# Welcome to DICEPLAYER version 1.0 #############\n##########################################################################################\n\n', 'INFO:diceplayer:Your python version is TEST\n', 'INFO:diceplayer:\n', 'INFO:diceplayer:Program started on 00 Test 0000 at 00:00:00\n', 'INFO:diceplayer:\n', 'INFO:diceplayer:Environment variables:\n', 'INFO:diceplayer:OMP_STACKSIZE = Not set\n', 'INFO:diceplayer:\n==========================================================================================\n CONTROL variables being used in this run:\n------------------------------------------------------------------------------------------\n\n', 'INFO:diceplayer:\n', 'INFO:diceplayer:------------------------------------------------------------------------------------------\n DICE variables being used in this run:\n------------------------------------------------------------------------------------------\n\n', 'INFO:diceplayer:dens = 0.75', 'INFO:diceplayer:isave = 1000', 'INFO:diceplayer:ljname = phb.ljc', 'INFO:diceplayer:nmol = [ 1 50 ]', 'INFO:diceplayer:nstep = [ 2000 3000 4000 ]', 'INFO:diceplayer:outname = phb', 'INFO:diceplayer:press = 1.0', 'INFO:diceplayer:progname = ~/.local/bin/dice', 'INFO:diceplayer:randominit = first', 'INFO:diceplayer:temp = 300.0', 'INFO:diceplayer:\n', 'INFO:diceplayer:------------------------------------------------------------------------------------------\n GAUSSIAN variables being used in this run:\n------------------------------------------------------------------------------------------\n\n', 'INFO:diceplayer:keywords = freq', 'INFO:diceplayer:level = MP2/aug-cc-pVDZ', 'INFO:diceplayer:pop = chelpg', 'INFO:diceplayer:qmprog = g16', 'INFO:diceplayer:\n']
self.assertEqual(cm.output, expected_output)
def test_validate_atom_dict(self):
with self.assertRaises(ValueError) as context:
Player.validate_atom_dict(
molecule_type=0,
molecule_site=0,
atom_dict={
"lbl": 0, "na": 1, "rx": 1.0, "ry": 1.0, "rz": 1.0, "chg": 1.0, "eps": 1.0
}
)
self.assertEqual(
str(context.exception),
"Invalid number of fields for site 1 for molecule type 1."
)
with self.assertRaises(ValueError) as context:
Player.validate_atom_dict(
molecule_type=0,
molecule_site=0,
atom_dict={
"lbl": '', "na": 1, "rx": 1.0, "ry": 1.0, "rz": 1.0, "chg": 1.0, "eps": 1.0, "sig": 1.0
}
)
self.assertEqual(
str(context.exception),
"Invalid lbl fields for site 1 for molecule type 1."
)
with self.assertRaises(ValueError) as context:
Player.validate_atom_dict(
molecule_type=0,
molecule_site=0,
atom_dict={
"lbl": 1.0, "na": '', "rx": 1.0, "ry": 1.0, "rz": 1.0, "chg": 1.0, "eps": 1.0, "sig": 1.0
}
)
self.assertEqual(
str(context.exception),
"Invalid na fields for site 1 for molecule type 1."
)
with self.assertRaises(ValueError) as context:
Player.validate_atom_dict(
molecule_type=0,
molecule_site=0,
atom_dict={
"lbl": 1.0, "na": 1, "rx": '', "ry": 1.0, "rz": 1.0, "chg": 1.0, "eps": 1.0, "sig": 1.0
}
)
self.assertEqual(
str(context.exception),
"Invalid rx fields for site 1 for molecule type 1. Value must be a float."
)
with self.assertRaises(ValueError) as context:
Player.validate_atom_dict(
molecule_type=0,
molecule_site=0,
atom_dict={
"lbl": 1.0, "na": 1, "rx": 1.0, "ry": '', "rz": 1.0, "chg": 1.0, "eps": 1.0, "sig": 1.0
}
)
self.assertEqual(
str(context.exception),
"Invalid ry fields for site 1 for molecule type 1. Value must be a float."
)
with self.assertRaises(ValueError) as context:
Player.validate_atom_dict(
molecule_type=0,
molecule_site=0,
atom_dict={
"lbl": 1.0, "na": 1, "rx": 1.0, "ry": 1.0, "rz": '', "chg": 1.0, "eps": 1.0, "sig": 1.0
}
)
self.assertEqual(
str(context.exception),
"Invalid rz fields for site 1 for molecule type 1. Value must be a float."
)
with self.assertRaises(ValueError) as context:
Player.validate_atom_dict(
molecule_type=0,
molecule_site=0,
atom_dict={
"lbl": 1.0, "na": 1, "rx": 1.0, "ry": 1.0, "rz": 1.0, "chg": '', "eps": 1.0, "sig": 1.0
}
)
self.assertEqual(
str(context.exception),
"Invalid chg fields for site 1 for molecule type 1. Value must be a float."
)
with self.assertRaises(ValueError) as context:
Player.validate_atom_dict(
molecule_type=0,
molecule_site=0,
atom_dict={
"lbl": 1.0, "na": 1, "rx": 1.0, "ry": 1.0, "rz": 1.0, "chg": 1.0, "eps": '', "sig": 1.0
}
)
self.assertEqual(
str(context.exception),
"Invalid eps fields for site 1 for molecule type 1. Value must be a float."
)
with self.assertRaises(ValueError) as context:
Player.validate_atom_dict(
molecule_type=0,
molecule_site=0,
atom_dict={
"lbl": 1.0, "na": 1, "rx": 1.0, "ry": 1.0, "rz": 1.0, "chg": 1.0, "eps": 1.0, "sig": ''
}
)
self.assertEqual(
str(context.exception),
"Invalid sig fields for site 1 for molecule type 1. Value must be a float."
)
@mock.patch("builtins.open", mock_open)
@mock.patch("diceplayer.player.Path.exists", return_value=True)
def test_read_potentials(self, mock_path_exists):
player = Player("control.test.yml")
player.read_potentials()
self.assertEqual(player.system.molecule[0].molname, "TEST")
self.assertEqual(len(player.system.molecule[0].atom), 1)
self.assertEqual(player.system.molecule[1].molname, "PLACEHOLDER")
self.assertEqual(len(player.system.molecule[1].atom), 1)
@mock.patch("builtins.open", mock_open)
@mock.patch("diceplayer.player.Path.exists")
def test_read_potentials_error(self, mock_path_exists):
player = Player("control.test.yml")
# Testing file not found error
mock_path_exists.return_value = False
with self.assertRaises(RuntimeError) as context:
player.read_potentials()
self.assertEqual(
str(context.exception),
"Potential file phb.ljc not found."
)
# Enabling file found for next tests
mock_path_exists.return_value = True
# Testing combrule error
with self.assertRaises(SystemExit) as context:
player.dice.config.ljname = "phb.error.combrule.ljc"
player.read_potentials()
self.assertEqual(
str(context.exception),
"Error: expected a '*' or a '+' sign in 1st line of file phb.error.combrule.ljc"
)
# Testing ntypes error
with self.assertRaises(SystemExit) as context:
player.dice.config.ljname = "phb.error.ntypes.ljc"
player.read_potentials()
self.assertEqual(
str(context.exception),
"Error: expected an integer in the 2nd line of file phb.error.ntypes.ljc"
)
# Testing ntypes error on config
with self.assertRaises(SystemExit) as context:
player.dice.config.ljname = "phb.error.ntypes.config.ljc"
player.read_potentials()
self.assertEqual(
str(context.exception),
"Error: number of molecule types in file phb.error.ntypes.config.ljc "
"must match that of 'nmol' keyword in config file"
)
# Testing nsite error
with self.assertRaises(ValueError) as context:
player.dice.config.ljname = "phb.error.nsites.ljc"
player.read_potentials()
self.assertEqual(
str(context.exception),
"Error: expected nsites to be an integer for molecule type 1"
)
# Testing molname error
with self.assertRaises(ValueError) as context:
player.dice.config.ljname = "phb.error.molname.ljc"
player.read_potentials()
self.assertEqual(
str(context.exception),
"Error: expected nsites and molname for the molecule type 1"
)
@mock.patch("builtins.open", mock_open)
@mock.patch("diceplayer.player.Path.exists", return_value=True)
def test_print_potentials(self, mock_path_exists):
player = Player("control.test.yml")
player.read_potentials()
with self.assertLogs(level='INFO') as context:
player.print_potentials()
expected_output = ['INFO:diceplayer:==========================================================================================\n', 'INFO:diceplayer: Potential parameters from file phb.ljc:', 'INFO:diceplayer:------------------------------------------------------------------------------------------\n', 'INFO:diceplayer:Combination rule: *', 'INFO:diceplayer:Types of molecules: 2\n', 'INFO:diceplayer:1 atoms in molecule type 1:', 'INFO:diceplayer:---------------------------------------------------------------------------------', 'INFO:diceplayer:Lbl AN X Y Z Charge Epsilon Sigma Mass', 'INFO:diceplayer:---------------------------------------------------------------------------------', 'INFO:diceplayer:1 1 0.00000 0.00000 0.00000 0.000000 0.00000 0.0000 1.0079', 'INFO:diceplayer:\n', 'INFO:diceplayer:1 atoms in molecule type 2:', 'INFO:diceplayer:---------------------------------------------------------------------------------', 'INFO:diceplayer:Lbl AN X Y Z Charge Epsilon Sigma Mass', 'INFO:diceplayer:---------------------------------------------------------------------------------', 'INFO:diceplayer:1 1 0.00000 0.00000 0.00000 0.000000 0.00000 0.0000 1.0079', 'INFO:diceplayer:\n', 'INFO:diceplayer:==========================================================================================']
self.assertEqual(
context.output,
expected_output
)
@mock.patch("builtins.open", mock_open)
def test_dice_start(self):
player = Player("control.test.yml")
player.dice = mock.MagicMock()
player.dice.start = mock.MagicMock()
player.dice_start(1)
player.dice.start.assert_called_once()
@mock.patch("builtins.open", mock_open)
def test_gaussian_start(self):
player = Player("control.test.yml")
player.gaussian = mock.MagicMock()
player.gaussian.start = mock.MagicMock()
player.gaussian_start(1)
player.gaussian.start.assert_called_once()
if __name__ == '__main__':
unittest.main()