diff --git a/diceplayer/__init__.py b/diceplayer/__init__.py index 36e8f16..e69de29 100644 --- a/diceplayer/__init__.py +++ b/diceplayer/__init__.py @@ -1,3 +0,0 @@ -from diceplayer.utils.logger import RunLogger - - diff --git a/diceplayer/__main__.py b/diceplayer/__main__.py index 538be5f..762680f 100644 --- a/diceplayer/__main__.py +++ b/diceplayer/__main__.py @@ -1,9 +1,7 @@ -from diceplayer.config.player_config import PlayerConfig +from diceplayer.cli import ArgsModel, read_input from diceplayer.logger import logger from diceplayer.player import Player -import yaml - import argparse from importlib import metadata @@ -11,17 +9,6 @@ from importlib import metadata VERSION = metadata.version("diceplayer") -def read_input(infile) -> PlayerConfig: - try: - with open(infile, "r") as f: - return PlayerConfig.model_validate( - yaml.safe_load(f) - ) - except Exception as e: - logger.exception("Failed to read input file") - raise e - - def main(): parser = argparse.ArgumentParser(prog="Diceplayer") parser.add_argument( @@ -46,7 +33,7 @@ def main(): metavar="OUTFILE", help="output file of diceplayer [default = run.log]", ) - args = parser.parse_args() + args = ArgsModel.from_args(parser.parse_args()) logger.set_output_file(args.outfile) @@ -56,4 +43,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/diceplayer/cli/__init__.py b/diceplayer/cli/__init__.py new file mode 100644 index 0000000..266597a --- /dev/null +++ b/diceplayer/cli/__init__.py @@ -0,0 +1,5 @@ +from .args_model import ArgsModel +from .read_input_file import read_input + + +__all__ = ["ArgsModel", "read_input"] diff --git a/diceplayer/cli/args_model.py b/diceplayer/cli/args_model.py new file mode 100644 index 0000000..8c5feaa --- /dev/null +++ b/diceplayer/cli/args_model.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel + + +class ArgsModel(BaseModel): + outfile: str + infile: str + continuation: bool + + @classmethod + def from_args(cls, args): + return cls(**vars(args)) diff --git a/diceplayer/cli/read_input_file.py b/diceplayer/cli/read_input_file.py new file mode 100644 index 0000000..36f8946 --- /dev/null +++ b/diceplayer/cli/read_input_file.py @@ -0,0 +1,13 @@ +from diceplayer.config import PlayerConfig +from diceplayer.logger import logger + +import yaml + + +def read_input(infile) -> PlayerConfig: + try: + with open(infile, "r") as f: + return PlayerConfig.model_validate(yaml.safe_load(f)) + except Exception as e: + logger.exception("Failed to read input file") + raise e diff --git a/diceplayer/config/__init__.py b/diceplayer/config/__init__.py index e69de29..6d7a441 100644 --- a/diceplayer/config/__init__.py +++ b/diceplayer/config/__init__.py @@ -0,0 +1,10 @@ +from .dice_config import DiceConfig +from .gaussian_config import GaussianConfig +from .player_config import PlayerConfig + + +__all__ = [ + "DiceConfig", + "GaussianConfig", + "PlayerConfig", +] diff --git a/diceplayer/config/dice_config.py b/diceplayer/config/dice_config.py index 30af670..ccfebd8 100644 --- a/diceplayer/config/dice_config.py +++ b/diceplayer/config/dice_config.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict from typing_extensions import List, Literal @@ -6,6 +6,9 @@ class DiceConfig(BaseModel): """ Data Transfer Object for the Dice configuration. """ + model_config = ConfigDict( + frozen=True, + ) ljname: str = Field(..., description="Name of the Lennard-Jones potential file") outname: str = Field( diff --git a/diceplayer/config/gaussian_config.py b/diceplayer/config/gaussian_config.py index 9cd838f..ea2d5c2 100644 --- a/diceplayer/config/gaussian_config.py +++ b/diceplayer/config/gaussian_config.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict from typing_extensions import Literal @@ -6,6 +6,9 @@ class GaussianConfig(BaseModel): """ Data Transfer Object for the Gaussian configuration. """ + model_config = ConfigDict( + frozen=True, + ) level: str = Field(..., description="Level of theory for the QM calculations") qmprog: Literal["g03", "g09", "g16"] = Field( diff --git a/diceplayer/config/player_config.py b/diceplayer/config/player_config.py index c47248c..ab7caa3 100644 --- a/diceplayer/config/player_config.py +++ b/diceplayer/config/player_config.py @@ -1,19 +1,23 @@ from diceplayer.config.dice_config import DiceConfig from diceplayer.config.gaussian_config import GaussianConfig -from pydantic import BaseModel, Field, model_validator -from typing_extensions import Self +from pydantic import BaseModel, Field, model_validator, ConfigDict +from typing_extensions import Self, Any from pathlib import Path MIN_STEP = 20000 +STEP_INCREMENT = 1000 class PlayerConfig(BaseModel): """ Data Transfer Object for the player configuration. """ + model_config = ConfigDict( + frozen=True, + ) opt: bool = Field(..., description="Whether to perform geometry optimization") maxcyc: int = Field( @@ -42,7 +46,9 @@ class PlayerConfig(BaseModel): "simfiles", description="Directory name for the simulation files" ) - @model_validator(mode="after") - def validate_altsteps(self) -> Self: - self.altsteps = round(max(MIN_STEP, self.altsteps) / 1000) * 1000 - return self + @model_validator(mode="before") + @staticmethod + def validate_altsteps(fields) -> dict[str, Any]: + altsteps = fields.pop("altsteps", MIN_STEP) + fields["altsteps"] = round(max(MIN_STEP, altsteps) / STEP_INCREMENT) * STEP_INCREMENT + return fields diff --git a/diceplayer/logger.py b/diceplayer/logger.py index 9b841c2..6bb5249 100644 --- a/diceplayer/logger.py +++ b/diceplayer/logger.py @@ -1,4 +1,4 @@ -from diceplayer import RunLogger +from diceplayer.utils import RunLogger logger = RunLogger("diceplayer") diff --git a/diceplayer/player.py b/diceplayer/player.py index 6800b28..8aff1c7 100644 --- a/diceplayer/player.py +++ b/diceplayer/player.py @@ -5,5 +5,4 @@ class Player: def __init__(self, config: PlayerConfig): self.config = config - def play(self, continuation = False): - ... \ No newline at end of file + def play(self, continuation=False): ... diff --git a/diceplayer/state/__init__.py b/diceplayer/state/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/diceplayer/state/state_handler.py b/diceplayer/state/state_handler.py new file mode 100644 index 0000000..fa456f4 --- /dev/null +++ b/diceplayer/state/state_handler.py @@ -0,0 +1,31 @@ +import pickle +from pathlib import Path + +from diceplayer.config import PlayerConfig +from diceplayer.environment import System +from diceplayer.logger import logger +from diceplayer.state.state_model import StateModel + + +class StateHandler: + def __init__(self, sim_dir: Path, state_file: str = "state.pkl"): + self._state_file = sim_dir / state_file + + def get_state(self, config: PlayerConfig) -> StateModel | None: + if not self._state_file.exists(): + return None + + with self._state_file.open(mode="r") as f: + data = pickle.load(f) + + model = StateModel.model_validate(data) + + if hash(model.config) != hash(config): + logger.warning("The configuration in the state file does not match the provided configuration.") + return None + + return model + + def save_state(self, state: StateModel) -> None: + with self._state_file.open(mode="wb") as f: + pickle.dump(state.model_dump(), f) \ No newline at end of file diff --git a/diceplayer/state/state_model.py b/diceplayer/state/state_model.py new file mode 100644 index 0000000..7a1099f --- /dev/null +++ b/diceplayer/state/state_model.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel + +from diceplayer.config import PlayerConfig +from diceplayer.environment import System + + +class StateModel(BaseModel): + config: PlayerConfig + system: System + current_cycle: int \ No newline at end of file diff --git a/diceplayer/utils/logger.py b/diceplayer/utils/logger.py index 51bcb8c..8c08d02 100644 --- a/diceplayer/utils/logger.py +++ b/diceplayer/utils/logger.py @@ -5,7 +5,7 @@ import sys from pathlib import Path -H = TypeVar('H', bound=logging.Handler) +H = TypeVar("H", bound=logging.Handler) class RunLogger(logging.Logger): @@ -18,33 +18,27 @@ class RunLogger(logging.Logger): self._configure_handler(logging.StreamHandler(stream), level) ) - def set_output_file(self, outfile: Path, level=logging.INFO): for handler in list(self.handlers): if not isinstance(handler, logging.FileHandler): continue self.handlers.remove(handler) - self.handlers.append( - self._create_file_handler(outfile, level) - ) - - + self.handlers.append(self._create_file_handler(outfile, level)) @staticmethod - def _create_file_handler(file: str|Path, level) -> logging.FileHandler: + def _create_file_handler(file: str | Path, level) -> logging.FileHandler: file = Path(file) if file.exists(): - file.rename(file.with_suffix('.log.backup')) + file.rename(file.with_suffix(".log.backup")) handler = logging.FileHandler(file) return RunLogger._configure_handler(handler, level) - @staticmethod def _configure_handler(handler: H, level) -> H: handler.setLevel(level) - formatter = logging.Formatter('%(message)s') + formatter = logging.Formatter("%(message)s") handler.setFormatter(formatter) - return handler \ No newline at end of file + return handler diff --git a/tests/state/__init__.py b/tests/state/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/state/test_state_handler.py b/tests/state/test_state_handler.py new file mode 100644 index 0000000..226df9f --- /dev/null +++ b/tests/state/test_state_handler.py @@ -0,0 +1,47 @@ +import tempfile +from pathlib import Path + +import pytest + +from diceplayer.config import PlayerConfig, DiceConfig, GaussianConfig +from diceplayer.state.state_handler import StateHandler + + +class TestStateHandler: + @pytest.fixture + def player_config(self) -> PlayerConfig: + return PlayerConfig( + opt=True, + mem=12, + maxcyc=100, + nprocs=4, + ncores=4, + dice=DiceConfig( + ljname="test", + outname="test", + dens=1.0, + nmol=[1], + nstep=[1, 1], + ), + gaussian=GaussianConfig( + level="test", + qmprog="g16", + keywords="test", + ), + ) + + def test_state_handler_initialization(self): + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir_path = Path(tmpdir) + state_handler = StateHandler(tmpdir_path) + + assert isinstance(state_handler, StateHandler) + + def test_state_handler_get_state(self, player_config: PlayerConfig): + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir_path = Path(tmpdir) + state_handler = StateHandler(tmpdir_path) + + state = state_handler.get_state(player_config) + + assert state is None \ No newline at end of file