feat: improved implementation and validations of configs

This commit is contained in:
2026-02-26 08:35:59 -03:00
parent e5c6282c86
commit c51d07cff2
11 changed files with 84 additions and 57 deletions

View File

@@ -1,7 +1,7 @@
from pydantic import BaseModel, Field
from diceplayer.shared.utils.dataclass_protocol import Dataclass from diceplayer.shared.utils.dataclass_protocol import Dataclass
from pydantic import BaseModel, Field
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from typing import List, Literal from typing import List, Literal
@@ -12,16 +12,33 @@ class DiceConfig(BaseModel):
""" """
ljname: str = Field(..., description="Name of the Lennard-Jones potential file") ljname: str = Field(..., description="Name of the Lennard-Jones potential file")
outname: str = Field(..., description="Name of the output file for the simulation results") outname: str = Field(
..., description="Name of the output file for the simulation results"
)
dens: float = Field(..., description="Density of the system") dens: float = Field(..., description="Density of the system")
nmol: List[int] = Field(..., description="List of the number of molecules for each component") nmol: List[int] = Field(
nstep: List[int] = Field(..., description="List of the number of steps for each component", min_length=2, max_length=3) ..., description="List of the number of molecules for each component"
)
nstep: List[int] = Field(
...,
description="List of the number of steps for each component",
min_length=2,
max_length=3,
)
upbuf: int = Field(360, description="Buffer size for the potential energy calculations") upbuf: int = Field(
combrule: Literal["+", "*"] = Field("*", description="Combination rule for the Lennard-Jones potential") 360, description="Buffer size for the potential energy calculations"
)
combrule: Literal["+", "*"] = Field(
"*", description="Combination rule for the Lennard-Jones potential"
)
isave: int = Field(1000, description="Frequency of saving the simulation results") isave: int = Field(1000, description="Frequency of saving the simulation results")
press: float = Field(1.0, description="Pressure of the system") press: float = Field(1.0, description="Pressure of the system")
temp: float = Field(300.0, description="Temperature of the system") temp: float = Field(300.0, description="Temperature of the system")
progname: str = Field("dice", description="Name of the program to run the simulation") progname: str = Field(
randominit: str = Field("first", description="Method for initializing the random number generator") "dice", description="Name of the program to run the simulation"
)
randominit: str = Field(
"first", description="Method for initializing the random number generator"
)
seed: int | None = Field(None, description="Seed for the random number generator") seed: int | None = Field(None, description="Seed for the random number generator")

View File

@@ -1,10 +1,9 @@
from typing import Literal from diceplayer.shared.utils.dataclass_protocol import Dataclass
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from diceplayer.shared.utils.dataclass_protocol import Dataclass
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from typing import Literal
class GaussianConfig(BaseModel): class GaussianConfig(BaseModel):
@@ -13,9 +12,18 @@ class GaussianConfig(BaseModel):
""" """
level: str = Field(..., description="Level of theory for the QM calculations") level: str = Field(..., description="Level of theory for the QM calculations")
qmprog: Literal["g03", "g09", "g16"] = Field("g16", description="QM program to use for the calculations") qmprog: Literal["g03", "g09", "g16"] = Field(
"g16", description="QM program to use for the calculations"
)
chgmult: list[int] = Field(default_factory=lambda: [0, 1], description="List of charge and multiplicity for the QM calculations") chgmult: list[int] = Field(
pop: str = Field("chelpg", description="Population analysis method for the QM calculations") default_factory=lambda: [0, 1],
description="List of charge and multiplicity for the QM calculations",
)
pop: str = Field(
"chelpg", description="Population analysis method for the QM calculations"
)
chg_tol: float = Field(0.01, description="Charge tolerance for the QM calculations") chg_tol: float = Field(0.01, description="Charge tolerance for the QM calculations")
keywords: str = Field(None, description="Additional keywords for the QM calculations") keywords: str = Field(
None, description="Additional keywords for the QM calculations"
)

View File

@@ -2,45 +2,48 @@ from diceplayer.config.dice_config import DiceConfig
from diceplayer.config.gaussian_config import GaussianConfig from diceplayer.config.gaussian_config import GaussianConfig
from diceplayer.shared.utils.dataclass_protocol import Dataclass from diceplayer.shared.utils.dataclass_protocol import Dataclass
from pydantic import BaseModel, Field, field_validator, model_validator
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from pathlib import Path
from typing import Self
MIN_STEP = 20000
@dataclass class PlayerConfig(BaseModel):
class PlayerConfig(Dataclass):
""" """
Data Transfer Object for the player configuration. Data Transfer Object for the player configuration.
""" """
opt: bool opt: bool = Field(..., description="Whether to perform geometry optimization")
maxcyc: int maxcyc: int = Field(
nprocs: int ..., description="Maximum number of cycles for the geometry optimization"
ncores: int )
nprocs: int = Field(
..., description="Number of processors to use for the QM calculations"
)
ncores: int = Field(
..., description="Number of cores to use for the QM calculations"
)
dice: DiceConfig dice: DiceConfig = Field(..., description="Dice configuration")
gaussian: GaussianConfig gaussian: GaussianConfig = Field(..., description="Gaussian configuration")
mem: int = None mem: int = Field(None, description="Memory configuration")
switchcyc: int = 3 switchcyc: int = Field(3, description="Switch cycle configuration")
qmprog: str = "g16" qmprog: str = Field("g16", description="QM program to use for the calculations")
altsteps: int = 20000 altsteps: int = Field(
geoms_file = "geoms.xyz" 20000, description="Number of steps for the alternate simulation"
simulation_dir = "simfiles" )
geoms_file: Path = Field(
"geoms.xyz", description="File name for the geometries output"
)
simulation_dir: Path = Field(
"simfiles", description="Directory name for the simulation files"
)
def __post_init__(self): @model_validator(mode="after")
MIN_STEP = 20000 def validate_altsteps(self) -> Self:
# altsteps value is always the nearest multiple of 1000
self.altsteps = round(max(MIN_STEP, self.altsteps) / 1000) * 1000 self.altsteps = round(max(MIN_STEP, self.altsteps) / 1000) * 1000
return self
@classmethod
def from_dict(cls, params: dict):
if params["dice"] is None:
raise ValueError("Error: 'dice' keyword not specified in config file.")
params["dice"] = DiceConfig.model_validate(params["dice"])
if params["gaussian"] is None:
raise ValueError("Error: 'gaussian' keyword not specified in config file.")
params["gaussian"] = GaussianConfig.model_validate(params["gaussian"])
params = {f.name: params[f.name] for f in fields(cls) if f.name in params}
return cls(**params)

View File

@@ -1,5 +1,3 @@
from pydantic import BaseModel
from diceplayer import VERSION, logger from diceplayer import VERSION, logger
from diceplayer.config.dice_config import DiceConfig from diceplayer.config.dice_config import DiceConfig
from diceplayer.config.gaussian_config import GaussianConfig from diceplayer.config.gaussian_config import GaussianConfig
@@ -14,6 +12,7 @@ from diceplayer.shared.utils.misc import weekday_date_time
from diceplayer.shared.utils.ptable import atomsymb from diceplayer.shared.utils.ptable import atomsymb
import yaml import yaml
from pydantic import BaseModel
import os import os
import pickle import pickle
@@ -119,7 +118,6 @@ class Player:
else: else:
logger.info(f"{key} = {value}") logger.info(f"{key} = {value}")
logger.info( logger.info(
f"##########################################################################################\n" f"##########################################################################################\n"
f"############# Welcome to DICEPLAYER version {VERSION} #############\n" f"############# Welcome to DICEPLAYER version {VERSION} #############\n"
@@ -430,7 +428,7 @@ class Player:
@staticmethod @staticmethod
def set_config(data: dict) -> PlayerConfig: def set_config(data: dict) -> PlayerConfig:
return PlayerConfig.from_dict(data) return PlayerConfig.model_validate(data)
@staticmethod @staticmethod
def read_keywords(infile) -> dict: def read_keywords(infile) -> dict:

View File

@@ -356,9 +356,10 @@ class DiceInterface(Interface):
) )
def run_dice_file(self, cycle: int, proc: int, file_name: str): def run_dice_file(self, cycle: int, proc: int, file_name: str):
with open(Path(file_name), "r") as infile, open( with (
Path(file_name + ".out"), "w" open(Path(file_name), "r") as infile,
) as outfile: open(Path(file_name + ".out"), "w") as outfile,
):
if shutil.which("bash") is not None: if shutil.which("bash") is not None:
exit_status = subprocess.call( exit_status = subprocess.call(
[ [

View File

@@ -72,7 +72,7 @@ class TestPlayerDTO(unittest.TestCase):
self.assertEqual(player_dto.altsteps, 20000) self.assertEqual(player_dto.altsteps, 20000)
def test_from_dict(self): def test_from_dict(self):
player_dto = PlayerConfig.from_dict(get_config_dict()) player_dto = PlayerConfig.model_validate(get_config_dict())
self.assertIsInstance(player_dto, PlayerConfig) self.assertIsInstance(player_dto, PlayerConfig)
self.assertIsInstance(player_dto.dice, DiceConfig) self.assertIsInstance(player_dto.dice, DiceConfig)

View File

@@ -19,7 +19,7 @@ class TestDiceInterface(unittest.TestCase):
logger.set_logger(stream=io.StringIO()) logger.set_logger(stream=io.StringIO())
config = yaml.load(get_config_example(), Loader=yaml.Loader) config = yaml.load(get_config_example(), Loader=yaml.Loader)
self.config = PlayerConfig.from_dict(config["diceplayer"]) self.config = PlayerConfig.model_validate(config["diceplayer"])
def test_class_instantiation(self): def test_class_instantiation(self):
dice = DiceInterface() dice = DiceInterface()

View File

@@ -16,7 +16,7 @@ class TestGaussianInterface(unittest.TestCase):
logger.set_logger(stream=io.StringIO()) logger.set_logger(stream=io.StringIO())
config = yaml.load(get_config_example(), Loader=yaml.Loader) config = yaml.load(get_config_example(), Loader=yaml.Loader)
self.config = PlayerConfig.from_dict(config["diceplayer"]) self.config = PlayerConfig.model_validate(config["diceplayer"])
def test_class_instantiation(self): def test_class_instantiation(self):
gaussian_interface = GaussianInterface() gaussian_interface = GaussianInterface()