feat: improves and initilize player pipeline

This commit is contained in:
2026-03-05 00:33:48 -03:00
parent 06ae9b41f0
commit 53eb34a83e
13 changed files with 248 additions and 60 deletions

View File

@@ -1,7 +1,8 @@
diceplayer:
opt: no
type: both
switch_cyc: 3
max_cyc: 5
mem: 24
maxcyc: 5
ncores: 5
nprocs: 4
qmprog: 'g16'

View File

@@ -1,4 +1,5 @@
from diceplayer.cli import ArgsModel, read_input
from diceplayer.config import PlayerConfig
from diceplayer.logger import logger
from diceplayer.player import Player
@@ -33,13 +34,26 @@ def main():
metavar="OUTFILE",
help="output file of diceplayer [default = run.log]",
)
parser.add_argument(
"-f",
"--force",
dest="force",
default=False,
action="store_true",
help="force overwrite existing state file if it exists [default = False]",
)
args = ArgsModel.from_args(parser.parse_args())
logger.set_output_file(args.outfile)
config = read_input(args.infile)
config: PlayerConfig
try:
config = read_input(args.infile)
except Exception as e:
logger.error(f"Failed to read input file: {e}")
return
Player(config).play(continuation=args.continuation)
Player(config).play(continuation=args.continuation, force=args.force)
if __name__ == "__main__":

View File

@@ -5,6 +5,7 @@ class ArgsModel(BaseModel):
outfile: str
infile: str
continuation: bool
force: bool
@classmethod
def from_args(cls, args):

View File

@@ -1,13 +1,9 @@
from diceplayer.config import PlayerConfig
from diceplayer.logger import logger
import yaml
def read_input(infile) -> PlayerConfig:
try:
with open(infile, "r") as f:
return PlayerConfig.model_validate(yaml.safe_load(f))
except Exception as e:
logger.exception("Failed to read input file")
raise e
with open(infile, "r") as f:
values = yaml.safe_load(f)
return PlayerConfig.model_validate(values["diceplayer"])

View File

@@ -1,11 +1,12 @@
from pydantic import BaseModel, Field, ConfigDict
from typing_extensions import List, Literal
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Literal
class DiceConfig(BaseModel):
"""
Data Transfer Object for the Dice configuration.
"""
model_config = ConfigDict(
frozen=True,
)
@@ -15,10 +16,10 @@ class DiceConfig(BaseModel):
..., description="Name of the output file for the simulation results"
)
dens: float = Field(..., description="Density of the system")
nmol: List[int] = Field(
nmol: list[int] = Field(
..., description="List of the number of molecules for each component"
)
nstep: List[int] = Field(
nstep: list[int] = Field(
...,
description="List of the number of steps for each component",
min_length=2,

View File

@@ -1,4 +1,4 @@
from pydantic import BaseModel, Field, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Literal
@@ -6,14 +6,15 @@ class GaussianConfig(BaseModel):
"""
Data Transfer Object for the Gaussian configuration.
"""
model_config = ConfigDict(
frozen=True,
)
level: str = Field(..., description="Level of theory for the QM calculations")
qmprog: Literal["g03", "g09", "g16"] = Field(
"g16", description="QM program to use for the calculations"
)
level: str = Field(..., description="Level of theory for the QM calculations")
chgmult: list[int] = Field(
default_factory=lambda: [0, 1],
@@ -23,6 +24,6 @@ class GaussianConfig(BaseModel):
"chelpg", description="Population analysis method for the QM calculations"
)
chg_tol: float = Field(0.01, description="Charge tolerance for the QM calculations")
keywords: str = Field(
keywords: str | None = Field(
None, description="Additional keywords for the QM calculations"
)

View File

@@ -1,9 +1,10 @@
from diceplayer.config.dice_config import DiceConfig
from diceplayer.config.gaussian_config import GaussianConfig
from pydantic import BaseModel, Field, model_validator, ConfigDict
from typing_extensions import Self, Any
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Any
from enum import Enum
from pathlib import Path
@@ -11,18 +12,28 @@ MIN_STEP = 20000
STEP_INCREMENT = 1000
class RoutineType(str, Enum):
CHARGE = "charge"
GEOMETRY = "geometry"
BOTH = "both"
class PlayerConfig(BaseModel):
"""
Data Transfer Object for the player configuration.
"""
model_config = ConfigDict(
frozen=True,
)
opt: bool = Field(..., description="Whether to perform geometry optimization")
maxcyc: int = Field(
..., description="Maximum number of cycles for the geometry optimization"
type: RoutineType = Field(..., description="Type of simulation to perform")
max_cyc: int = Field(
..., description="Maximum number of cycles for the geometry optimization", gt=0
)
switch_cyc: int = Field(..., description="Switch cycle configuration")
mem: int = Field(None, description="Memory configuration")
nprocs: int = Field(
..., description="Number of processors to use for the QM calculations"
)
@@ -33,22 +44,37 @@ class PlayerConfig(BaseModel):
dice: DiceConfig = Field(..., description="Dice configuration")
gaussian: GaussianConfig = Field(..., description="Gaussian configuration")
mem: int = Field(None, description="Memory configuration")
switchcyc: int = Field(3, description="Switch cycle configuration")
qmprog: str = Field("g16", description="QM program to use for the calculations")
altsteps: int = Field(
20000, description="Number of steps for the alternate simulation"
)
geoms_file: Path = Field(
"geoms.xyz", description="File name for the geometries output"
Path("geoms.xyz"), description="File name for the geometries output"
)
simulation_dir: Path = Field(
"simfiles", description="Directory name for the simulation files"
Path("simfiles"), description="Directory name for the simulation files"
)
@model_validator(mode="before")
@staticmethod
def validate_altsteps(fields) -> dict[str, Any]:
altsteps = fields.pop("altsteps", MIN_STEP)
fields["altsteps"] = round(max(MIN_STEP, altsteps) / STEP_INCREMENT) * STEP_INCREMENT
fields["altsteps"] = (
round(max(MIN_STEP, altsteps) / STEP_INCREMENT) * STEP_INCREMENT
)
return fields
@model_validator(mode="before")
@staticmethod
def validate_switch_cyc(fields: dict[str, Any]) -> dict[str, Any]:
max_cyc = int(fields.get("max_cyc", 0))
switch_cyc = int(fields.get("switch_cyc", max_cyc))
if fields.get("type") == "both" and not switch_cyc < max_cyc:
raise ValueError("switch_cyc must be less than max_cyc when type='both'.")
if fields.get("type") != "both" and switch_cyc != max_cyc:
raise ValueError(
"switch_cyc must be equal to max_cyc when type is not 'both'."
)
return fields

View File

@@ -1,8 +1,41 @@
from diceplayer.config.player_config import PlayerConfig
from diceplayer.logger import logger
from diceplayer.state.state_handler import StateHandler
from diceplayer.state.state_model import StateModel
from typing_extensions import TypedDict, Unpack
class PlayerFlags(TypedDict):
continuation: bool
force: bool
class Player:
def __init__(self, config: PlayerConfig):
self.config = config
def play(self, continuation=False): ...
def play(self, **flags: Unpack[PlayerFlags]):
state_handler = StateHandler(self.config.simulation_dir)
if not flags["continuation"]:
logger.info(
"Continuation flag is not set. Starting a new simulation and deleting any existing state."
)
state_handler.delete()
state = state_handler.get(self.config, force=flags["force"])
if state is None:
state = StateModel.from_config(self.config)
else:
logger.info("Resuming from existing state.")
while state.current_cycle < self.config.max_cyc:
logger.info(
f"Starting cycle {state.current_cycle + 1} of {self.config.max_cyc}."
)
state.current_cycle += 1
state_handler.save(state)
logger.info("Reached maximum number of cycles. Simulation complete.")

View File

@@ -1,31 +1,37 @@
import pickle
from pathlib import Path
from diceplayer.config import PlayerConfig
from diceplayer.environment import System
from diceplayer.logger import logger
from diceplayer.state.state_model import StateModel
import pickle
from pathlib import Path
class StateHandler:
def __init__(self, sim_dir: Path, state_file: str = "state.pkl"):
if not sim_dir.exists():
sim_dir.mkdir(parents=True, exist_ok=True)
self._state_file = sim_dir / state_file
def get_state(self, config: PlayerConfig) -> StateModel | None:
def get(self, config: PlayerConfig, force=False) -> StateModel | None:
if not self._state_file.exists():
return None
with self._state_file.open(mode="r") as f:
data = pickle.load(f)
with open(self._state_file, mode="rb") as file:
data = pickle.load(file)
model = StateModel.model_validate(data)
if hash(model.config) != hash(config):
logger.warning("The configuration in the state file does not match the provided configuration.")
if config != model.config and not force:
logger.warning(
"The configuration in the state file does not match the provided configuration."
)
return None
return model
def save_state(self, state: StateModel) -> None:
def save(self, state: StateModel) -> None:
with self._state_file.open(mode="wb") as f:
pickle.dump(state.model_dump(), f)
pickle.dump(state.model_dump(), f)
def delete(self) -> None:
if self._state_file.exists():
self._state_file.unlink()

View File

@@ -1,10 +1,19 @@
from pydantic import BaseModel
from diceplayer.config import PlayerConfig
from diceplayer.environment import System
from pydantic import BaseModel
from typing_extensions import Self
class StateModel(BaseModel):
config: PlayerConfig
system: System
current_cycle: int
current_cycle: int
@classmethod
def from_config(cls, config: PlayerConfig) -> Self:
return cls(
config=config,
system=System(),
current_cycle=0,
)

0
tests/cli/__init__.py Normal file
View File

View File

@@ -0,0 +1,30 @@
import diceplayer
from diceplayer.cli import read_input
from diceplayer.config import PlayerConfig
import pytest
from pathlib import Path
class TestReadInputFile:
@pytest.fixture
def example_config(self) -> Path:
return Path(diceplayer.__path__[0]).parent / "control.example.yml"
def test_read_input_file(self, example_config: Path):
config = read_input(example_config)
assert config is not None
assert isinstance(config, PlayerConfig)
def test_read_input_non_existing_file(self):
with pytest.raises(FileNotFoundError):
read_input("nonexistent_file.yml")
def test_read_input_invalid_yaml(self, tmp_path: Path):
invalid_yaml_file = tmp_path / "invalid.yml"
invalid_yaml_file.write_text("This is not valid YAML: [unbalanced brackets")
with pytest.raises(Exception):
read_input(invalid_yaml_file)

View File

@@ -1,10 +1,11 @@
import tempfile
from pathlib import Path
from diceplayer.config import DiceConfig, GaussianConfig, PlayerConfig
from diceplayer.environment import System
from diceplayer.state.state_handler import StateHandler
from diceplayer.state.state_model import StateModel
import pytest
from diceplayer.config import PlayerConfig, DiceConfig, GaussianConfig
from diceplayer.state.state_handler import StateHandler
from pathlib import Path
class TestStateHandler:
@@ -30,18 +31,87 @@ class TestStateHandler:
),
)
def test_state_handler_initialization(self):
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
state_handler = StateHandler(tmpdir_path)
def test_initialization(self, tmp_path: Path):
state_handler = StateHandler(tmp_path)
assert isinstance(state_handler, StateHandler)
assert isinstance(state_handler, StateHandler)
def test_state_handler_get_state(self, player_config: PlayerConfig):
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
state_handler = StateHandler(tmpdir_path)
def test_save(self, tmp_path: Path, player_config: PlayerConfig):
state_handler = StateHandler(tmp_path)
state = state_handler.get_state(player_config)
state = StateModel(
config=player_config,
system=System(),
current_cycle=0,
)
assert state is None
state_handler.save(state)
assert (tmp_path / "state.pkl").exists()
def test_get_when_empty(self, tmp_path: Path, player_config: PlayerConfig):
state_handler = StateHandler(tmp_path)
state = state_handler.get(player_config)
assert state is None
def test_get(self, tmp_path: Path, player_config: PlayerConfig):
state_handler = StateHandler(tmp_path)
state = StateModel(
config=player_config,
system=System(),
current_cycle=0,
)
state_handler.save(state)
retrieved_state = state_handler.get(player_config)
assert retrieved_state is not None
assert retrieved_state.config == state.config
assert retrieved_state.system == state.system
assert retrieved_state.current_cycle == state.current_cycle
def test_get_with_different_config(
self, tmp_path: Path, player_config: PlayerConfig
):
state_handler = StateHandler(tmp_path)
state = StateModel(
config=player_config,
system=System(),
current_cycle=0,
)
state_handler.save(state)
different_config = player_config.model_copy(update={"opt": False})
retrieved_state = state_handler.get(different_config)
assert retrieved_state is None
def test_get_with_different_config_force(
self, tmp_path: Path, player_config: PlayerConfig
):
state_handler = StateHandler(tmp_path)
state = StateModel(
config=player_config,
system=System(),
current_cycle=0,
)
state_handler.save(state)
different_config = player_config.model_copy(update={"opt": False})
retrieved_state = state_handler.get(different_config, force=True)
assert retrieved_state is not None
assert retrieved_state.config == state.config
assert retrieved_state.config != different_config
assert retrieved_state.system == state.system
assert retrieved_state.current_cycle == state.current_cycle