feat: better state_handler
This commit is contained in:
@@ -1,3 +0,0 @@
|
||||
from diceplayer.utils.logger import RunLogger
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from diceplayer.config.player_config import PlayerConfig
|
||||
from diceplayer.cli import ArgsModel, read_input
|
||||
from diceplayer.logger import logger
|
||||
from diceplayer.player import Player
|
||||
|
||||
import yaml
|
||||
|
||||
import argparse
|
||||
from importlib import metadata
|
||||
|
||||
@@ -11,17 +9,6 @@ from importlib import metadata
|
||||
VERSION = metadata.version("diceplayer")
|
||||
|
||||
|
||||
def read_input(infile) -> PlayerConfig:
|
||||
try:
|
||||
with open(infile, "r") as f:
|
||||
return PlayerConfig.model_validate(
|
||||
yaml.safe_load(f)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to read input file")
|
||||
raise e
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(prog="Diceplayer")
|
||||
parser.add_argument(
|
||||
@@ -46,7 +33,7 @@ def main():
|
||||
metavar="OUTFILE",
|
||||
help="output file of diceplayer [default = run.log]",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = ArgsModel.from_args(parser.parse_args())
|
||||
|
||||
logger.set_output_file(args.outfile)
|
||||
|
||||
|
||||
5
diceplayer/cli/__init__.py
Normal file
5
diceplayer/cli/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .args_model import ArgsModel
|
||||
from .read_input_file import read_input
|
||||
|
||||
|
||||
__all__ = ["ArgsModel", "read_input"]
|
||||
11
diceplayer/cli/args_model.py
Normal file
11
diceplayer/cli/args_model.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ArgsModel(BaseModel):
|
||||
outfile: str
|
||||
infile: str
|
||||
continuation: bool
|
||||
|
||||
@classmethod
|
||||
def from_args(cls, args):
|
||||
return cls(**vars(args))
|
||||
13
diceplayer/cli/read_input_file.py
Normal file
13
diceplayer/cli/read_input_file.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from diceplayer.config import PlayerConfig
|
||||
from diceplayer.logger import logger
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
def read_input(infile) -> PlayerConfig:
|
||||
try:
|
||||
with open(infile, "r") as f:
|
||||
return PlayerConfig.model_validate(yaml.safe_load(f))
|
||||
except Exception as e:
|
||||
logger.exception("Failed to read input file")
|
||||
raise e
|
||||
@@ -0,0 +1,10 @@
|
||||
from .dice_config import DiceConfig
|
||||
from .gaussian_config import GaussianConfig
|
||||
from .player_config import PlayerConfig
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DiceConfig",
|
||||
"GaussianConfig",
|
||||
"PlayerConfig",
|
||||
]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing_extensions import List, Literal
|
||||
|
||||
|
||||
@@ -6,6 +6,9 @@ class DiceConfig(BaseModel):
|
||||
"""
|
||||
Data Transfer Object for the Dice configuration.
|
||||
"""
|
||||
model_config = ConfigDict(
|
||||
frozen=True,
|
||||
)
|
||||
|
||||
ljname: str = Field(..., description="Name of the Lennard-Jones potential file")
|
||||
outname: str = Field(
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
@@ -6,6 +6,9 @@ class GaussianConfig(BaseModel):
|
||||
"""
|
||||
Data Transfer Object for the Gaussian configuration.
|
||||
"""
|
||||
model_config = ConfigDict(
|
||||
frozen=True,
|
||||
)
|
||||
|
||||
level: str = Field(..., description="Level of theory for the QM calculations")
|
||||
qmprog: Literal["g03", "g09", "g16"] = Field(
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
from diceplayer.config.dice_config import DiceConfig
|
||||
from diceplayer.config.gaussian_config import GaussianConfig
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
from pydantic import BaseModel, Field, model_validator, ConfigDict
|
||||
from typing_extensions import Self, Any
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
MIN_STEP = 20000
|
||||
STEP_INCREMENT = 1000
|
||||
|
||||
|
||||
class PlayerConfig(BaseModel):
|
||||
"""
|
||||
Data Transfer Object for the player configuration.
|
||||
"""
|
||||
model_config = ConfigDict(
|
||||
frozen=True,
|
||||
)
|
||||
|
||||
opt: bool = Field(..., description="Whether to perform geometry optimization")
|
||||
maxcyc: int = Field(
|
||||
@@ -42,7 +46,9 @@ class PlayerConfig(BaseModel):
|
||||
"simfiles", description="Directory name for the simulation files"
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_altsteps(self) -> Self:
|
||||
self.altsteps = round(max(MIN_STEP, self.altsteps) / 1000) * 1000
|
||||
return self
|
||||
@model_validator(mode="before")
|
||||
@staticmethod
|
||||
def validate_altsteps(fields) -> dict[str, Any]:
|
||||
altsteps = fields.pop("altsteps", MIN_STEP)
|
||||
fields["altsteps"] = round(max(MIN_STEP, altsteps) / STEP_INCREMENT) * STEP_INCREMENT
|
||||
return fields
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from diceplayer import RunLogger
|
||||
from diceplayer.utils import RunLogger
|
||||
|
||||
|
||||
logger = RunLogger("diceplayer")
|
||||
|
||||
@@ -5,5 +5,4 @@ class Player:
|
||||
def __init__(self, config: PlayerConfig):
|
||||
self.config = config
|
||||
|
||||
def play(self, continuation = False):
|
||||
...
|
||||
def play(self, continuation=False): ...
|
||||
|
||||
0
diceplayer/state/__init__.py
Normal file
0
diceplayer/state/__init__.py
Normal file
31
diceplayer/state/state_handler.py
Normal file
31
diceplayer/state/state_handler.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
from diceplayer.config import PlayerConfig
|
||||
from diceplayer.environment import System
|
||||
from diceplayer.logger import logger
|
||||
from diceplayer.state.state_model import StateModel
|
||||
|
||||
|
||||
class StateHandler:
|
||||
def __init__(self, sim_dir: Path, state_file: str = "state.pkl"):
|
||||
self._state_file = sim_dir / state_file
|
||||
|
||||
def get_state(self, config: PlayerConfig) -> StateModel | None:
|
||||
if not self._state_file.exists():
|
||||
return None
|
||||
|
||||
with self._state_file.open(mode="r") as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
model = StateModel.model_validate(data)
|
||||
|
||||
if hash(model.config) != hash(config):
|
||||
logger.warning("The configuration in the state file does not match the provided configuration.")
|
||||
return None
|
||||
|
||||
return model
|
||||
|
||||
def save_state(self, state: StateModel) -> None:
|
||||
with self._state_file.open(mode="wb") as f:
|
||||
pickle.dump(state.model_dump(), f)
|
||||
10
diceplayer/state/state_model.py
Normal file
10
diceplayer/state/state_model.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from diceplayer.config import PlayerConfig
|
||||
from diceplayer.environment import System
|
||||
|
||||
|
||||
class StateModel(BaseModel):
|
||||
config: PlayerConfig
|
||||
system: System
|
||||
current_cycle: int
|
||||
@@ -5,7 +5,7 @@ import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
H = TypeVar('H', bound=logging.Handler)
|
||||
H = TypeVar("H", bound=logging.Handler)
|
||||
|
||||
|
||||
class RunLogger(logging.Logger):
|
||||
@@ -18,33 +18,27 @@ class RunLogger(logging.Logger):
|
||||
self._configure_handler(logging.StreamHandler(stream), level)
|
||||
)
|
||||
|
||||
|
||||
def set_output_file(self, outfile: Path, level=logging.INFO):
|
||||
for handler in list(self.handlers):
|
||||
if not isinstance(handler, logging.FileHandler):
|
||||
continue
|
||||
self.handlers.remove(handler)
|
||||
|
||||
self.handlers.append(
|
||||
self._create_file_handler(outfile, level)
|
||||
)
|
||||
|
||||
|
||||
self.handlers.append(self._create_file_handler(outfile, level))
|
||||
|
||||
@staticmethod
|
||||
def _create_file_handler(file: str|Path, level) -> logging.FileHandler:
|
||||
def _create_file_handler(file: str | Path, level) -> logging.FileHandler:
|
||||
file = Path(file)
|
||||
|
||||
if file.exists():
|
||||
file.rename(file.with_suffix('.log.backup'))
|
||||
file.rename(file.with_suffix(".log.backup"))
|
||||
|
||||
handler = logging.FileHandler(file)
|
||||
return RunLogger._configure_handler(handler, level)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _configure_handler(handler: H, level) -> H:
|
||||
handler.setLevel(level)
|
||||
formatter = logging.Formatter('%(message)s')
|
||||
formatter = logging.Formatter("%(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
return handler
|
||||
0
tests/state/__init__.py
Normal file
0
tests/state/__init__.py
Normal file
47
tests/state/test_state_handler.py
Normal file
47
tests/state/test_state_handler.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from diceplayer.config import PlayerConfig, DiceConfig, GaussianConfig
|
||||
from diceplayer.state.state_handler import StateHandler
|
||||
|
||||
|
||||
class TestStateHandler:
|
||||
@pytest.fixture
|
||||
def player_config(self) -> PlayerConfig:
|
||||
return PlayerConfig(
|
||||
opt=True,
|
||||
mem=12,
|
||||
maxcyc=100,
|
||||
nprocs=4,
|
||||
ncores=4,
|
||||
dice=DiceConfig(
|
||||
ljname="test",
|
||||
outname="test",
|
||||
dens=1.0,
|
||||
nmol=[1],
|
||||
nstep=[1, 1],
|
||||
),
|
||||
gaussian=GaussianConfig(
|
||||
level="test",
|
||||
qmprog="g16",
|
||||
keywords="test",
|
||||
),
|
||||
)
|
||||
|
||||
def test_state_handler_initialization(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
state_handler = StateHandler(tmpdir_path)
|
||||
|
||||
assert isinstance(state_handler, StateHandler)
|
||||
|
||||
def test_state_handler_get_state(self, player_config: PlayerConfig):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
state_handler = StateHandler(tmpdir_path)
|
||||
|
||||
state = state_handler.get_state(player_config)
|
||||
|
||||
assert state is None
|
||||
Reference in New Issue
Block a user