feat: better state_handler

This commit is contained in:
2026-03-04 14:02:19 -03:00
parent 11ff4c0c21
commit 06ae9b41f0
17 changed files with 158 additions and 42 deletions

View File

@@ -1,3 +0,0 @@
from diceplayer.utils.logger import RunLogger

View File

@@ -1,9 +1,7 @@
from diceplayer.config.player_config import PlayerConfig
from diceplayer.cli import ArgsModel, read_input
from diceplayer.logger import logger
from diceplayer.player import Player
import yaml
import argparse
from importlib import metadata
@@ -11,17 +9,6 @@ from importlib import metadata
VERSION = metadata.version("diceplayer")
def read_input(infile) -> PlayerConfig:
try:
with open(infile, "r") as f:
return PlayerConfig.model_validate(
yaml.safe_load(f)
)
except Exception as e:
logger.exception("Failed to read input file")
raise e
def main():
parser = argparse.ArgumentParser(prog="Diceplayer")
parser.add_argument(
@@ -46,7 +33,7 @@ def main():
metavar="OUTFILE",
help="output file of diceplayer [default = run.log]",
)
args = parser.parse_args()
args = ArgsModel.from_args(parser.parse_args())
logger.set_output_file(args.outfile)
@@ -56,4 +43,4 @@ def main():
if __name__ == "__main__":
main()
main()

View File

@@ -0,0 +1,5 @@
from .args_model import ArgsModel
from .read_input_file import read_input
__all__ = ["ArgsModel", "read_input"]

View File

@@ -0,0 +1,11 @@
from pydantic import BaseModel
class ArgsModel(BaseModel):
outfile: str
infile: str
continuation: bool
@classmethod
def from_args(cls, args):
return cls(**vars(args))

View File

@@ -0,0 +1,13 @@
from diceplayer.config import PlayerConfig
from diceplayer.logger import logger
import yaml
def read_input(infile) -> PlayerConfig:
try:
with open(infile, "r") as f:
return PlayerConfig.model_validate(yaml.safe_load(f))
except Exception as e:
logger.exception("Failed to read input file")
raise e

View File

@@ -0,0 +1,10 @@
from .dice_config import DiceConfig
from .gaussian_config import GaussianConfig
from .player_config import PlayerConfig
__all__ = [
"DiceConfig",
"GaussianConfig",
"PlayerConfig",
]

View File

@@ -1,4 +1,4 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ConfigDict
from typing_extensions import List, Literal
@@ -6,6 +6,9 @@ class DiceConfig(BaseModel):
"""
Data Transfer Object for the Dice configuration.
"""
model_config = ConfigDict(
frozen=True,
)
ljname: str = Field(..., description="Name of the Lennard-Jones potential file")
outname: str = Field(

View File

@@ -1,4 +1,4 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ConfigDict
from typing_extensions import Literal
@@ -6,6 +6,9 @@ class GaussianConfig(BaseModel):
"""
Data Transfer Object for the Gaussian configuration.
"""
model_config = ConfigDict(
frozen=True,
)
level: str = Field(..., description="Level of theory for the QM calculations")
qmprog: Literal["g03", "g09", "g16"] = Field(

View File

@@ -1,19 +1,23 @@
from diceplayer.config.dice_config import DiceConfig
from diceplayer.config.gaussian_config import GaussianConfig
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self
from pydantic import BaseModel, Field, model_validator, ConfigDict
from typing_extensions import Self, Any
from pathlib import Path
MIN_STEP = 20000
STEP_INCREMENT = 1000
class PlayerConfig(BaseModel):
"""
Data Transfer Object for the player configuration.
"""
model_config = ConfigDict(
frozen=True,
)
opt: bool = Field(..., description="Whether to perform geometry optimization")
maxcyc: int = Field(
@@ -42,7 +46,9 @@ class PlayerConfig(BaseModel):
"simfiles", description="Directory name for the simulation files"
)
@model_validator(mode="after")
def validate_altsteps(self) -> Self:
self.altsteps = round(max(MIN_STEP, self.altsteps) / 1000) * 1000
return self
@model_validator(mode="before")
@staticmethod
def validate_altsteps(fields) -> dict[str, Any]:
altsteps = fields.pop("altsteps", MIN_STEP)
fields["altsteps"] = round(max(MIN_STEP, altsteps) / STEP_INCREMENT) * STEP_INCREMENT
return fields

View File

@@ -1,4 +1,4 @@
from diceplayer import RunLogger
from diceplayer.utils import RunLogger
logger = RunLogger("diceplayer")

View File

@@ -5,5 +5,4 @@ class Player:
def __init__(self, config: PlayerConfig):
self.config = config
def play(self, continuation = False):
...
def play(self, continuation=False): ...

View File

View File

@@ -0,0 +1,31 @@
import pickle
from pathlib import Path
from diceplayer.config import PlayerConfig
from diceplayer.environment import System
from diceplayer.logger import logger
from diceplayer.state.state_model import StateModel
class StateHandler:
def __init__(self, sim_dir: Path, state_file: str = "state.pkl"):
self._state_file = sim_dir / state_file
def get_state(self, config: PlayerConfig) -> StateModel | None:
if not self._state_file.exists():
return None
with self._state_file.open(mode="r") as f:
data = pickle.load(f)
model = StateModel.model_validate(data)
if hash(model.config) != hash(config):
logger.warning("The configuration in the state file does not match the provided configuration.")
return None
return model
def save_state(self, state: StateModel) -> None:
with self._state_file.open(mode="wb") as f:
pickle.dump(state.model_dump(), f)

View File

@@ -0,0 +1,10 @@
from pydantic import BaseModel
from diceplayer.config import PlayerConfig
from diceplayer.environment import System
class StateModel(BaseModel):
config: PlayerConfig
system: System
current_cycle: int

View File

@@ -5,7 +5,7 @@ import sys
from pathlib import Path
H = TypeVar('H', bound=logging.Handler)
H = TypeVar("H", bound=logging.Handler)
class RunLogger(logging.Logger):
@@ -18,33 +18,27 @@ class RunLogger(logging.Logger):
self._configure_handler(logging.StreamHandler(stream), level)
)
def set_output_file(self, outfile: Path, level=logging.INFO):
for handler in list(self.handlers):
if not isinstance(handler, logging.FileHandler):
continue
self.handlers.remove(handler)
self.handlers.append(
self._create_file_handler(outfile, level)
)
self.handlers.append(self._create_file_handler(outfile, level))
@staticmethod
def _create_file_handler(file: str|Path, level) -> logging.FileHandler:
def _create_file_handler(file: str | Path, level) -> logging.FileHandler:
file = Path(file)
if file.exists():
file.rename(file.with_suffix('.log.backup'))
file.rename(file.with_suffix(".log.backup"))
handler = logging.FileHandler(file)
return RunLogger._configure_handler(handler, level)
@staticmethod
def _configure_handler(handler: H, level) -> H:
handler.setLevel(level)
formatter = logging.Formatter('%(message)s')
formatter = logging.Formatter("%(message)s")
handler.setFormatter(formatter)
return handler
return handler