diff --git a/control.example.yml b/control.example.yml index 705e4cf..d75e9c4 100644 --- a/control.example.yml +++ b/control.example.yml @@ -3,7 +3,7 @@ diceplayer: switch_cyc: 3 max_cyc: 5 mem: 24 - ncores: 5 + ncores: 20 qmprog: 'g16' lps: no ghosts: no @@ -11,14 +11,15 @@ diceplayer: dice: nprocs: 4 - nmol: [1, 100] + nmol: [1, 1000] dens: 1.5 nstep: [2000, 3000] - isave: 1000 + isave: 0 outname: 'phb' progname: '~/.local/bin/dice' - ljname: 'phb.ljc' + ljname: 'phb.ljc.example' randominit: 'always' + seed: 12345 gaussian: qmprog: 'g16' diff --git a/diceplayer/config/dice_config.py b/diceplayer/config/dice_config.py index 99b59ab..c9b0429 100644 --- a/diceplayer/config/dice_config.py +++ b/diceplayer/config/dice_config.py @@ -1,3 +1,5 @@ +from pathlib import Path + from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Literal @@ -16,7 +18,7 @@ class DiceConfig(BaseModel): ..., description="Number of processes to use for the DICE simulations" ) - ljname: str = Field(..., description="Name of the Lennard-Jones potential file") + ljname: Path = Field(..., description="Name of the Lennard-Jones potential file") outname: str = Field( ..., description="Name of the output file for the simulation results" ) @@ -47,10 +49,13 @@ class DiceConfig(BaseModel): isave: int = Field(1000, description="Frequency of saving the simulation results") press: float = Field(1.0, description="Pressure of the system") temp: float = Field(300.0, description="Temperature of the system") - progname: str = Field( + progname: Path = Field( "dice", description="Name of the program to run the simulation" ) randominit: str = Field( "first", description="Method for initializing the random number generator" ) - seed: int = Field(default_factory=lambda: random.randint(0, 2**32 - 1), description="Seed for the random number generator") + seed: int = Field( + default_factory=lambda: int(1e6 * random.random()), + description="Seed for the random number generator", + ) diff --git a/diceplayer/dice/dice_handler.py b/diceplayer/dice/dice_handler.py index f418cda..0a8e878 100644 --- a/diceplayer/dice/dice_handler.py +++ b/diceplayer/dice/dice_handler.py @@ -1,10 +1,14 @@ -import shutil - -from diceplayer.dice.dice_input import NVTTerConfig, NVTEqConfig, NPTEqConfig, NPTTerConfig +from diceplayer.dice.dice_input import ( + NPTEqConfig, + NPTTerConfig, + NVTEqConfig, + NVTTerConfig, +) from diceplayer.dice.dice_wrapper import DiceWrapper from diceplayer.logger import logger from diceplayer.state.state_model import StateModel +import shutil from pathlib import Path from threading import Thread @@ -15,7 +19,9 @@ class DiceHandler: def run(self, state: StateModel, cycle: int) -> StateModel: if self.dice_directory.exists(): - logger.info(f"Found dice directory: {self.dice_directory}, this directory will be purged for a clean state") + logger.info( + f"Found dice directory: {self.dice_directory}, this directory will be purged for a clean state" + ) shutil.rmtree(self.dice_directory) self.dice_directory.mkdir(parents=True) @@ -38,7 +44,9 @@ class DiceHandler: t.join() if len(results) != state.config.dice.nprocs: - raise RuntimeError(f"Expected {state.config.dice.nprocs} simulation results, but got {len(results)}") + raise RuntimeError( + f"Expected {state.config.dice.nprocs} simulation results, but got {len(results)}" + ) return results @@ -47,21 +55,23 @@ class DiceHandler: def commit_simulation_state(self, state: StateModel, result: dict) -> StateModel: return state - def _simulation_process(self, state: StateModel, cycle: int, proc: int, results: list[dict]) -> None: + def _simulation_process( + self, state: StateModel, cycle: int, proc: int, results: list[dict] + ) -> None: proc_directory = self.dice_directory / f"{proc:02d}" if proc_directory.exists(): shutil.rmtree(proc_directory) proc_directory.mkdir(parents=True) - dice = DiceWrapper( - state.config.dice, proc_directory - ) + dice = DiceWrapper(state.config.dice, proc_directory) - if state.config.dice.randominit == "first" and cycle == 0: + self._generate_phb_file(state, proc_directory) + + if state.config.dice.randominit == "first" and cycle >= 0: + self._generate_last_xyz(state, proc_directory) + else: nvt_ter_config = NVTTerConfig.from_config(state.config) dice.run(nvt_ter_config) - else: - self._generate_last_xyz(state, proc_directory) if len(state.config.dice.nstep) == 2: nvt_eq_config = NVTEqConfig.from_config(state.config) @@ -76,5 +86,30 @@ class DiceHandler: results.append(dice.extract_results()) - def _generate_last_xyz(self, state: StateModel, proc_directory: Path) -> None: - ... + @staticmethod + def _generate_phb_file(state: StateModel, proc_directory: Path) -> None: + fstr = "{:<3d} {:>3d} {:>10.5f} {:>10.5f} {:>10.5f} {:>10.6f} {:>9.5f} {:>7.4f}\n" + + phb_file = proc_directory / state.config.dice.ljname + + with open(phb_file, "w") as f: + f.write(f"{state.config.dice.combrule}\n") + f.write(f"{len(state.config.dice.nmol)}\n") + + for molecule in state.system.molecule: + f.write(f"{len(molecule.atom)} {molecule.molname}\n") + for atom in molecule.atom: + f.write( + fstr.format( + atom.lbl, + atom.na, + atom.rx, + atom.ry, + atom.rz, + atom.chg, + atom.eps, + atom.sig, + ) + ) + + def _generate_last_xyz(self, state: StateModel, proc_directory: Path) -> None: ... diff --git a/diceplayer/dice/dice_input.py b/diceplayer/dice/dice_input.py index 34adecd..5750128 100644 --- a/diceplayer/dice/dice_input.py +++ b/diceplayer/dice/dice_input.py @@ -9,6 +9,30 @@ from pathlib import Path from typing import Any, Sequence, TextIO + +DICE_KEYWORD_ORDER = [ + "title", + "ncores", + "ljname", + "outname", + "nmol", + "dens", + "temp", + "press", + "seed", + "init", + "nstep", + "vstep", + "mstop", + "accum", + "iprint", + "isave", + "irdf", + "upbuf", +] + + + @dataclass(slots=True) class BaseConfig(ABC): ncores: int @@ -18,13 +42,14 @@ class BaseConfig(ABC): temp: float seed: int isave: int - press: float = 1.0 def write(self, directory: Path, filename: str = "input") -> Path: input_path = directory / filename if input_path.exists(): - logger.info(f"Dice input file {input_path} already exists and will be overwritten") + logger.info( + f"Dice input file {input_path} already exists and will be overwritten" + ) input_path.unlink() input_path.parent.mkdir(parents=True, exist_ok=True) @@ -34,13 +59,18 @@ class BaseConfig(ABC): return input_path def write_dice_config(self, io_writer: TextIO) -> None: - for field in fields(self): - key = field.name - value = getattr(self, key) + values = {f.name: getattr(self, f.name) for f in fields(self)} + for key in DICE_KEYWORD_ORDER: + value = values.pop(key, None) if value is None: continue + io_writer.write(f"{key} = {self._serialize_value(value)}\n") + # write any remaining fields (future extensions) + for key, value in values.items(): + if value is None: + continue io_writer.write(f"{key} = {self._serialize_value(value)}\n") io_writer.write("$end\n") @@ -48,7 +78,7 @@ class BaseConfig(ABC): @classmethod def from_config(cls, config: PlayerConfig, **kwargs) -> Self: base_fields = cls._extract_base_fields(config) - return cls(**base_fields, **kwargs) + return cls(**(base_fields | kwargs)) @staticmethod def _extract_base_fields(config: PlayerConfig) -> dict[str, Any]: @@ -60,7 +90,6 @@ class BaseConfig(ABC): temp=config.dice.temp, seed=config.dice.seed, isave=config.dice.isave, - press=config.dice.press, ) @staticmethod @@ -91,10 +120,18 @@ class BaseConfig(ABC): @dataclass(slots=True) class NVTConfig(BaseConfig): title: str = "Diceplayer Run - NVT" - dens: float = 0.0 - nstep: int = 0 + dens: float = ... + nstep: int = ... vstep: int = 0 + @classmethod + def from_config(cls, config: PlayerConfig, **kwargs) -> Self: + return super(NVTConfig, cls).from_config( + config, + dens=config.dice.dens, + nstep=cls._get_nstep(config, 0), + ) + # ----------------------------------------------------- # NVT THERMALIZATION @@ -105,12 +142,12 @@ class NVTConfig(BaseConfig): class NVTTerConfig(NVTConfig): title: str = "Diceplayer Run - NVT Thermalization" upbuf: int = 360 + init: str = "yes" @classmethod def from_config(cls, config: PlayerConfig, **kwargs) -> Self: return super(NVTTerConfig, cls).from_config( config, - dens=config.dice.dens, nstep=cls._get_nstep(config, 0), upbuf=config.dice.upbuf, vstep=0, @@ -118,9 +155,7 @@ class NVTTerConfig(NVTConfig): ) def write(self, directory: Path, filename: str = "nvt.ter") -> Path: - return super(NVTTerConfig, self).write( - directory, filename - ) + return super(NVTTerConfig, self).write(directory, filename) # ----------------------------------------------------- @@ -132,12 +167,12 @@ class NVTTerConfig(NVTConfig): class NVTEqConfig(NVTConfig): title: str = "Diceplayer Run - NVT Production" irdf: int = 0 + init: str = "yesreadxyz" @classmethod def from_config(cls, config: PlayerConfig, **kwargs) -> Self: return super(NVTEqConfig, cls).from_config( config, - dens=config.dice.dens, nstep=cls._get_nstep(config, 1), irdf=config.dice.irdf, vstep=0, @@ -145,9 +180,7 @@ class NVTEqConfig(NVTConfig): ) def write(self, directory: Path, filename: str = "nvt.eq") -> Path: - return super(NVTEqConfig, self).write( - directory, filename - ) + return super(NVTEqConfig, self).write(directory, filename) # ----------------------------------------------------- @@ -160,6 +193,14 @@ class NPTConfig(BaseConfig): title: str = "Diceplayer Run - NPT" nstep: int = 0 vstep: int = 5000 + press: float = 1.0 + + @classmethod + def from_config(cls, config: PlayerConfig, **kwargs) -> Self: + return super(NPTConfig, cls).from_config( + config, + press=config.dice.press, + ) # ----------------------------------------------------- @@ -183,9 +224,7 @@ class NPTTerConfig(NPTConfig): ) def write(self, directory: Path, filename: str = "npt.ter") -> Path: - return super(NPTTerConfig, self).write( - directory, filename - ) + return super(NPTTerConfig, self).write(directory, filename) # ----------------------------------------------------- @@ -209,6 +248,4 @@ class NPTEqConfig(NPTConfig): ) def write(self, directory: Path, filename: str = "npt.eq") -> Path: - return super(NPTEqConfig, self).write( - directory, filename - ) \ No newline at end of file + return super(NPTEqConfig, self).write(directory, filename) diff --git a/diceplayer/dice/dice_wrapper.py b/diceplayer/dice/dice_wrapper.py index 46796a9..4ec76c5 100644 --- a/diceplayer/dice/dice_wrapper.py +++ b/diceplayer/dice/dice_wrapper.py @@ -1,12 +1,10 @@ -import subprocess -from typing import Final - import diceplayer.dice.dice_input as dice_input - -from pathlib import Path - from diceplayer.config import DiceConfig +import subprocess +from pathlib import Path +from typing import Final + DICE_FLAG_LINE: Final[int] = -2 DICE_END_FLAG: Final[str] = "End of simulation" @@ -22,8 +20,9 @@ class DiceWrapper: output_path = input_path.parent / (input_path.name + ".out") with open(output_path, "w") as outfile, open(input_path, "r") as infile: + bin_path = self.dice_config.progname.expanduser() exit_status = subprocess.call( - self.dice_config.progname, stdin=infile, stdout=outfile + bin_path, stdin=infile, stdout=outfile, cwd=self.working_directory ) if exit_status != 0: @@ -38,4 +37,3 @@ class DiceWrapper: def extract_results(self) -> dict: return {} - diff --git a/diceplayer/environment/atom.py b/diceplayer/environment/atom.py index c5c9dee..ac284fb 100644 --- a/diceplayer/environment/atom.py +++ b/diceplayer/environment/atom.py @@ -1,9 +1,9 @@ from diceplayer.utils.ptable import AtomInfo, PTable -from dataclasses import dataclass +from pydantic.dataclasses import dataclass -@dataclass +@dataclass(slots=True) class Atom: """ Atom class declaration. This class is used throughout the DicePlayer program to represent atoms. diff --git a/diceplayer/environment/molecule.py b/diceplayer/environment/molecule.py index 5ae59ea..ea96abe 100644 --- a/diceplayer/environment/molecule.py +++ b/diceplayer/environment/molecule.py @@ -10,10 +10,11 @@ import numpy as np import numpy.linalg as linalg import numpy.typing as npt from typing_extensions import List, Self, Tuple +from pydantic.dataclasses import dataclass import math from copy import deepcopy -from dataclasses import dataclass, field +from dataclasses import field from functools import cached_property diff --git a/diceplayer/optimization/optimization_handler.py b/diceplayer/optimization/optimization_handler.py index 9d0d86a..9fbd86a 100644 --- a/diceplayer/optimization/optimization_handler.py +++ b/diceplayer/optimization/optimization_handler.py @@ -16,4 +16,4 @@ class OptimizationHandler: if current_cycle < state.config.switch_cyc: return RoutineType.CHARGE - return RoutineType.GEOMETRY \ No newline at end of file + return RoutineType.GEOMETRY diff --git a/diceplayer/player.py b/diceplayer/player.py index 4f4979c..6b3c3ec 100644 --- a/diceplayer/player.py +++ b/diceplayer/player.py @@ -1,12 +1,13 @@ from diceplayer.config.player_config import PlayerConfig from diceplayer.dice.dice_handler import DiceHandler from diceplayer.logger import logger -from diceplayer.optimization.optimization_handler import OptimizationHandler from diceplayer.state.state_handler import StateHandler from diceplayer.state.state_model import StateModel from typing_extensions import TypedDict, Unpack +from diceplayer.utils.potential import read_system_from_phb + class PlayerFlags(TypedDict): continuation: bool @@ -30,9 +31,11 @@ class Player: self._state_handler.delete() state = None - if state is None: - state = StateModel.from_config(self.config) + system = read_system_from_phb(self.config) + state = StateModel( + config=self.config, system=system + ) else: logger.info("Resuming from existing state.") diff --git a/diceplayer/state/state_model.py b/diceplayer/state/state_model.py index a7e05b6..66b9df4 100644 --- a/diceplayer/state/state_model.py +++ b/diceplayer/state/state_model.py @@ -8,7 +8,7 @@ from typing_extensions import Self class StateModel(BaseModel): config: PlayerConfig system: System - current_cycle: int + current_cycle: int = 0 @classmethod def from_config(cls, config: PlayerConfig) -> Self: diff --git a/diceplayer/utils/potential.py b/diceplayer/utils/potential.py new file mode 100644 index 0000000..5d1ee73 --- /dev/null +++ b/diceplayer/utils/potential.py @@ -0,0 +1,62 @@ +from diceplayer.config import PlayerConfig +from diceplayer.environment import System, Molecule, Atom + +from pathlib import Path + +from diceplayer.logger import logger +from diceplayer.state.state_model import StateModel + + +def read_system_from_phb(config: PlayerConfig) -> System: + phb_file = Path(config.dice.ljname) + if not phb_file.exists(): + raise FileNotFoundError + + ljc_data = phb_file.read_text(encoding="utf-8").splitlines() + + combrule = ljc_data.pop(0).strip() + if combrule != config.dice.combrule: + raise ValueError( + f"Invalid combrule defined in {phb_file}. Expected the same value configured in the config file" + ) + + ntypes = ljc_data.pop(0).strip() + if not ntypes.isdigit(): + raise ValueError(f"Invalid ntypes defined in {phb_file}") + + nmol = int(ntypes) + if nmol != len(config.dice.nmol): + raise ValueError(f"Invalid nmol defined in {phb_file}") + + sys = System() + + for i in range(nmol): + nsites, molname = ljc_data.pop(0).split() + + if not nsites.isdigit(): + raise ValueError(f"Invalid nsites defined in {phb_file}") + nsites = int(nsites) + + mol = Molecule(molname) + + for j in range(nsites): + _fields = ljc_data.pop(0).split() + mol.add_atom( + Atom(*_fields) + ) + + sys.add_type(mol) + + return sys + + +# def write_phb(phb_file: Path, state: StateModel) -> None: +# if phb_file.exists(): +# raise RuntimeError(f"File {phb_file} already exists") +# +# with open(phb_file, "w") as f: +# f.write(f"{state.config.dice.combrule}\n") +# f.write(f"{len(state.system.nmols)}\n") +# f.write(f"{state.config.dice.nmol}\n") + + diff --git a/tests/config/test_dice_config.py b/tests/config/test_dice_config.py index 8afe05e..6765c24 100644 --- a/tests/config/test_dice_config.py +++ b/tests/config/test_dice_config.py @@ -6,6 +6,7 @@ import pytest class TestDiceConfig: def test_class_instantiation(self): dice_dto = DiceConfig( + nprocs=1, ljname="test", outname="test", dens=1.0, @@ -18,6 +19,7 @@ class TestDiceConfig: def test_validate_jname(self): with pytest.raises(ValueError) as ex: DiceConfig( + nprocs=1, ljname=None, outname="test", dens=1.0, @@ -30,6 +32,7 @@ class TestDiceConfig: def test_validate_outname(self): with pytest.raises(ValueError) as ex: DiceConfig( + nprocs=1, ljname="test", outname=None, dens=1.0, @@ -42,6 +45,7 @@ class TestDiceConfig: def test_validate_dens(self): with pytest.raises(ValueError) as ex: DiceConfig( + nprocs=1, ljname="test", outname="test", dens=None, @@ -54,6 +58,7 @@ class TestDiceConfig: def test_validate_nmol(self): with pytest.raises(ValueError) as ex: DiceConfig( + nprocs=1, ljname="test", outname="test", dens=1.0, @@ -66,6 +71,7 @@ class TestDiceConfig: def test_validate_nstep(self): with pytest.raises(ValueError) as ex: DiceConfig( + nprocs=1, ljname="test", outname="test", dens=1.0, @@ -78,6 +84,7 @@ class TestDiceConfig: def test_from_dict(self): dice_dto = DiceConfig.model_validate( { + "nprocs": 1, "ljname": "test", "outname": "test", "dens": 1.0, diff --git a/tests/config/test_player_config.py b/tests/config/test_player_config.py index b26baf9..9e994da 100644 --- a/tests/config/test_player_config.py +++ b/tests/config/test_player_config.py @@ -11,6 +11,7 @@ class TestPlayerConfig: @pytest.fixture def dice_payload(self) -> dict[str, Any]: return { + "nprocs": 4, "ljname": "test", "outname": "test", "dens": 1.0, @@ -35,7 +36,6 @@ class TestPlayerConfig: "mem": 12, "max_cyc": 100, "switch_cyc": 50, - "nprocs": 4, "ncores": 4, "dice": dice_payload, "gaussian": gaussian_payload, @@ -57,7 +57,6 @@ class TestPlayerConfig: mem=12, max_cyc=100, switch_cyc=50, - nprocs=4, ncores=4, dice=dice_config, gaussian=gaussian_config, @@ -75,7 +74,6 @@ class TestPlayerConfig: mem=12, max_cyc=100, switch_cyc=50, - nprocs=4, ncores=4, altsteps=0, dice=dice_config, diff --git a/tests/dice/test_dice_input.py b/tests/dice/test_dice_input.py index 2e1dad1..8ed6dc0 100644 --- a/tests/dice/test_dice_input.py +++ b/tests/dice/test_dice_input.py @@ -20,9 +20,9 @@ class TestDiceInput: "mem": 12, "max_cyc": 100, "switch_cyc": 50, - "nprocs": 4, "ncores": 4, "dice": { + "nprocs": 4, "ljname": "test", "outname": "test", "dens": 1.0, diff --git a/tests/mocks/mock_inputs.py b/tests/mocks/mock_inputs.py deleted file mode 100644 index eadf05b..0000000 --- a/tests/mocks/mock_inputs.py +++ /dev/null @@ -1,113 +0,0 @@ -from unittest import mock - - -def get_config_example(): - return """ -diceplayer: - opt: no - mem: 12 - maxcyc: 3 - ncores: 4 - nprocs: 4 - qmprog: 'g16' - lps: no - ghosts: no - altsteps: 20000 - - dice: - nmol: [1, 50] - dens: 0.75 - nstep: [2000, 3000, 4000] - isave: 1000 - outname: 'phb' - progname: '~/.local/bin/dice' - ljname: 'phb.ljc' - randominit: 'first' - - gaussian: - qmprog: 'g16' - level: 'MP2/aug-cc-pVDZ' - keywords: 'freq' -""" - - -def get_potentials_exemple(): - return """\ -* -2 -1 TEST - 1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000 -1 PLACEHOLDER - 1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000 -""" - - -def get_potentials_error_combrule(): - return """\ -. -2 -1 TEST - 1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000 -1 PLACEHOLDER - 1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000 -""" - - -def get_potentials_error_ntypes(): - return """\ -* -a -1 TEST - 1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000 -1 PLACEHOLDER - 1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000 -""" - - -def get_potentials_error_ntypes_config(): - return """\ -* -3 -1 TEST - 1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000 -1 PLACEHOLDER - 1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000 -""" - - -def get_potentials_error_nsites(): - return """\ -* -2 -. TEST - 1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000 -1 PLACEHOLDER - 1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000 -""" - - -def get_potentials_error_molname(): - return """\ -* -2 -1 - 1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000 -1 PLACEHOLDER - 1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000 -""" - - -def mock_open(file, *args, **kwargs): - values = { - "control.test.yml": get_config_example(), - "phb.ljc": get_potentials_exemple(), - "phb.error.combrule.ljc": get_potentials_error_combrule(), - "phb.error.ntypes.ljc": get_potentials_error_ntypes(), - "phb.error.ntypes.config.ljc": get_potentials_error_ntypes_config(), - "phb.error.nsites.ljc": get_potentials_error_nsites(), - "phb.error.molname.ljc": get_potentials_error_molname(), - } - if file in values: - return mock.mock_open(read_data=values[file])() - - return mock.mock_open(read_data="")() diff --git a/tests/mocks/mock_proc.py b/tests/mocks/mock_proc.py deleted file mode 100644 index 2dfa1d1..0000000 --- a/tests/mocks/mock_proc.py +++ /dev/null @@ -1,32 +0,0 @@ -from typing_extensions import List - -import itertools - - -class MockProc: - pid_counter = itertools.count() - - def __init__(self, *args, **kwargs): - self.pid = next(MockProc.pid_counter) - - if "exitcode" in kwargs: - self.exitcode = kwargs["exitcode"] - else: - self.exitcode = 0 - - self.sentinel = self.pid - - def __call__(self, *args, **kwargs): - return self - - def start(self): - pass - - def terminate(self): - pass - - -class MockConnection: - @staticmethod - def wait(sentinels: List[int]): - return sentinels diff --git a/tests/state/test_state_handler.py b/tests/state/test_state_handler.py index 6c0a77c..7287d45 100644 --- a/tests/state/test_state_handler.py +++ b/tests/state/test_state_handler.py @@ -16,9 +16,9 @@ class TestStateHandler: mem=12, max_cyc=100, switch_cyc=50, - nprocs=4, ncores=4, dice=DiceConfig( + nprocs=4, ljname="test", outname="test", dens=1.0, diff --git a/tests/mocks/__init__.py b/tests/utils/__init__.py similarity index 100% rename from tests/mocks/__init__.py rename to tests/utils/__init__.py diff --git a/tests/utils/test_potential.py b/tests/utils/test_potential.py new file mode 100644 index 0000000..a3b984d --- /dev/null +++ b/tests/utils/test_potential.py @@ -0,0 +1,38 @@ +from pathlib import Path +from typing import Any + +import pytest + +from diceplayer.config import PlayerConfig +from diceplayer.environment import System +from diceplayer.utils.potential import read_system_from_phb + + +class TestPotential: + @pytest.fixture + def player_config(self) -> PlayerConfig: + return PlayerConfig.model_validate({ + "type": "both", + "mem": 12, + "max_cyc": 100, + "switch_cyc": 50, + "ncores": 4, + "dice": { + "nprocs": 4, + "ljname": "phb.ljc.example", + "outname": "test", + "dens": 1.0, + "nmol": [12, 16], + "nstep": [1, 1], + }, + "gaussian": { + "level": "test", + "qmprog": "g16", + "keywords": "test", + }, + }) + + def test_read_phb(self, player_config: PlayerConfig): + system = read_system_from_phb(player_config) + + assert isinstance(system, System)