81 lines
2.5 KiB
Python
81 lines
2.5 KiB
Python
from diceplayer.config.dice_config import DiceConfig
|
|
from diceplayer.config.gaussian_config import GaussianConfig
|
|
|
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
from typing_extensions import Any
|
|
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
|
|
|
|
MIN_STEP = 20000
|
|
STEP_INCREMENT = 1000
|
|
|
|
|
|
class RoutineType(str, Enum):
|
|
CHARGE = "charge"
|
|
GEOMETRY = "geometry"
|
|
BOTH = "both"
|
|
|
|
|
|
class PlayerConfig(BaseModel):
|
|
"""
|
|
Data Transfer Object for the player configuration.
|
|
"""
|
|
|
|
model_config = ConfigDict(
|
|
frozen=True,
|
|
)
|
|
|
|
type: RoutineType = Field(..., description="Type of simulation to perform")
|
|
max_cyc: int = Field(
|
|
..., description="Maximum number of cycles for the geometry optimization", gt=0
|
|
)
|
|
switch_cyc: int = Field(..., description="Switch cycle configuration")
|
|
|
|
mem: int = Field(None, description="Memory configuration")
|
|
nprocs: int = Field(
|
|
..., description="Number of processors to use for the QM calculations"
|
|
)
|
|
ncores: int = Field(
|
|
..., description="Number of cores to use for the QM calculations"
|
|
)
|
|
|
|
dice: DiceConfig = Field(..., description="Dice configuration")
|
|
gaussian: GaussianConfig = Field(..., description="Gaussian configuration")
|
|
|
|
altsteps: int = Field(
|
|
20000, description="Number of steps for the alternate simulation"
|
|
)
|
|
geoms_file: Path = Field(
|
|
Path("geoms.xyz"), description="File name for the geometries output"
|
|
)
|
|
simulation_dir: Path = Field(
|
|
Path("simfiles"), description="Directory name for the simulation files"
|
|
)
|
|
|
|
@model_validator(mode="before")
|
|
@staticmethod
|
|
def validate_altsteps(fields) -> dict[str, Any]:
|
|
altsteps = fields.pop("altsteps", MIN_STEP)
|
|
fields["altsteps"] = (
|
|
round(max(MIN_STEP, altsteps) / STEP_INCREMENT) * STEP_INCREMENT
|
|
)
|
|
return fields
|
|
|
|
@model_validator(mode="before")
|
|
@staticmethod
|
|
def validate_switch_cyc(fields: dict[str, Any]) -> dict[str, Any]:
|
|
max_cyc = int(fields.get("max_cyc", 0))
|
|
switch_cyc = int(fields.get("switch_cyc", max_cyc))
|
|
|
|
if fields.get("type") == "both" and not switch_cyc < max_cyc:
|
|
raise ValueError("switch_cyc must be less than max_cyc when type='both'.")
|
|
|
|
if fields.get("type") != "both" and switch_cyc != max_cyc:
|
|
raise ValueError(
|
|
"switch_cyc must be equal to max_cyc when type is not 'both'."
|
|
)
|
|
|
|
return fields
|