feat: better state_handler
This commit is contained in:
@@ -1,3 +0,0 @@
|
|||||||
from diceplayer.utils.logger import RunLogger
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
from diceplayer.config.player_config import PlayerConfig
|
from diceplayer.cli import ArgsModel, read_input
|
||||||
from diceplayer.logger import logger
|
from diceplayer.logger import logger
|
||||||
from diceplayer.player import Player
|
from diceplayer.player import Player
|
||||||
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from importlib import metadata
|
from importlib import metadata
|
||||||
|
|
||||||
@@ -11,17 +9,6 @@ from importlib import metadata
|
|||||||
VERSION = metadata.version("diceplayer")
|
VERSION = metadata.version("diceplayer")
|
||||||
|
|
||||||
|
|
||||||
def read_input(infile) -> PlayerConfig:
|
|
||||||
try:
|
|
||||||
with open(infile, "r") as f:
|
|
||||||
return PlayerConfig.model_validate(
|
|
||||||
yaml.safe_load(f)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("Failed to read input file")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(prog="Diceplayer")
|
parser = argparse.ArgumentParser(prog="Diceplayer")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -46,7 +33,7 @@ def main():
|
|||||||
metavar="OUTFILE",
|
metavar="OUTFILE",
|
||||||
help="output file of diceplayer [default = run.log]",
|
help="output file of diceplayer [default = run.log]",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = ArgsModel.from_args(parser.parse_args())
|
||||||
|
|
||||||
logger.set_output_file(args.outfile)
|
logger.set_output_file(args.outfile)
|
||||||
|
|
||||||
@@ -56,4 +43,4 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
5
diceplayer/cli/__init__.py
Normal file
5
diceplayer/cli/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from .args_model import ArgsModel
|
||||||
|
from .read_input_file import read_input
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["ArgsModel", "read_input"]
|
||||||
11
diceplayer/cli/args_model.py
Normal file
11
diceplayer/cli/args_model.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ArgsModel(BaseModel):
|
||||||
|
outfile: str
|
||||||
|
infile: str
|
||||||
|
continuation: bool
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_args(cls, args):
|
||||||
|
return cls(**vars(args))
|
||||||
13
diceplayer/cli/read_input_file.py
Normal file
13
diceplayer/cli/read_input_file.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
from diceplayer.config import PlayerConfig
|
||||||
|
from diceplayer.logger import logger
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
def read_input(infile) -> PlayerConfig:
|
||||||
|
try:
|
||||||
|
with open(infile, "r") as f:
|
||||||
|
return PlayerConfig.model_validate(yaml.safe_load(f))
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to read input file")
|
||||||
|
raise e
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
from .dice_config import DiceConfig
|
||||||
|
from .gaussian_config import GaussianConfig
|
||||||
|
from .player_config import PlayerConfig
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DiceConfig",
|
||||||
|
"GaussianConfig",
|
||||||
|
"PlayerConfig",
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, ConfigDict
|
||||||
from typing_extensions import List, Literal
|
from typing_extensions import List, Literal
|
||||||
|
|
||||||
|
|
||||||
@@ -6,6 +6,9 @@ class DiceConfig(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Data Transfer Object for the Dice configuration.
|
Data Transfer Object for the Dice configuration.
|
||||||
"""
|
"""
|
||||||
|
model_config = ConfigDict(
|
||||||
|
frozen=True,
|
||||||
|
)
|
||||||
|
|
||||||
ljname: str = Field(..., description="Name of the Lennard-Jones potential file")
|
ljname: str = Field(..., description="Name of the Lennard-Jones potential file")
|
||||||
outname: str = Field(
|
outname: str = Field(
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, ConfigDict
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
||||||
@@ -6,6 +6,9 @@ class GaussianConfig(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Data Transfer Object for the Gaussian configuration.
|
Data Transfer Object for the Gaussian configuration.
|
||||||
"""
|
"""
|
||||||
|
model_config = ConfigDict(
|
||||||
|
frozen=True,
|
||||||
|
)
|
||||||
|
|
||||||
level: str = Field(..., description="Level of theory for the QM calculations")
|
level: str = Field(..., description="Level of theory for the QM calculations")
|
||||||
qmprog: Literal["g03", "g09", "g16"] = Field(
|
qmprog: Literal["g03", "g09", "g16"] = Field(
|
||||||
|
|||||||
@@ -1,19 +1,23 @@
|
|||||||
from diceplayer.config.dice_config import DiceConfig
|
from diceplayer.config.dice_config import DiceConfig
|
||||||
from diceplayer.config.gaussian_config import GaussianConfig
|
from diceplayer.config.gaussian_config import GaussianConfig
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator, ConfigDict
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self, Any
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
MIN_STEP = 20000
|
MIN_STEP = 20000
|
||||||
|
STEP_INCREMENT = 1000
|
||||||
|
|
||||||
|
|
||||||
class PlayerConfig(BaseModel):
|
class PlayerConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
Data Transfer Object for the player configuration.
|
Data Transfer Object for the player configuration.
|
||||||
"""
|
"""
|
||||||
|
model_config = ConfigDict(
|
||||||
|
frozen=True,
|
||||||
|
)
|
||||||
|
|
||||||
opt: bool = Field(..., description="Whether to perform geometry optimization")
|
opt: bool = Field(..., description="Whether to perform geometry optimization")
|
||||||
maxcyc: int = Field(
|
maxcyc: int = Field(
|
||||||
@@ -42,7 +46,9 @@ class PlayerConfig(BaseModel):
|
|||||||
"simfiles", description="Directory name for the simulation files"
|
"simfiles", description="Directory name for the simulation files"
|
||||||
)
|
)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="before")
|
||||||
def validate_altsteps(self) -> Self:
|
@staticmethod
|
||||||
self.altsteps = round(max(MIN_STEP, self.altsteps) / 1000) * 1000
|
def validate_altsteps(fields) -> dict[str, Any]:
|
||||||
return self
|
altsteps = fields.pop("altsteps", MIN_STEP)
|
||||||
|
fields["altsteps"] = round(max(MIN_STEP, altsteps) / STEP_INCREMENT) * STEP_INCREMENT
|
||||||
|
return fields
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from diceplayer import RunLogger
|
from diceplayer.utils import RunLogger
|
||||||
|
|
||||||
|
|
||||||
logger = RunLogger("diceplayer")
|
logger = RunLogger("diceplayer")
|
||||||
|
|||||||
@@ -5,5 +5,4 @@ class Player:
|
|||||||
def __init__(self, config: PlayerConfig):
|
def __init__(self, config: PlayerConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def play(self, continuation = False):
|
def play(self, continuation=False): ...
|
||||||
...
|
|
||||||
|
|||||||
0
diceplayer/state/__init__.py
Normal file
0
diceplayer/state/__init__.py
Normal file
31
diceplayer/state/state_handler.py
Normal file
31
diceplayer/state/state_handler.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from diceplayer.config import PlayerConfig
|
||||||
|
from diceplayer.environment import System
|
||||||
|
from diceplayer.logger import logger
|
||||||
|
from diceplayer.state.state_model import StateModel
|
||||||
|
|
||||||
|
|
||||||
|
class StateHandler:
|
||||||
|
def __init__(self, sim_dir: Path, state_file: str = "state.pkl"):
|
||||||
|
self._state_file = sim_dir / state_file
|
||||||
|
|
||||||
|
def get_state(self, config: PlayerConfig) -> StateModel | None:
|
||||||
|
if not self._state_file.exists():
|
||||||
|
return None
|
||||||
|
|
||||||
|
with self._state_file.open(mode="r") as f:
|
||||||
|
data = pickle.load(f)
|
||||||
|
|
||||||
|
model = StateModel.model_validate(data)
|
||||||
|
|
||||||
|
if hash(model.config) != hash(config):
|
||||||
|
logger.warning("The configuration in the state file does not match the provided configuration.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def save_state(self, state: StateModel) -> None:
|
||||||
|
with self._state_file.open(mode="wb") as f:
|
||||||
|
pickle.dump(state.model_dump(), f)
|
||||||
10
diceplayer/state/state_model.py
Normal file
10
diceplayer/state/state_model.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from diceplayer.config import PlayerConfig
|
||||||
|
from diceplayer.environment import System
|
||||||
|
|
||||||
|
|
||||||
|
class StateModel(BaseModel):
|
||||||
|
config: PlayerConfig
|
||||||
|
system: System
|
||||||
|
current_cycle: int
|
||||||
@@ -5,7 +5,7 @@ import sys
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
H = TypeVar('H', bound=logging.Handler)
|
H = TypeVar("H", bound=logging.Handler)
|
||||||
|
|
||||||
|
|
||||||
class RunLogger(logging.Logger):
|
class RunLogger(logging.Logger):
|
||||||
@@ -18,33 +18,27 @@ class RunLogger(logging.Logger):
|
|||||||
self._configure_handler(logging.StreamHandler(stream), level)
|
self._configure_handler(logging.StreamHandler(stream), level)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def set_output_file(self, outfile: Path, level=logging.INFO):
|
def set_output_file(self, outfile: Path, level=logging.INFO):
|
||||||
for handler in list(self.handlers):
|
for handler in list(self.handlers):
|
||||||
if not isinstance(handler, logging.FileHandler):
|
if not isinstance(handler, logging.FileHandler):
|
||||||
continue
|
continue
|
||||||
self.handlers.remove(handler)
|
self.handlers.remove(handler)
|
||||||
|
|
||||||
self.handlers.append(
|
self.handlers.append(self._create_file_handler(outfile, level))
|
||||||
self._create_file_handler(outfile, level)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_file_handler(file: str|Path, level) -> logging.FileHandler:
|
def _create_file_handler(file: str | Path, level) -> logging.FileHandler:
|
||||||
file = Path(file)
|
file = Path(file)
|
||||||
|
|
||||||
if file.exists():
|
if file.exists():
|
||||||
file.rename(file.with_suffix('.log.backup'))
|
file.rename(file.with_suffix(".log.backup"))
|
||||||
|
|
||||||
handler = logging.FileHandler(file)
|
handler = logging.FileHandler(file)
|
||||||
return RunLogger._configure_handler(handler, level)
|
return RunLogger._configure_handler(handler, level)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_handler(handler: H, level) -> H:
|
def _configure_handler(handler: H, level) -> H:
|
||||||
handler.setLevel(level)
|
handler.setLevel(level)
|
||||||
formatter = logging.Formatter('%(message)s')
|
formatter = logging.Formatter("%(message)s")
|
||||||
handler.setFormatter(formatter)
|
handler.setFormatter(formatter)
|
||||||
return handler
|
return handler
|
||||||
|
|||||||
0
tests/state/__init__.py
Normal file
0
tests/state/__init__.py
Normal file
47
tests/state/test_state_handler.py
Normal file
47
tests/state/test_state_handler.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from diceplayer.config import PlayerConfig, DiceConfig, GaussianConfig
|
||||||
|
from diceplayer.state.state_handler import StateHandler
|
||||||
|
|
||||||
|
|
||||||
|
class TestStateHandler:
|
||||||
|
@pytest.fixture
|
||||||
|
def player_config(self) -> PlayerConfig:
|
||||||
|
return PlayerConfig(
|
||||||
|
opt=True,
|
||||||
|
mem=12,
|
||||||
|
maxcyc=100,
|
||||||
|
nprocs=4,
|
||||||
|
ncores=4,
|
||||||
|
dice=DiceConfig(
|
||||||
|
ljname="test",
|
||||||
|
outname="test",
|
||||||
|
dens=1.0,
|
||||||
|
nmol=[1],
|
||||||
|
nstep=[1, 1],
|
||||||
|
),
|
||||||
|
gaussian=GaussianConfig(
|
||||||
|
level="test",
|
||||||
|
qmprog="g16",
|
||||||
|
keywords="test",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_state_handler_initialization(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
tmpdir_path = Path(tmpdir)
|
||||||
|
state_handler = StateHandler(tmpdir_path)
|
||||||
|
|
||||||
|
assert isinstance(state_handler, StateHandler)
|
||||||
|
|
||||||
|
def test_state_handler_get_state(self, player_config: PlayerConfig):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
tmpdir_path = Path(tmpdir)
|
||||||
|
state_handler = StateHandler(tmpdir_path)
|
||||||
|
|
||||||
|
state = state_handler.get_state(player_config)
|
||||||
|
|
||||||
|
assert state is None
|
||||||
Reference in New Issue
Block a user