from diceplayer.config.dice_config import DiceConfig from diceplayer.config.gaussian_config import GaussianConfig from pydantic import BaseModel, ConfigDict, Field, model_validator from typing_extensions import Any from enum import Enum from pathlib import Path MIN_STEP = 20000 STEP_INCREMENT = 1000 class RoutineType(str, Enum): CHARGE = "charge" GEOMETRY = "geometry" BOTH = "both" class PlayerConfig(BaseModel): """ Configuration for DICEPlayer simulations. Attributes: type: Type of simulation to perform (charge, geometry, or both). max_cyc: Maximum number of cycles for the geometry optimization. switch_cyc: Cycle at which to switch from charge to geometry optimization (if type is "both"). mem: Memory configuration for QM calculations. nprocs: Number of processors to use for QM calculations. ncores: Number of cores to use for QM calculations. dice: Configuration parameters specific to DICE simulations. gaussian: Configuration parameters specific to Gaussian calculations. altsteps: Number of steps for the alternate simulation (default: 20000). geoms_file: File name for the geometries output (default: "geoms.xyz"). simulation_dir: Directory name for the simulation files (default: "simfiles"). """ model_config = ConfigDict( frozen=True, ) type: RoutineType = Field(..., description="Type of simulation to perform") max_cyc: int = Field( ..., description="Maximum number of cycles for the geometry optimization", gt=0 ) switch_cyc: int = Field(..., description="Switch cycle configuration") mem: int = Field(None, description="Memory configuration") nprocs: int = Field( ..., description="Number of processors to use for the QM calculations" ) ncores: int = Field( ..., description="Number of cores to use for the QM calculations" ) dice: DiceConfig = Field(..., description="Dice configuration") gaussian: GaussianConfig = Field(..., description="Gaussian configuration") altsteps: int = Field( 20000, description="Number of steps for the alternate simulation" ) geoms_file: Path = Field( Path("geoms.xyz"), description="File name for the geometries output" ) simulation_dir: Path = Field( Path("simfiles"), description="Directory name for the simulation files" ) @model_validator(mode="before") @staticmethod def validate_altsteps(fields) -> dict[str, Any]: altsteps = fields.pop("altsteps", MIN_STEP) fields["altsteps"] = ( round(max(MIN_STEP, altsteps) / STEP_INCREMENT) * STEP_INCREMENT ) return fields @model_validator(mode="before") @staticmethod def validate_switch_cyc(fields: dict[str, Any]) -> dict[str, Any]: max_cyc = int(fields.get("max_cyc", 0)) switch_cyc = int(fields.get("switch_cyc", max_cyc)) if fields.get("type") == "both" and not switch_cyc < max_cyc: raise ValueError("switch_cyc must be less than max_cyc when type='both'.") if fields.get("type") != "both" and switch_cyc != max_cyc: raise ValueError( "switch_cyc must be equal to max_cyc when type is not 'both'." ) return fields