feat: improved implementation and validations of configs
This commit is contained in:
@@ -2,45 +2,48 @@ from diceplayer.config.dice_config import DiceConfig
|
||||
from diceplayer.config.gaussian_config import GaussianConfig
|
||||
from diceplayer.shared.utils.dataclass_protocol import Dataclass
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from dataclasses import dataclass, fields
|
||||
from pathlib import Path
|
||||
from typing import Self
|
||||
|
||||
MIN_STEP = 20000
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlayerConfig(Dataclass):
|
||||
class PlayerConfig(BaseModel):
|
||||
"""
|
||||
Data Transfer Object for the player configuration.
|
||||
"""
|
||||
|
||||
opt: bool
|
||||
maxcyc: int
|
||||
nprocs: int
|
||||
ncores: int
|
||||
opt: bool = Field(..., description="Whether to perform geometry optimization")
|
||||
maxcyc: int = Field(
|
||||
..., description="Maximum number of cycles for the geometry optimization"
|
||||
)
|
||||
nprocs: int = Field(
|
||||
..., description="Number of processors to use for the QM calculations"
|
||||
)
|
||||
ncores: int = Field(
|
||||
..., description="Number of cores to use for the QM calculations"
|
||||
)
|
||||
|
||||
dice: DiceConfig
|
||||
gaussian: GaussianConfig
|
||||
dice: DiceConfig = Field(..., description="Dice configuration")
|
||||
gaussian: GaussianConfig = Field(..., description="Gaussian configuration")
|
||||
|
||||
mem: int = None
|
||||
switchcyc: int = 3
|
||||
qmprog: str = "g16"
|
||||
altsteps: int = 20000
|
||||
geoms_file = "geoms.xyz"
|
||||
simulation_dir = "simfiles"
|
||||
mem: int = Field(None, description="Memory configuration")
|
||||
switchcyc: int = Field(3, description="Switch cycle configuration")
|
||||
qmprog: str = Field("g16", description="QM program to use for the calculations")
|
||||
altsteps: int = Field(
|
||||
20000, description="Number of steps for the alternate simulation"
|
||||
)
|
||||
geoms_file: Path = Field(
|
||||
"geoms.xyz", description="File name for the geometries output"
|
||||
)
|
||||
simulation_dir: Path = Field(
|
||||
"simfiles", description="Directory name for the simulation files"
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
MIN_STEP = 20000
|
||||
# altsteps value is always the nearest multiple of 1000
|
||||
@model_validator(mode="after")
|
||||
def validate_altsteps(self) -> Self:
|
||||
self.altsteps = round(max(MIN_STEP, self.altsteps) / 1000) * 1000
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, params: dict):
|
||||
if params["dice"] is None:
|
||||
raise ValueError("Error: 'dice' keyword not specified in config file.")
|
||||
params["dice"] = DiceConfig.model_validate(params["dice"])
|
||||
|
||||
if params["gaussian"] is None:
|
||||
raise ValueError("Error: 'gaussian' keyword not specified in config file.")
|
||||
params["gaussian"] = GaussianConfig.model_validate(params["gaussian"])
|
||||
|
||||
params = {f.name: params[f.name] for f in fields(cls) if f.name in params}
|
||||
|
||||
return cls(**params)
|
||||
return self
|
||||
|
||||
Reference in New Issue
Block a user