chore: fixes tests
This commit is contained in:
27
diceplayer/config/dice_config.py
Normal file
27
diceplayer/config/dice_config.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from diceplayer.shared.utils.dataclass_protocol import Dataclass
|
||||
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import List, Literal
|
||||
|
||||
|
||||
class DiceConfig(BaseModel):
|
||||
"""
|
||||
Data Transfer Object for the Dice configuration.
|
||||
"""
|
||||
|
||||
ljname: str = Field(..., description="Name of the Lennard-Jones potential file")
|
||||
outname: str = Field(..., description="Name of the output file for the simulation results")
|
||||
dens: float = Field(..., description="Density of the system")
|
||||
nmol: List[int] = Field(..., description="List of the number of molecules for each component")
|
||||
nstep: List[int] = Field(..., description="List of the number of steps for each component", min_length=2, max_length=3)
|
||||
|
||||
upbuf: int = Field(360, description="Buffer size for the potential energy calculations")
|
||||
combrule: Literal["+", "*"] = Field("*", description="Combination rule for the Lennard-Jones potential")
|
||||
isave: int = Field(1000, description="Frequency of saving the simulation results")
|
||||
press: float = Field(1.0, description="Pressure of the system")
|
||||
temp: float = Field(300.0, description="Temperature of the system")
|
||||
progname: str = Field("dice", description="Name of the program to run the simulation")
|
||||
randominit: str = Field("first", description="Method for initializing the random number generator")
|
||||
seed: int | None = Field(None, description="Seed for the random number generator")
|
||||
21
diceplayer/config/gaussian_config.py
Normal file
21
diceplayer/config/gaussian_config.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from diceplayer.shared.utils.dataclass_protocol import Dataclass
|
||||
|
||||
from dataclasses import dataclass, fields
|
||||
|
||||
|
||||
class GaussianConfig(BaseModel):
|
||||
"""
|
||||
Data Transfer Object for the Gaussian configuration.
|
||||
"""
|
||||
|
||||
level: str = Field(..., description="Level of theory for the QM calculations")
|
||||
qmprog: Literal["g03", "g09", "g16"] = Field("g16", description="QM program to use for the calculations")
|
||||
|
||||
chgmult: list[int] = Field(default_factory=lambda: [0, 1], description="List of charge and multiplicity for the QM calculations")
|
||||
pop: str = Field("chelpg", description="Population analysis method for the QM calculations")
|
||||
chg_tol: float = Field(0.01, description="Charge tolerance for the QM calculations")
|
||||
keywords: str = Field(None, description="Additional keywords for the QM calculations")
|
||||
@@ -1,5 +1,5 @@
|
||||
from diceplayer.shared.config.dice_config import DiceConfig
|
||||
from diceplayer.shared.config.gaussian_config import GaussianDTO
|
||||
from diceplayer.config.dice_config import DiceConfig
|
||||
from diceplayer.config.gaussian_config import GaussianConfig
|
||||
from diceplayer.shared.utils.dataclass_protocol import Dataclass
|
||||
|
||||
from dataclasses import dataclass, fields
|
||||
@@ -17,7 +17,7 @@ class PlayerConfig(Dataclass):
|
||||
ncores: int
|
||||
|
||||
dice: DiceConfig
|
||||
gaussian: GaussianDTO
|
||||
gaussian: GaussianConfig
|
||||
|
||||
mem: int = None
|
||||
switchcyc: int = 3
|
||||
@@ -35,11 +35,11 @@ class PlayerConfig(Dataclass):
|
||||
def from_dict(cls, params: dict):
|
||||
if params["dice"] is None:
|
||||
raise ValueError("Error: 'dice' keyword not specified in config file.")
|
||||
params["dice"] = DiceConfig.from_dict(params["dice"])
|
||||
params["dice"] = DiceConfig.model_validate(params["dice"])
|
||||
|
||||
if params["gaussian"] is None:
|
||||
raise ValueError("Error: 'gaussian' keyword not specified in config file.")
|
||||
params["gaussian"] = GaussianDTO.from_dict(params["gaussian"])
|
||||
params["gaussian"] = GaussianConfig.model_validate(params["gaussian"])
|
||||
|
||||
params = {f.name: params[f.name] for f in fields(cls) if f.name in params}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from diceplayer import VERSION, logger
|
||||
from diceplayer.shared.config.dice_config import DiceConfig
|
||||
from diceplayer.shared.config.gaussian_config import GaussianDTO
|
||||
from diceplayer.shared.config.player_config import PlayerConfig
|
||||
from diceplayer.config.dice_config import DiceConfig
|
||||
from diceplayer.config.gaussian_config import GaussianConfig
|
||||
from diceplayer.config.player_config import PlayerConfig
|
||||
from diceplayer.shared.environment.atom import Atom
|
||||
from diceplayer.shared.environment.molecule import Molecule
|
||||
from diceplayer.shared.environment.system import System
|
||||
@@ -16,7 +18,6 @@ import yaml
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
from dataclasses import fields
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Type
|
||||
|
||||
@@ -108,14 +109,16 @@ class Player:
|
||||
geoms_file_path.touch()
|
||||
|
||||
def print_keywords(self) -> None:
|
||||
def log_keywords(config: Dataclass, dto: Type[Dataclass]):
|
||||
for key in sorted(list(map(lambda f: f.name, fields(dto)))):
|
||||
if getattr(config, key) is not None:
|
||||
if isinstance(getattr(config, key), list):
|
||||
string = " ".join(str(x) for x in getattr(config, key))
|
||||
logger.info(f"{key} = [ {string} ]")
|
||||
else:
|
||||
logger.info(f"{key} = {getattr(config, key)}")
|
||||
def log_keywords(config: BaseModel):
|
||||
for key, value in sorted(config.model_dump().items()):
|
||||
if value is None:
|
||||
continue
|
||||
if isinstance(value, list):
|
||||
string = " ".join(str(x) for x in value)
|
||||
logger.info(f"{key} = [ {string} ]")
|
||||
else:
|
||||
logger.info(f"{key} = {value}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"##########################################################################################\n"
|
||||
@@ -138,7 +141,7 @@ class Player:
|
||||
"------------------------------------------------------------------------------------------\n"
|
||||
)
|
||||
|
||||
log_keywords(self.config.dice, DiceConfig)
|
||||
log_keywords(self.config.dice)
|
||||
|
||||
logger.info(
|
||||
"------------------------------------------------------------------------------------------\n"
|
||||
@@ -146,7 +149,7 @@ class Player:
|
||||
"------------------------------------------------------------------------------------------\n"
|
||||
)
|
||||
|
||||
log_keywords(self.config.gaussian, GaussianDTO)
|
||||
log_keywords(self.config.gaussian)
|
||||
|
||||
logger.info("\n")
|
||||
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
from diceplayer.shared.utils.dataclass_protocol import Dataclass
|
||||
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import List
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiceConfig(Dataclass):
|
||||
"""
|
||||
Data Transfer Object for the Dice configuration.
|
||||
"""
|
||||
|
||||
ljname: str
|
||||
outname: str
|
||||
dens: float
|
||||
nmol: List[int]
|
||||
nstep: List[int]
|
||||
|
||||
upbuf = 360
|
||||
combrule = "*"
|
||||
isave: int = 1000
|
||||
press: float = 1.0
|
||||
temp: float = 300.0
|
||||
progname: str = "dice"
|
||||
randominit: str = "first"
|
||||
|
||||
def __post_init__(self):
|
||||
if not isinstance(self.ljname, str):
|
||||
raise ValueError("Error: 'ljname' keyword not specified in config file")
|
||||
|
||||
if not isinstance(self.outname, str):
|
||||
raise ValueError("Error: 'outname' keyword not specified in config file")
|
||||
|
||||
if not isinstance(self.dens, float):
|
||||
raise ValueError("Error: 'dens' keyword not specified in config file")
|
||||
|
||||
if not isinstance(self.nmol, list):
|
||||
raise ValueError(
|
||||
"Error: 'nmol' keyword not defined appropriately in config file"
|
||||
)
|
||||
|
||||
if not isinstance(self.nstep, list) or len(self.nstep) not in (2, 3):
|
||||
raise ValueError(
|
||||
"Error: 'nstep' keyword not defined appropriately in config file"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, params: dict):
|
||||
params = {f.name: params[f.name] for f in fields(cls) if f.name in params}
|
||||
|
||||
return cls(**params)
|
||||
@@ -1,30 +0,0 @@
|
||||
from diceplayer.shared.utils.dataclass_protocol import Dataclass
|
||||
|
||||
from dataclasses import dataclass, fields
|
||||
|
||||
|
||||
@dataclass
|
||||
class GaussianDTO(Dataclass):
|
||||
"""
|
||||
Data Transfer Object for the Gaussian configuration.
|
||||
"""
|
||||
|
||||
level: str
|
||||
qmprog: str
|
||||
|
||||
chgmult = [0, 1]
|
||||
pop: str = "chelpg"
|
||||
chg_tol: float = 0.01
|
||||
keywords: str = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.qmprog not in ("g03", "g09", "g16"):
|
||||
raise ValueError("Error: invalid qmprog value.")
|
||||
if self.level is None:
|
||||
raise ValueError("Error: 'level' keyword not specified in config file.")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, params: dict):
|
||||
params = {f.name: params[f.name] for f in fields(cls) if f.name in params}
|
||||
|
||||
return cls(**params)
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from diceplayer.shared.config.player_config import PlayerConfig
|
||||
from diceplayer.config.player_config import PlayerConfig
|
||||
from diceplayer.shared.environment.system import System
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from diceplayer import logger
|
||||
from diceplayer.shared.config.player_config import PlayerConfig
|
||||
from diceplayer.config.player_config import PlayerConfig
|
||||
from diceplayer.shared.environment.system import System
|
||||
from diceplayer.shared.interface import Interface
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from diceplayer import logger
|
||||
from diceplayer.shared.config.player_config import PlayerConfig
|
||||
from diceplayer.config.player_config import PlayerConfig
|
||||
from diceplayer.shared.environment.atom import Atom
|
||||
from diceplayer.shared.environment.molecule import Molecule
|
||||
from diceplayer.shared.environment.system import System
|
||||
|
||||
Reference in New Issue
Block a user