diff --git a/diceplayer/config/dice_config.py b/diceplayer/config/dice_config.py index 8730ed2..907d5bf 100644 --- a/diceplayer/config/dice_config.py +++ b/diceplayer/config/dice_config.py @@ -1,7 +1,7 @@ -from pydantic import BaseModel, Field - from diceplayer.shared.utils.dataclass_protocol import Dataclass +from pydantic import BaseModel, Field + from dataclasses import dataclass, fields from typing import List, Literal @@ -12,16 +12,33 @@ class DiceConfig(BaseModel): """ ljname: str = Field(..., description="Name of the Lennard-Jones potential file") - outname: str = Field(..., description="Name of the output file for the simulation results") + outname: str = Field( + ..., description="Name of the output file for the simulation results" + ) dens: float = Field(..., description="Density of the system") - nmol: List[int] = Field(..., description="List of the number of molecules for each component") - nstep: List[int] = Field(..., description="List of the number of steps for each component", min_length=2, max_length=3) + nmol: List[int] = Field( + ..., description="List of the number of molecules for each component" + ) + nstep: List[int] = Field( + ..., + description="List of the number of steps for each component", + min_length=2, + max_length=3, + ) - upbuf: int = Field(360, description="Buffer size for the potential energy calculations") - combrule: Literal["+", "*"] = Field("*", description="Combination rule for the Lennard-Jones potential") + upbuf: int = Field( + 360, description="Buffer size for the potential energy calculations" + ) + combrule: Literal["+", "*"] = Field( + "*", description="Combination rule for the Lennard-Jones potential" + ) isave: int = Field(1000, description="Frequency of saving the simulation results") press: float = Field(1.0, description="Pressure of the system") temp: float = Field(300.0, description="Temperature of the system") - progname: str = Field("dice", description="Name of the program to run the simulation") - randominit: str = Field("first", description="Method for initializing the random number generator") + progname: str = Field( + "dice", description="Name of the program to run the simulation" + ) + randominit: str = Field( + "first", description="Method for initializing the random number generator" + ) seed: int | None = Field(None, description="Seed for the random number generator") diff --git a/diceplayer/config/gaussian_config.py b/diceplayer/config/gaussian_config.py index 2f9b4c5..bc741b2 100644 --- a/diceplayer/config/gaussian_config.py +++ b/diceplayer/config/gaussian_config.py @@ -1,10 +1,9 @@ -from typing import Literal +from diceplayer.shared.utils.dataclass_protocol import Dataclass from pydantic import BaseModel, Field -from diceplayer.shared.utils.dataclass_protocol import Dataclass - from dataclasses import dataclass, fields +from typing import Literal class GaussianConfig(BaseModel): @@ -13,9 +12,18 @@ class GaussianConfig(BaseModel): """ level: str = Field(..., description="Level of theory for the QM calculations") - qmprog: Literal["g03", "g09", "g16"] = Field("g16", description="QM program to use for the calculations") + qmprog: Literal["g03", "g09", "g16"] = Field( + "g16", description="QM program to use for the calculations" + ) - chgmult: list[int] = Field(default_factory=lambda: [0, 1], description="List of charge and multiplicity for the QM calculations") - pop: str = Field("chelpg", description="Population analysis method for the QM calculations") + chgmult: list[int] = Field( + default_factory=lambda: [0, 1], + description="List of charge and multiplicity for the QM calculations", + ) + pop: str = Field( + "chelpg", description="Population analysis method for the QM calculations" + ) chg_tol: float = Field(0.01, description="Charge tolerance for the QM calculations") - keywords: str = Field(None, description="Additional keywords for the QM calculations") + keywords: str = Field( + None, description="Additional keywords for the QM calculations" + ) diff --git a/diceplayer/config/player_config.py b/diceplayer/config/player_config.py index 29d1e23..1ec4993 100644 --- a/diceplayer/config/player_config.py +++ b/diceplayer/config/player_config.py @@ -2,45 +2,48 @@ from diceplayer.config.dice_config import DiceConfig from diceplayer.config.gaussian_config import GaussianConfig from diceplayer.shared.utils.dataclass_protocol import Dataclass +from pydantic import BaseModel, Field, field_validator, model_validator + from dataclasses import dataclass, fields +from pathlib import Path +from typing import Self + +MIN_STEP = 20000 -@dataclass -class PlayerConfig(Dataclass): +class PlayerConfig(BaseModel): """ Data Transfer Object for the player configuration. """ - opt: bool - maxcyc: int - nprocs: int - ncores: int + opt: bool = Field(..., description="Whether to perform geometry optimization") + maxcyc: int = Field( + ..., description="Maximum number of cycles for the geometry optimization" + ) + nprocs: int = Field( + ..., description="Number of processors to use for the QM calculations" + ) + ncores: int = Field( + ..., description="Number of cores to use for the QM calculations" + ) - dice: DiceConfig - gaussian: GaussianConfig + dice: DiceConfig = Field(..., description="Dice configuration") + gaussian: GaussianConfig = Field(..., description="Gaussian configuration") - mem: int = None - switchcyc: int = 3 - qmprog: str = "g16" - altsteps: int = 20000 - geoms_file = "geoms.xyz" - simulation_dir = "simfiles" + mem: int = Field(None, description="Memory configuration") + switchcyc: int = Field(3, description="Switch cycle configuration") + qmprog: str = Field("g16", description="QM program to use for the calculations") + altsteps: int = Field( + 20000, description="Number of steps for the alternate simulation" + ) + geoms_file: Path = Field( + "geoms.xyz", description="File name for the geometries output" + ) + simulation_dir: Path = Field( + "simfiles", description="Directory name for the simulation files" + ) - def __post_init__(self): - MIN_STEP = 20000 - # altsteps value is always the nearest multiple of 1000 + @model_validator(mode="after") + def validate_altsteps(self) -> Self: self.altsteps = round(max(MIN_STEP, self.altsteps) / 1000) * 1000 - - @classmethod - def from_dict(cls, params: dict): - if params["dice"] is None: - raise ValueError("Error: 'dice' keyword not specified in config file.") - params["dice"] = DiceConfig.model_validate(params["dice"]) - - if params["gaussian"] is None: - raise ValueError("Error: 'gaussian' keyword not specified in config file.") - params["gaussian"] = GaussianConfig.model_validate(params["gaussian"]) - - params = {f.name: params[f.name] for f in fields(cls) if f.name in params} - - return cls(**params) + return self diff --git a/diceplayer/player.py b/diceplayer/player.py index 3e77848..5aa2f5b 100644 --- a/diceplayer/player.py +++ b/diceplayer/player.py @@ -1,5 +1,3 @@ -from pydantic import BaseModel - from diceplayer import VERSION, logger from diceplayer.config.dice_config import DiceConfig from diceplayer.config.gaussian_config import GaussianConfig @@ -14,6 +12,7 @@ from diceplayer.shared.utils.misc import weekday_date_time from diceplayer.shared.utils.ptable import atomsymb import yaml +from pydantic import BaseModel import os import pickle @@ -119,7 +118,6 @@ class Player: else: logger.info(f"{key} = {value}") - logger.info( f"##########################################################################################\n" f"############# Welcome to DICEPLAYER version {VERSION} #############\n" @@ -430,7 +428,7 @@ class Player: @staticmethod def set_config(data: dict) -> PlayerConfig: - return PlayerConfig.from_dict(data) + return PlayerConfig.model_validate(data) @staticmethod def read_keywords(infile) -> dict: diff --git a/diceplayer/shared/interface/dice_interface.py b/diceplayer/shared/interface/dice_interface.py index 848661c..c13ef63 100644 --- a/diceplayer/shared/interface/dice_interface.py +++ b/diceplayer/shared/interface/dice_interface.py @@ -356,9 +356,10 @@ class DiceInterface(Interface): ) def run_dice_file(self, cycle: int, proc: int, file_name: str): - with open(Path(file_name), "r") as infile, open( - Path(file_name + ".out"), "w" - ) as outfile: + with ( + open(Path(file_name), "r") as infile, + open(Path(file_name + ".out"), "w") as outfile, + ): if shutil.which("bash") is not None: exit_status = subprocess.call( [ diff --git a/tests/shared/config/__init__.py b/tests/config/__init__.py similarity index 100% rename from tests/shared/config/__init__.py rename to tests/config/__init__.py diff --git a/tests/shared/config/test_dice_dto.py b/tests/config/test_dice_dto.py similarity index 100% rename from tests/shared/config/test_dice_dto.py rename to tests/config/test_dice_dto.py diff --git a/tests/shared/config/test_gaussian_dto.py b/tests/config/test_gaussian_dto.py similarity index 100% rename from tests/shared/config/test_gaussian_dto.py rename to tests/config/test_gaussian_dto.py diff --git a/tests/shared/config/test_player_dto.py b/tests/config/test_player_dto.py similarity index 96% rename from tests/shared/config/test_player_dto.py rename to tests/config/test_player_dto.py index 7a74a95..883fad0 100644 --- a/tests/shared/config/test_player_dto.py +++ b/tests/config/test_player_dto.py @@ -72,7 +72,7 @@ class TestPlayerDTO(unittest.TestCase): self.assertEqual(player_dto.altsteps, 20000) def test_from_dict(self): - player_dto = PlayerConfig.from_dict(get_config_dict()) + player_dto = PlayerConfig.model_validate(get_config_dict()) self.assertIsInstance(player_dto, PlayerConfig) self.assertIsInstance(player_dto.dice, DiceConfig) diff --git a/tests/shared/interface/test_dice_interface.py b/tests/shared/interface/test_dice_interface.py index 54ce426..b246f08 100644 --- a/tests/shared/interface/test_dice_interface.py +++ b/tests/shared/interface/test_dice_interface.py @@ -19,7 +19,7 @@ class TestDiceInterface(unittest.TestCase): logger.set_logger(stream=io.StringIO()) config = yaml.load(get_config_example(), Loader=yaml.Loader) - self.config = PlayerConfig.from_dict(config["diceplayer"]) + self.config = PlayerConfig.model_validate(config["diceplayer"]) def test_class_instantiation(self): dice = DiceInterface() diff --git a/tests/shared/interface/test_gaussian_interface.py b/tests/shared/interface/test_gaussian_interface.py index 0a7f17f..723de74 100644 --- a/tests/shared/interface/test_gaussian_interface.py +++ b/tests/shared/interface/test_gaussian_interface.py @@ -16,7 +16,7 @@ class TestGaussianInterface(unittest.TestCase): logger.set_logger(stream=io.StringIO()) config = yaml.load(get_config_example(), Loader=yaml.Loader) - self.config = PlayerConfig.from_dict(config["diceplayer"]) + self.config = PlayerConfig.model_validate(config["diceplayer"]) def test_class_instantiation(self): gaussian_interface = GaussianInterface()