feat: better state_handler

This commit is contained in:
2026-03-04 14:02:19 -03:00
parent 11ff4c0c21
commit 06ae9b41f0
17 changed files with 158 additions and 42 deletions

View File

@@ -1,3 +0,0 @@
from diceplayer.utils.logger import RunLogger

View File

@@ -1,9 +1,7 @@
from diceplayer.config.player_config import PlayerConfig from diceplayer.cli import ArgsModel, read_input
from diceplayer.logger import logger from diceplayer.logger import logger
from diceplayer.player import Player from diceplayer.player import Player
import yaml
import argparse import argparse
from importlib import metadata from importlib import metadata
@@ -11,17 +9,6 @@ from importlib import metadata
VERSION = metadata.version("diceplayer") VERSION = metadata.version("diceplayer")
def read_input(infile) -> PlayerConfig:
try:
with open(infile, "r") as f:
return PlayerConfig.model_validate(
yaml.safe_load(f)
)
except Exception as e:
logger.exception("Failed to read input file")
raise e
def main(): def main():
parser = argparse.ArgumentParser(prog="Diceplayer") parser = argparse.ArgumentParser(prog="Diceplayer")
parser.add_argument( parser.add_argument(
@@ -46,7 +33,7 @@ def main():
metavar="OUTFILE", metavar="OUTFILE",
help="output file of diceplayer [default = run.log]", help="output file of diceplayer [default = run.log]",
) )
args = parser.parse_args() args = ArgsModel.from_args(parser.parse_args())
logger.set_output_file(args.outfile) logger.set_output_file(args.outfile)
@@ -56,4 +43,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -0,0 +1,5 @@
from .args_model import ArgsModel
from .read_input_file import read_input
__all__ = ["ArgsModel", "read_input"]

View File

@@ -0,0 +1,11 @@
from pydantic import BaseModel
class ArgsModel(BaseModel):
outfile: str
infile: str
continuation: bool
@classmethod
def from_args(cls, args):
return cls(**vars(args))

View File

@@ -0,0 +1,13 @@
from diceplayer.config import PlayerConfig
from diceplayer.logger import logger
import yaml
def read_input(infile) -> PlayerConfig:
try:
with open(infile, "r") as f:
return PlayerConfig.model_validate(yaml.safe_load(f))
except Exception as e:
logger.exception("Failed to read input file")
raise e

View File

@@ -0,0 +1,10 @@
from .dice_config import DiceConfig
from .gaussian_config import GaussianConfig
from .player_config import PlayerConfig
__all__ = [
"DiceConfig",
"GaussianConfig",
"PlayerConfig",
]

View File

@@ -1,4 +1,4 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, ConfigDict
from typing_extensions import List, Literal from typing_extensions import List, Literal
@@ -6,6 +6,9 @@ class DiceConfig(BaseModel):
""" """
Data Transfer Object for the Dice configuration. Data Transfer Object for the Dice configuration.
""" """
model_config = ConfigDict(
frozen=True,
)
ljname: str = Field(..., description="Name of the Lennard-Jones potential file") ljname: str = Field(..., description="Name of the Lennard-Jones potential file")
outname: str = Field( outname: str = Field(

View File

@@ -1,4 +1,4 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, ConfigDict
from typing_extensions import Literal from typing_extensions import Literal
@@ -6,6 +6,9 @@ class GaussianConfig(BaseModel):
""" """
Data Transfer Object for the Gaussian configuration. Data Transfer Object for the Gaussian configuration.
""" """
model_config = ConfigDict(
frozen=True,
)
level: str = Field(..., description="Level of theory for the QM calculations") level: str = Field(..., description="Level of theory for the QM calculations")
qmprog: Literal["g03", "g09", "g16"] = Field( qmprog: Literal["g03", "g09", "g16"] = Field(

View File

@@ -1,19 +1,23 @@
from diceplayer.config.dice_config import DiceConfig from diceplayer.config.dice_config import DiceConfig
from diceplayer.config.gaussian_config import GaussianConfig from diceplayer.config.gaussian_config import GaussianConfig
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator, ConfigDict
from typing_extensions import Self from typing_extensions import Self, Any
from pathlib import Path from pathlib import Path
MIN_STEP = 20000 MIN_STEP = 20000
STEP_INCREMENT = 1000
class PlayerConfig(BaseModel): class PlayerConfig(BaseModel):
""" """
Data Transfer Object for the player configuration. Data Transfer Object for the player configuration.
""" """
model_config = ConfigDict(
frozen=True,
)
opt: bool = Field(..., description="Whether to perform geometry optimization") opt: bool = Field(..., description="Whether to perform geometry optimization")
maxcyc: int = Field( maxcyc: int = Field(
@@ -42,7 +46,9 @@ class PlayerConfig(BaseModel):
"simfiles", description="Directory name for the simulation files" "simfiles", description="Directory name for the simulation files"
) )
@model_validator(mode="after") @model_validator(mode="before")
def validate_altsteps(self) -> Self: @staticmethod
self.altsteps = round(max(MIN_STEP, self.altsteps) / 1000) * 1000 def validate_altsteps(fields) -> dict[str, Any]:
return self altsteps = fields.pop("altsteps", MIN_STEP)
fields["altsteps"] = round(max(MIN_STEP, altsteps) / STEP_INCREMENT) * STEP_INCREMENT
return fields

View File

@@ -1,4 +1,4 @@
from diceplayer import RunLogger from diceplayer.utils import RunLogger
logger = RunLogger("diceplayer") logger = RunLogger("diceplayer")

View File

@@ -5,5 +5,4 @@ class Player:
def __init__(self, config: PlayerConfig): def __init__(self, config: PlayerConfig):
self.config = config self.config = config
def play(self, continuation = False): def play(self, continuation=False): ...
...

View File

View File

@@ -0,0 +1,31 @@
import pickle
from pathlib import Path
from diceplayer.config import PlayerConfig
from diceplayer.environment import System
from diceplayer.logger import logger
from diceplayer.state.state_model import StateModel
class StateHandler:
def __init__(self, sim_dir: Path, state_file: str = "state.pkl"):
self._state_file = sim_dir / state_file
def get_state(self, config: PlayerConfig) -> StateModel | None:
if not self._state_file.exists():
return None
with self._state_file.open(mode="r") as f:
data = pickle.load(f)
model = StateModel.model_validate(data)
if hash(model.config) != hash(config):
logger.warning("The configuration in the state file does not match the provided configuration.")
return None
return model
def save_state(self, state: StateModel) -> None:
with self._state_file.open(mode="wb") as f:
pickle.dump(state.model_dump(), f)

View File

@@ -0,0 +1,10 @@
from pydantic import BaseModel
from diceplayer.config import PlayerConfig
from diceplayer.environment import System
class StateModel(BaseModel):
config: PlayerConfig
system: System
current_cycle: int

View File

@@ -5,7 +5,7 @@ import sys
from pathlib import Path from pathlib import Path
H = TypeVar('H', bound=logging.Handler) H = TypeVar("H", bound=logging.Handler)
class RunLogger(logging.Logger): class RunLogger(logging.Logger):
@@ -18,33 +18,27 @@ class RunLogger(logging.Logger):
self._configure_handler(logging.StreamHandler(stream), level) self._configure_handler(logging.StreamHandler(stream), level)
) )
def set_output_file(self, outfile: Path, level=logging.INFO): def set_output_file(self, outfile: Path, level=logging.INFO):
for handler in list(self.handlers): for handler in list(self.handlers):
if not isinstance(handler, logging.FileHandler): if not isinstance(handler, logging.FileHandler):
continue continue
self.handlers.remove(handler) self.handlers.remove(handler)
self.handlers.append( self.handlers.append(self._create_file_handler(outfile, level))
self._create_file_handler(outfile, level)
)
@staticmethod @staticmethod
def _create_file_handler(file: str|Path, level) -> logging.FileHandler: def _create_file_handler(file: str | Path, level) -> logging.FileHandler:
file = Path(file) file = Path(file)
if file.exists(): if file.exists():
file.rename(file.with_suffix('.log.backup')) file.rename(file.with_suffix(".log.backup"))
handler = logging.FileHandler(file) handler = logging.FileHandler(file)
return RunLogger._configure_handler(handler, level) return RunLogger._configure_handler(handler, level)
@staticmethod @staticmethod
def _configure_handler(handler: H, level) -> H: def _configure_handler(handler: H, level) -> H:
handler.setLevel(level) handler.setLevel(level)
formatter = logging.Formatter('%(message)s') formatter = logging.Formatter("%(message)s")
handler.setFormatter(formatter) handler.setFormatter(formatter)
return handler return handler

0
tests/state/__init__.py Normal file
View File

View File

@@ -0,0 +1,47 @@
import tempfile
from pathlib import Path
import pytest
from diceplayer.config import PlayerConfig, DiceConfig, GaussianConfig
from diceplayer.state.state_handler import StateHandler
class TestStateHandler:
@pytest.fixture
def player_config(self) -> PlayerConfig:
return PlayerConfig(
opt=True,
mem=12,
maxcyc=100,
nprocs=4,
ncores=4,
dice=DiceConfig(
ljname="test",
outname="test",
dens=1.0,
nmol=[1],
nstep=[1, 1],
),
gaussian=GaussianConfig(
level="test",
qmprog="g16",
keywords="test",
),
)
def test_state_handler_initialization(self):
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
state_handler = StateHandler(tmpdir_path)
assert isinstance(state_handler, StateHandler)
def test_state_handler_get_state(self, player_config: PlayerConfig):
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
state_handler = StateHandler(tmpdir_path)
state = state_handler.get_state(player_config)
assert state is None