Files
DicePlayer/diceplayer/config/player_config.py

91 lines
3.2 KiB
Python

from diceplayer.config.dice_config import DiceConfig
from diceplayer.config.gaussian_config import GaussianConfig
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Any
from enum import Enum
from pathlib import Path
MIN_STEP = 20000
STEP_INCREMENT = 1000
class RoutineType(str, Enum):
CHARGE = "charge"
GEOMETRY = "geometry"
BOTH = "both"
class PlayerConfig(BaseModel):
"""
Configuration for DICEPlayer simulations.
Attributes:
type: Type of simulation to perform (charge, geometry, or both).
max_cyc: Maximum number of cycles for the geometry optimization.
switch_cyc: Cycle at which to switch from charge to geometry optimization (if type is "both").
mem: Memory configuration for QM calculations.
nprocs: Number of processors to use for QM calculations.
ncores: Number of cores to use for QM calculations.
dice: Configuration parameters specific to DICE simulations.
gaussian: Configuration parameters specific to Gaussian calculations.
altsteps: Number of steps for the alternate simulation (default: 20000).
geoms_file: File name for the geometries output (default: "geoms.xyz").
simulation_dir: Directory name for the simulation files (default: "simfiles").
"""
model_config = ConfigDict(
frozen=True,
)
type: RoutineType = Field(..., description="Type of simulation to perform")
max_cyc: int = Field(
..., description="Maximum number of cycles for the geometry optimization", gt=0
)
switch_cyc: int = Field(..., description="Switch cycle configuration")
mem: int = Field(None, description="Memory configuration")
ncores: int = Field(
..., description="Number of cores to use for the QM calculations"
)
dice: DiceConfig = Field(..., description="Dice configuration")
gaussian: GaussianConfig = Field(..., description="Gaussian configuration")
altsteps: int = Field(
20000, description="Number of steps for the alternate simulation"
)
geoms_file: Path = Field(
Path("geoms.xyz"), description="File name for the geometries output"
)
simulation_dir: Path = Field(
Path("simfiles"), description="Directory name for the simulation files"
)
@model_validator(mode="before")
@staticmethod
def validate_altsteps(fields) -> dict[str, Any]:
altsteps = fields.pop("altsteps", MIN_STEP)
fields["altsteps"] = (
round(max(MIN_STEP, altsteps) / STEP_INCREMENT) * STEP_INCREMENT
)
return fields
@model_validator(mode="before")
@staticmethod
def validate_switch_cyc(fields: dict[str, Any]) -> dict[str, Any]:
max_cyc = int(fields.get("max_cyc", 0))
switch_cyc = int(fields.get("switch_cyc", max_cyc))
if fields.get("type") == "both" and not switch_cyc < max_cyc:
raise ValueError("switch_cyc must be less than max_cyc when type='both'.")
if fields.get("type") != "both" and switch_cyc != max_cyc:
raise ValueError(
"switch_cyc must be equal to max_cyc when type is not 'both'."
)
return fields