diff --git a/control.example.yml b/control.example.yml index 0830d4d..6ecbb53 100644 --- a/control.example.yml +++ b/control.example.yml @@ -1,7 +1,8 @@ diceplayer: - opt: no + type: both + switch_cyc: 3 + max_cyc: 5 mem: 24 - maxcyc: 5 ncores: 5 nprocs: 4 qmprog: 'g16' diff --git a/diceplayer/__main__.py b/diceplayer/__main__.py index 762680f..fed4452 100644 --- a/diceplayer/__main__.py +++ b/diceplayer/__main__.py @@ -1,4 +1,5 @@ from diceplayer.cli import ArgsModel, read_input +from diceplayer.config import PlayerConfig from diceplayer.logger import logger from diceplayer.player import Player @@ -33,13 +34,26 @@ def main(): metavar="OUTFILE", help="output file of diceplayer [default = run.log]", ) + parser.add_argument( + "-f", + "--force", + dest="force", + default=False, + action="store_true", + help="force overwrite existing state file if it exists [default = False]", + ) args = ArgsModel.from_args(parser.parse_args()) logger.set_output_file(args.outfile) - config = read_input(args.infile) + config: PlayerConfig + try: + config = read_input(args.infile) + except Exception as e: + logger.error(f"Failed to read input file: {e}") + return - Player(config).play(continuation=args.continuation) + Player(config).play(continuation=args.continuation, force=args.force) if __name__ == "__main__": diff --git a/diceplayer/cli/args_model.py b/diceplayer/cli/args_model.py index 8c5feaa..5cf107a 100644 --- a/diceplayer/cli/args_model.py +++ b/diceplayer/cli/args_model.py @@ -5,6 +5,7 @@ class ArgsModel(BaseModel): outfile: str infile: str continuation: bool + force: bool @classmethod def from_args(cls, args): diff --git a/diceplayer/cli/read_input_file.py b/diceplayer/cli/read_input_file.py index 36f8946..5fd6d82 100644 --- a/diceplayer/cli/read_input_file.py +++ b/diceplayer/cli/read_input_file.py @@ -1,13 +1,9 @@ from diceplayer.config import PlayerConfig -from diceplayer.logger import logger import yaml def read_input(infile) -> PlayerConfig: - try: - with open(infile, "r") as f: - return PlayerConfig.model_validate(yaml.safe_load(f)) - except Exception as e: - logger.exception("Failed to read input file") - raise e + with open(infile, "r") as f: + values = yaml.safe_load(f) + return PlayerConfig.model_validate(values["diceplayer"]) diff --git a/diceplayer/config/dice_config.py b/diceplayer/config/dice_config.py index ccfebd8..e788378 100644 --- a/diceplayer/config/dice_config.py +++ b/diceplayer/config/dice_config.py @@ -1,11 +1,12 @@ -from pydantic import BaseModel, Field, ConfigDict -from typing_extensions import List, Literal +from pydantic import BaseModel, ConfigDict, Field +from typing_extensions import Literal class DiceConfig(BaseModel): """ Data Transfer Object for the Dice configuration. """ + model_config = ConfigDict( frozen=True, ) @@ -15,10 +16,10 @@ class DiceConfig(BaseModel): ..., description="Name of the output file for the simulation results" ) dens: float = Field(..., description="Density of the system") - nmol: List[int] = Field( + nmol: list[int] = Field( ..., description="List of the number of molecules for each component" ) - nstep: List[int] = Field( + nstep: list[int] = Field( ..., description="List of the number of steps for each component", min_length=2, diff --git a/diceplayer/config/gaussian_config.py b/diceplayer/config/gaussian_config.py index ea2d5c2..3ea014e 100644 --- a/diceplayer/config/gaussian_config.py +++ b/diceplayer/config/gaussian_config.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Literal @@ -6,14 +6,15 @@ class GaussianConfig(BaseModel): """ Data Transfer Object for the Gaussian configuration. """ + model_config = ConfigDict( frozen=True, ) - level: str = Field(..., description="Level of theory for the QM calculations") qmprog: Literal["g03", "g09", "g16"] = Field( "g16", description="QM program to use for the calculations" ) + level: str = Field(..., description="Level of theory for the QM calculations") chgmult: list[int] = Field( default_factory=lambda: [0, 1], @@ -23,6 +24,6 @@ class GaussianConfig(BaseModel): "chelpg", description="Population analysis method for the QM calculations" ) chg_tol: float = Field(0.01, description="Charge tolerance for the QM calculations") - keywords: str = Field( + keywords: str | None = Field( None, description="Additional keywords for the QM calculations" ) diff --git a/diceplayer/config/player_config.py b/diceplayer/config/player_config.py index ab7caa3..82701ef 100644 --- a/diceplayer/config/player_config.py +++ b/diceplayer/config/player_config.py @@ -1,9 +1,10 @@ from diceplayer.config.dice_config import DiceConfig from diceplayer.config.gaussian_config import GaussianConfig -from pydantic import BaseModel, Field, model_validator, ConfigDict -from typing_extensions import Self, Any +from pydantic import BaseModel, ConfigDict, Field, model_validator +from typing_extensions import Any +from enum import Enum from pathlib import Path @@ -11,18 +12,28 @@ MIN_STEP = 20000 STEP_INCREMENT = 1000 +class RoutineType(str, Enum): + CHARGE = "charge" + GEOMETRY = "geometry" + BOTH = "both" + + class PlayerConfig(BaseModel): """ Data Transfer Object for the player configuration. """ + model_config = ConfigDict( frozen=True, ) - opt: bool = Field(..., description="Whether to perform geometry optimization") - maxcyc: int = Field( - ..., description="Maximum number of cycles for the geometry optimization" + type: RoutineType = Field(..., description="Type of simulation to perform") + max_cyc: int = Field( + ..., description="Maximum number of cycles for the geometry optimization", gt=0 ) + switch_cyc: int = Field(..., description="Switch cycle configuration") + + mem: int = Field(None, description="Memory configuration") nprocs: int = Field( ..., description="Number of processors to use for the QM calculations" ) @@ -33,22 +44,37 @@ class PlayerConfig(BaseModel): dice: DiceConfig = Field(..., description="Dice configuration") gaussian: GaussianConfig = Field(..., description="Gaussian configuration") - mem: int = Field(None, description="Memory configuration") - switchcyc: int = Field(3, description="Switch cycle configuration") - qmprog: str = Field("g16", description="QM program to use for the calculations") altsteps: int = Field( 20000, description="Number of steps for the alternate simulation" ) geoms_file: Path = Field( - "geoms.xyz", description="File name for the geometries output" + Path("geoms.xyz"), description="File name for the geometries output" ) simulation_dir: Path = Field( - "simfiles", description="Directory name for the simulation files" + Path("simfiles"), description="Directory name for the simulation files" ) @model_validator(mode="before") @staticmethod def validate_altsteps(fields) -> dict[str, Any]: altsteps = fields.pop("altsteps", MIN_STEP) - fields["altsteps"] = round(max(MIN_STEP, altsteps) / STEP_INCREMENT) * STEP_INCREMENT + fields["altsteps"] = ( + round(max(MIN_STEP, altsteps) / STEP_INCREMENT) * STEP_INCREMENT + ) + return fields + + @model_validator(mode="before") + @staticmethod + def validate_switch_cyc(fields: dict[str, Any]) -> dict[str, Any]: + max_cyc = int(fields.get("max_cyc", 0)) + switch_cyc = int(fields.get("switch_cyc", max_cyc)) + + if fields.get("type") == "both" and not switch_cyc < max_cyc: + raise ValueError("switch_cyc must be less than max_cyc when type='both'.") + + if fields.get("type") != "both" and switch_cyc != max_cyc: + raise ValueError( + "switch_cyc must be equal to max_cyc when type is not 'both'." + ) + return fields diff --git a/diceplayer/player.py b/diceplayer/player.py index 8aff1c7..54cb094 100644 --- a/diceplayer/player.py +++ b/diceplayer/player.py @@ -1,8 +1,41 @@ from diceplayer.config.player_config import PlayerConfig +from diceplayer.logger import logger +from diceplayer.state.state_handler import StateHandler +from diceplayer.state.state_model import StateModel + +from typing_extensions import TypedDict, Unpack + + +class PlayerFlags(TypedDict): + continuation: bool + force: bool class Player: def __init__(self, config: PlayerConfig): self.config = config - def play(self, continuation=False): ... + def play(self, **flags: Unpack[PlayerFlags]): + state_handler = StateHandler(self.config.simulation_dir) + + if not flags["continuation"]: + logger.info( + "Continuation flag is not set. Starting a new simulation and deleting any existing state." + ) + state_handler.delete() + + state = state_handler.get(self.config, force=flags["force"]) + + if state is None: + state = StateModel.from_config(self.config) + else: + logger.info("Resuming from existing state.") + + while state.current_cycle < self.config.max_cyc: + logger.info( + f"Starting cycle {state.current_cycle + 1} of {self.config.max_cyc}." + ) + state.current_cycle += 1 + state_handler.save(state) + + logger.info("Reached maximum number of cycles. Simulation complete.") diff --git a/diceplayer/state/state_handler.py b/diceplayer/state/state_handler.py index fa456f4..b2056f4 100644 --- a/diceplayer/state/state_handler.py +++ b/diceplayer/state/state_handler.py @@ -1,31 +1,37 @@ -import pickle -from pathlib import Path - from diceplayer.config import PlayerConfig -from diceplayer.environment import System from diceplayer.logger import logger from diceplayer.state.state_model import StateModel +import pickle +from pathlib import Path + class StateHandler: def __init__(self, sim_dir: Path, state_file: str = "state.pkl"): + if not sim_dir.exists(): + sim_dir.mkdir(parents=True, exist_ok=True) self._state_file = sim_dir / state_file - def get_state(self, config: PlayerConfig) -> StateModel | None: + def get(self, config: PlayerConfig, force=False) -> StateModel | None: if not self._state_file.exists(): return None - with self._state_file.open(mode="r") as f: - data = pickle.load(f) - + with open(self._state_file, mode="rb") as file: + data = pickle.load(file) model = StateModel.model_validate(data) - if hash(model.config) != hash(config): - logger.warning("The configuration in the state file does not match the provided configuration.") + if config != model.config and not force: + logger.warning( + "The configuration in the state file does not match the provided configuration." + ) return None return model - def save_state(self, state: StateModel) -> None: + def save(self, state: StateModel) -> None: with self._state_file.open(mode="wb") as f: - pickle.dump(state.model_dump(), f) \ No newline at end of file + pickle.dump(state.model_dump(), f) + + def delete(self) -> None: + if self._state_file.exists(): + self._state_file.unlink() diff --git a/diceplayer/state/state_model.py b/diceplayer/state/state_model.py index 7a1099f..a7e05b6 100644 --- a/diceplayer/state/state_model.py +++ b/diceplayer/state/state_model.py @@ -1,10 +1,19 @@ -from pydantic import BaseModel - from diceplayer.config import PlayerConfig from diceplayer.environment import System +from pydantic import BaseModel +from typing_extensions import Self + class StateModel(BaseModel): config: PlayerConfig system: System - current_cycle: int \ No newline at end of file + current_cycle: int + + @classmethod + def from_config(cls, config: PlayerConfig) -> Self: + return cls( + config=config, + system=System(), + current_cycle=0, + ) diff --git a/tests/cli/__init__.py b/tests/cli/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/cli/test_read_input_file.py b/tests/cli/test_read_input_file.py new file mode 100644 index 0000000..191d797 --- /dev/null +++ b/tests/cli/test_read_input_file.py @@ -0,0 +1,30 @@ +import diceplayer +from diceplayer.cli import read_input +from diceplayer.config import PlayerConfig + +import pytest + +from pathlib import Path + + +class TestReadInputFile: + @pytest.fixture + def example_config(self) -> Path: + return Path(diceplayer.__path__[0]).parent / "control.example.yml" + + def test_read_input_file(self, example_config: Path): + config = read_input(example_config) + + assert config is not None + assert isinstance(config, PlayerConfig) + + def test_read_input_non_existing_file(self): + with pytest.raises(FileNotFoundError): + read_input("nonexistent_file.yml") + + def test_read_input_invalid_yaml(self, tmp_path: Path): + invalid_yaml_file = tmp_path / "invalid.yml" + invalid_yaml_file.write_text("This is not valid YAML: [unbalanced brackets") + + with pytest.raises(Exception): + read_input(invalid_yaml_file) diff --git a/tests/state/test_state_handler.py b/tests/state/test_state_handler.py index 226df9f..b86196f 100644 --- a/tests/state/test_state_handler.py +++ b/tests/state/test_state_handler.py @@ -1,10 +1,11 @@ -import tempfile -from pathlib import Path +from diceplayer.config import DiceConfig, GaussianConfig, PlayerConfig +from diceplayer.environment import System +from diceplayer.state.state_handler import StateHandler +from diceplayer.state.state_model import StateModel import pytest -from diceplayer.config import PlayerConfig, DiceConfig, GaussianConfig -from diceplayer.state.state_handler import StateHandler +from pathlib import Path class TestStateHandler: @@ -30,18 +31,87 @@ class TestStateHandler: ), ) - def test_state_handler_initialization(self): - with tempfile.TemporaryDirectory() as tmpdir: - tmpdir_path = Path(tmpdir) - state_handler = StateHandler(tmpdir_path) + def test_initialization(self, tmp_path: Path): + state_handler = StateHandler(tmp_path) - assert isinstance(state_handler, StateHandler) + assert isinstance(state_handler, StateHandler) - def test_state_handler_get_state(self, player_config: PlayerConfig): - with tempfile.TemporaryDirectory() as tmpdir: - tmpdir_path = Path(tmpdir) - state_handler = StateHandler(tmpdir_path) + def test_save(self, tmp_path: Path, player_config: PlayerConfig): + state_handler = StateHandler(tmp_path) - state = state_handler.get_state(player_config) + state = StateModel( + config=player_config, + system=System(), + current_cycle=0, + ) - assert state is None \ No newline at end of file + state_handler.save(state) + + assert (tmp_path / "state.pkl").exists() + + def test_get_when_empty(self, tmp_path: Path, player_config: PlayerConfig): + state_handler = StateHandler(tmp_path) + + state = state_handler.get(player_config) + + assert state is None + + def test_get(self, tmp_path: Path, player_config: PlayerConfig): + state_handler = StateHandler(tmp_path) + + state = StateModel( + config=player_config, + system=System(), + current_cycle=0, + ) + + state_handler.save(state) + + retrieved_state = state_handler.get(player_config) + + assert retrieved_state is not None + assert retrieved_state.config == state.config + assert retrieved_state.system == state.system + assert retrieved_state.current_cycle == state.current_cycle + + def test_get_with_different_config( + self, tmp_path: Path, player_config: PlayerConfig + ): + state_handler = StateHandler(tmp_path) + + state = StateModel( + config=player_config, + system=System(), + current_cycle=0, + ) + + state_handler.save(state) + + different_config = player_config.model_copy(update={"opt": False}) + + retrieved_state = state_handler.get(different_config) + + assert retrieved_state is None + + def test_get_with_different_config_force( + self, tmp_path: Path, player_config: PlayerConfig + ): + state_handler = StateHandler(tmp_path) + + state = StateModel( + config=player_config, + system=System(), + current_cycle=0, + ) + + state_handler.save(state) + + different_config = player_config.model_copy(update={"opt": False}) + + retrieved_state = state_handler.get(different_config, force=True) + + assert retrieved_state is not None + assert retrieved_state.config == state.config + assert retrieved_state.config != different_config + assert retrieved_state.system == state.system + assert retrieved_state.current_cycle == state.current_cycle