feat: improves and initilize player pipeline
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
diceplayer:
|
||||
opt: no
|
||||
type: both
|
||||
switch_cyc: 3
|
||||
max_cyc: 5
|
||||
mem: 24
|
||||
maxcyc: 5
|
||||
ncores: 5
|
||||
nprocs: 4
|
||||
qmprog: 'g16'
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from diceplayer.cli import ArgsModel, read_input
|
||||
from diceplayer.config import PlayerConfig
|
||||
from diceplayer.logger import logger
|
||||
from diceplayer.player import Player
|
||||
|
||||
@@ -33,13 +34,26 @@ def main():
|
||||
metavar="OUTFILE",
|
||||
help="output file of diceplayer [default = run.log]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--force",
|
||||
dest="force",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="force overwrite existing state file if it exists [default = False]",
|
||||
)
|
||||
args = ArgsModel.from_args(parser.parse_args())
|
||||
|
||||
logger.set_output_file(args.outfile)
|
||||
|
||||
config = read_input(args.infile)
|
||||
config: PlayerConfig
|
||||
try:
|
||||
config = read_input(args.infile)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read input file: {e}")
|
||||
return
|
||||
|
||||
Player(config).play(continuation=args.continuation)
|
||||
Player(config).play(continuation=args.continuation, force=args.force)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -5,6 +5,7 @@ class ArgsModel(BaseModel):
|
||||
outfile: str
|
||||
infile: str
|
||||
continuation: bool
|
||||
force: bool
|
||||
|
||||
@classmethod
|
||||
def from_args(cls, args):
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
from diceplayer.config import PlayerConfig
|
||||
from diceplayer.logger import logger
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
def read_input(infile) -> PlayerConfig:
|
||||
try:
|
||||
with open(infile, "r") as f:
|
||||
return PlayerConfig.model_validate(yaml.safe_load(f))
|
||||
except Exception as e:
|
||||
logger.exception("Failed to read input file")
|
||||
raise e
|
||||
with open(infile, "r") as f:
|
||||
values = yaml.safe_load(f)
|
||||
return PlayerConfig.model_validate(values["diceplayer"])
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing_extensions import List, Literal
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
class DiceConfig(BaseModel):
|
||||
"""
|
||||
Data Transfer Object for the Dice configuration.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
frozen=True,
|
||||
)
|
||||
@@ -15,10 +16,10 @@ class DiceConfig(BaseModel):
|
||||
..., description="Name of the output file for the simulation results"
|
||||
)
|
||||
dens: float = Field(..., description="Density of the system")
|
||||
nmol: List[int] = Field(
|
||||
nmol: list[int] = Field(
|
||||
..., description="List of the number of molecules for each component"
|
||||
)
|
||||
nstep: List[int] = Field(
|
||||
nstep: list[int] = Field(
|
||||
...,
|
||||
description="List of the number of steps for each component",
|
||||
min_length=2,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
@@ -6,14 +6,15 @@ class GaussianConfig(BaseModel):
|
||||
"""
|
||||
Data Transfer Object for the Gaussian configuration.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
frozen=True,
|
||||
)
|
||||
|
||||
level: str = Field(..., description="Level of theory for the QM calculations")
|
||||
qmprog: Literal["g03", "g09", "g16"] = Field(
|
||||
"g16", description="QM program to use for the calculations"
|
||||
)
|
||||
level: str = Field(..., description="Level of theory for the QM calculations")
|
||||
|
||||
chgmult: list[int] = Field(
|
||||
default_factory=lambda: [0, 1],
|
||||
@@ -23,6 +24,6 @@ class GaussianConfig(BaseModel):
|
||||
"chelpg", description="Population analysis method for the QM calculations"
|
||||
)
|
||||
chg_tol: float = Field(0.01, description="Charge tolerance for the QM calculations")
|
||||
keywords: str = Field(
|
||||
keywords: str | None = Field(
|
||||
None, description="Additional keywords for the QM calculations"
|
||||
)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from diceplayer.config.dice_config import DiceConfig
|
||||
from diceplayer.config.gaussian_config import GaussianConfig
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator, ConfigDict
|
||||
from typing_extensions import Self, Any
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing_extensions import Any
|
||||
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@@ -11,18 +12,28 @@ MIN_STEP = 20000
|
||||
STEP_INCREMENT = 1000
|
||||
|
||||
|
||||
class RoutineType(str, Enum):
|
||||
CHARGE = "charge"
|
||||
GEOMETRY = "geometry"
|
||||
BOTH = "both"
|
||||
|
||||
|
||||
class PlayerConfig(BaseModel):
|
||||
"""
|
||||
Data Transfer Object for the player configuration.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
frozen=True,
|
||||
)
|
||||
|
||||
opt: bool = Field(..., description="Whether to perform geometry optimization")
|
||||
maxcyc: int = Field(
|
||||
..., description="Maximum number of cycles for the geometry optimization"
|
||||
type: RoutineType = Field(..., description="Type of simulation to perform")
|
||||
max_cyc: int = Field(
|
||||
..., description="Maximum number of cycles for the geometry optimization", gt=0
|
||||
)
|
||||
switch_cyc: int = Field(..., description="Switch cycle configuration")
|
||||
|
||||
mem: int = Field(None, description="Memory configuration")
|
||||
nprocs: int = Field(
|
||||
..., description="Number of processors to use for the QM calculations"
|
||||
)
|
||||
@@ -33,22 +44,37 @@ class PlayerConfig(BaseModel):
|
||||
dice: DiceConfig = Field(..., description="Dice configuration")
|
||||
gaussian: GaussianConfig = Field(..., description="Gaussian configuration")
|
||||
|
||||
mem: int = Field(None, description="Memory configuration")
|
||||
switchcyc: int = Field(3, description="Switch cycle configuration")
|
||||
qmprog: str = Field("g16", description="QM program to use for the calculations")
|
||||
altsteps: int = Field(
|
||||
20000, description="Number of steps for the alternate simulation"
|
||||
)
|
||||
geoms_file: Path = Field(
|
||||
"geoms.xyz", description="File name for the geometries output"
|
||||
Path("geoms.xyz"), description="File name for the geometries output"
|
||||
)
|
||||
simulation_dir: Path = Field(
|
||||
"simfiles", description="Directory name for the simulation files"
|
||||
Path("simfiles"), description="Directory name for the simulation files"
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@staticmethod
|
||||
def validate_altsteps(fields) -> dict[str, Any]:
|
||||
altsteps = fields.pop("altsteps", MIN_STEP)
|
||||
fields["altsteps"] = round(max(MIN_STEP, altsteps) / STEP_INCREMENT) * STEP_INCREMENT
|
||||
fields["altsteps"] = (
|
||||
round(max(MIN_STEP, altsteps) / STEP_INCREMENT) * STEP_INCREMENT
|
||||
)
|
||||
return fields
|
||||
|
||||
@model_validator(mode="before")
|
||||
@staticmethod
|
||||
def validate_switch_cyc(fields: dict[str, Any]) -> dict[str, Any]:
|
||||
max_cyc = int(fields.get("max_cyc", 0))
|
||||
switch_cyc = int(fields.get("switch_cyc", max_cyc))
|
||||
|
||||
if fields.get("type") == "both" and not switch_cyc < max_cyc:
|
||||
raise ValueError("switch_cyc must be less than max_cyc when type='both'.")
|
||||
|
||||
if fields.get("type") != "both" and switch_cyc != max_cyc:
|
||||
raise ValueError(
|
||||
"switch_cyc must be equal to max_cyc when type is not 'both'."
|
||||
)
|
||||
|
||||
return fields
|
||||
|
||||
@@ -1,8 +1,41 @@
|
||||
from diceplayer.config.player_config import PlayerConfig
|
||||
from diceplayer.logger import logger
|
||||
from diceplayer.state.state_handler import StateHandler
|
||||
from diceplayer.state.state_model import StateModel
|
||||
|
||||
from typing_extensions import TypedDict, Unpack
|
||||
|
||||
|
||||
class PlayerFlags(TypedDict):
|
||||
continuation: bool
|
||||
force: bool
|
||||
|
||||
|
||||
class Player:
|
||||
def __init__(self, config: PlayerConfig):
|
||||
self.config = config
|
||||
|
||||
def play(self, continuation=False): ...
|
||||
def play(self, **flags: Unpack[PlayerFlags]):
|
||||
state_handler = StateHandler(self.config.simulation_dir)
|
||||
|
||||
if not flags["continuation"]:
|
||||
logger.info(
|
||||
"Continuation flag is not set. Starting a new simulation and deleting any existing state."
|
||||
)
|
||||
state_handler.delete()
|
||||
|
||||
state = state_handler.get(self.config, force=flags["force"])
|
||||
|
||||
if state is None:
|
||||
state = StateModel.from_config(self.config)
|
||||
else:
|
||||
logger.info("Resuming from existing state.")
|
||||
|
||||
while state.current_cycle < self.config.max_cyc:
|
||||
logger.info(
|
||||
f"Starting cycle {state.current_cycle + 1} of {self.config.max_cyc}."
|
||||
)
|
||||
state.current_cycle += 1
|
||||
state_handler.save(state)
|
||||
|
||||
logger.info("Reached maximum number of cycles. Simulation complete.")
|
||||
|
||||
@@ -1,31 +1,37 @@
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
from diceplayer.config import PlayerConfig
|
||||
from diceplayer.environment import System
|
||||
from diceplayer.logger import logger
|
||||
from diceplayer.state.state_model import StateModel
|
||||
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class StateHandler:
|
||||
def __init__(self, sim_dir: Path, state_file: str = "state.pkl"):
|
||||
if not sim_dir.exists():
|
||||
sim_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._state_file = sim_dir / state_file
|
||||
|
||||
def get_state(self, config: PlayerConfig) -> StateModel | None:
|
||||
def get(self, config: PlayerConfig, force=False) -> StateModel | None:
|
||||
if not self._state_file.exists():
|
||||
return None
|
||||
|
||||
with self._state_file.open(mode="r") as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
with open(self._state_file, mode="rb") as file:
|
||||
data = pickle.load(file)
|
||||
model = StateModel.model_validate(data)
|
||||
|
||||
if hash(model.config) != hash(config):
|
||||
logger.warning("The configuration in the state file does not match the provided configuration.")
|
||||
if config != model.config and not force:
|
||||
logger.warning(
|
||||
"The configuration in the state file does not match the provided configuration."
|
||||
)
|
||||
return None
|
||||
|
||||
return model
|
||||
|
||||
def save_state(self, state: StateModel) -> None:
|
||||
def save(self, state: StateModel) -> None:
|
||||
with self._state_file.open(mode="wb") as f:
|
||||
pickle.dump(state.model_dump(), f)
|
||||
|
||||
def delete(self) -> None:
|
||||
if self._state_file.exists():
|
||||
self._state_file.unlink()
|
||||
|
||||
@@ -1,10 +1,19 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from diceplayer.config import PlayerConfig
|
||||
from diceplayer.environment import System
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class StateModel(BaseModel):
|
||||
config: PlayerConfig
|
||||
system: System
|
||||
current_cycle: int
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: PlayerConfig) -> Self:
|
||||
return cls(
|
||||
config=config,
|
||||
system=System(),
|
||||
current_cycle=0,
|
||||
)
|
||||
|
||||
0
tests/cli/__init__.py
Normal file
0
tests/cli/__init__.py
Normal file
30
tests/cli/test_read_input_file.py
Normal file
30
tests/cli/test_read_input_file.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import diceplayer
|
||||
from diceplayer.cli import read_input
|
||||
from diceplayer.config import PlayerConfig
|
||||
|
||||
import pytest
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class TestReadInputFile:
|
||||
@pytest.fixture
|
||||
def example_config(self) -> Path:
|
||||
return Path(diceplayer.__path__[0]).parent / "control.example.yml"
|
||||
|
||||
def test_read_input_file(self, example_config: Path):
|
||||
config = read_input(example_config)
|
||||
|
||||
assert config is not None
|
||||
assert isinstance(config, PlayerConfig)
|
||||
|
||||
def test_read_input_non_existing_file(self):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
read_input("nonexistent_file.yml")
|
||||
|
||||
def test_read_input_invalid_yaml(self, tmp_path: Path):
|
||||
invalid_yaml_file = tmp_path / "invalid.yml"
|
||||
invalid_yaml_file.write_text("This is not valid YAML: [unbalanced brackets")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
read_input(invalid_yaml_file)
|
||||
@@ -1,10 +1,11 @@
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from diceplayer.config import DiceConfig, GaussianConfig, PlayerConfig
|
||||
from diceplayer.environment import System
|
||||
from diceplayer.state.state_handler import StateHandler
|
||||
from diceplayer.state.state_model import StateModel
|
||||
|
||||
import pytest
|
||||
|
||||
from diceplayer.config import PlayerConfig, DiceConfig, GaussianConfig
|
||||
from diceplayer.state.state_handler import StateHandler
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class TestStateHandler:
|
||||
@@ -30,18 +31,87 @@ class TestStateHandler:
|
||||
),
|
||||
)
|
||||
|
||||
def test_state_handler_initialization(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
state_handler = StateHandler(tmpdir_path)
|
||||
def test_initialization(self, tmp_path: Path):
|
||||
state_handler = StateHandler(tmp_path)
|
||||
|
||||
assert isinstance(state_handler, StateHandler)
|
||||
assert isinstance(state_handler, StateHandler)
|
||||
|
||||
def test_state_handler_get_state(self, player_config: PlayerConfig):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
state_handler = StateHandler(tmpdir_path)
|
||||
def test_save(self, tmp_path: Path, player_config: PlayerConfig):
|
||||
state_handler = StateHandler(tmp_path)
|
||||
|
||||
state = state_handler.get_state(player_config)
|
||||
state = StateModel(
|
||||
config=player_config,
|
||||
system=System(),
|
||||
current_cycle=0,
|
||||
)
|
||||
|
||||
assert state is None
|
||||
state_handler.save(state)
|
||||
|
||||
assert (tmp_path / "state.pkl").exists()
|
||||
|
||||
def test_get_when_empty(self, tmp_path: Path, player_config: PlayerConfig):
|
||||
state_handler = StateHandler(tmp_path)
|
||||
|
||||
state = state_handler.get(player_config)
|
||||
|
||||
assert state is None
|
||||
|
||||
def test_get(self, tmp_path: Path, player_config: PlayerConfig):
|
||||
state_handler = StateHandler(tmp_path)
|
||||
|
||||
state = StateModel(
|
||||
config=player_config,
|
||||
system=System(),
|
||||
current_cycle=0,
|
||||
)
|
||||
|
||||
state_handler.save(state)
|
||||
|
||||
retrieved_state = state_handler.get(player_config)
|
||||
|
||||
assert retrieved_state is not None
|
||||
assert retrieved_state.config == state.config
|
||||
assert retrieved_state.system == state.system
|
||||
assert retrieved_state.current_cycle == state.current_cycle
|
||||
|
||||
def test_get_with_different_config(
|
||||
self, tmp_path: Path, player_config: PlayerConfig
|
||||
):
|
||||
state_handler = StateHandler(tmp_path)
|
||||
|
||||
state = StateModel(
|
||||
config=player_config,
|
||||
system=System(),
|
||||
current_cycle=0,
|
||||
)
|
||||
|
||||
state_handler.save(state)
|
||||
|
||||
different_config = player_config.model_copy(update={"opt": False})
|
||||
|
||||
retrieved_state = state_handler.get(different_config)
|
||||
|
||||
assert retrieved_state is None
|
||||
|
||||
def test_get_with_different_config_force(
|
||||
self, tmp_path: Path, player_config: PlayerConfig
|
||||
):
|
||||
state_handler = StateHandler(tmp_path)
|
||||
|
||||
state = StateModel(
|
||||
config=player_config,
|
||||
system=System(),
|
||||
current_cycle=0,
|
||||
)
|
||||
|
||||
state_handler.save(state)
|
||||
|
||||
different_config = player_config.model_copy(update={"opt": False})
|
||||
|
||||
retrieved_state = state_handler.get(different_config, force=True)
|
||||
|
||||
assert retrieved_state is not None
|
||||
assert retrieved_state.config == state.config
|
||||
assert retrieved_state.config != different_config
|
||||
assert retrieved_state.system == state.system
|
||||
assert retrieved_state.current_cycle == state.current_cycle
|
||||
|
||||
Reference in New Issue
Block a user