from diceplayer.config import PlayerConfig from diceplayer.logger import logger from pydantic import BaseModel, Field from typing_extensions import Self import random from enum import StrEnum from pathlib import Path from typing import Annotated, Any, Literal, TextIO _ALLOWED_DICE_KEYWORD_IN_ORDER = [ "title", "ncores", "ljname", "outname", "nmol", "dens", "temp", "press", "seed", "init", "nstep", "vstep", "mstop", "accum", "iprint", "isave", "irdf", "upbuf", ] class DiceRoutineType(StrEnum): NVT_TER = "nvt.ter" NVT_EQ = "nvt.eq" NPT_TER = "npt.ter" NPT_EQ = "npt.eq" def get_nstep(config, idx: int) -> int: if len(config.dice.nstep) > idx: return config.dice.nstep[idx] return config.dice.nstep[-1] def get_seed(config) -> int: return config.dice.seed or random.randint(0, 2**32 - 1) def get_ncores(config) -> int: return max(1, int(config.ncores / config.dice.nprocs)) # ----------------------------------------------------- # NVT THERMALIZATION # ----------------------------------------------------- class NVTTerConfig(BaseModel): type: Literal[DiceRoutineType.NVT_TER] = DiceRoutineType.NVT_TER title: str = "NVT Thermalization" ncores: int ljname: str outname: str nmol: list[int] dens: float temp: float seed: int init: Literal["yes"] = "yes" nstep: int vstep: Literal[0] = 0 isave: int upbuf: int @classmethod def from_config(cls, config: PlayerConfig, **kwargs) -> Self: return cls( ncores=get_ncores(config), ljname=str(config.dice.ljname), outname=config.dice.outname, nmol=config.dice.nmol, dens=config.dice.dens, temp=config.dice.temp, seed=get_seed(config), nstep=get_nstep(config, 0), isave=config.dice.isave, upbuf=config.dice.upbuf, **kwargs, ) # ----------------------------------------------------- # NVT PRODUCTION # ----------------------------------------------------- class NVTEqConfig(BaseModel): type: Literal[DiceRoutineType.NVT_EQ] = DiceRoutineType.NVT_EQ title: str = "NVT Production" ncores: int ljname: str outname: str nmol: list[int] dens: float temp: float seed: int init: Literal["no", "yesreadxyz"] = "no" nstep: int vstep: int isave: int irdf: Literal[0] = 0 upbuf: int @classmethod def from_config(cls, config: PlayerConfig, **kwargs) -> Self: return cls( ncores=get_ncores(config), ljname=str(config.dice.ljname), outname=config.dice.outname, nmol=config.dice.nmol, dens=config.dice.dens, temp=config.dice.temp, seed=get_seed(config), nstep=get_nstep(config, 1), vstep=config.dice.vstep, isave=max(1, get_nstep(config, 1) // 10), upbuf=config.dice.upbuf, **kwargs, ) # ----------------------------------------------------- # NPT THERMALIZATION # ----------------------------------------------------- class NPTTerConfig(BaseModel): type: Literal[DiceRoutineType.NPT_TER] = DiceRoutineType.NPT_TER title: str = "NPT Thermalization" ncores: int ljname: str outname: str nmol: list[int] dens: float temp: float press: float seed: int init: Literal["yes", "yesreadxyz"] = "yes" nstep: int vstep: int isave: int upbuf: int @classmethod def from_config(cls, config: PlayerConfig, **kwargs) -> Self: return cls( ncores=get_ncores(config), ljname=str(config.dice.ljname), outname=config.dice.outname, nmol=config.dice.nmol, dens=config.dice.dens, temp=config.dice.temp, press=config.dice.press, seed=get_seed(config), nstep=get_nstep(config, 1), vstep=max(1, config.dice.vstep), isave=config.dice.isave, upbuf=config.dice.upbuf, **kwargs, ) # ----------------------------------------------------- # NPT PRODUCTION # ----------------------------------------------------- class NPTEqConfig(BaseModel): type: Literal[DiceRoutineType.NPT_EQ] = DiceRoutineType.NPT_EQ title: str = "NPT Production" ncores: int ljname: str outname: str nmol: list[int] dens: float temp: float press: float seed: int init: Literal["yes", "yesreadxyz"] = "yes" nstep: int vstep: int isave: int irdf: Literal[0] = 0 upbuf: int @classmethod def from_config(cls, config: PlayerConfig, **kwargs) -> Self: return cls( ncores=get_ncores(config), ljname=str(config.dice.ljname), outname=config.dice.outname, nmol=config.dice.nmol, dens=config.dice.dens, temp=config.dice.temp, press=config.dice.press, seed=get_seed(config), nstep=get_nstep(config, 2), vstep=config.dice.vstep, isave=max(1, get_nstep(config, 2) // 10), upbuf=config.dice.upbuf, **kwargs, ) DiceInputConfig = Annotated[ NVTTerConfig | NVTEqConfig | NPTTerConfig | NPTEqConfig, Field(discriminator="type"), ] def _serialize_value(value: Any) -> str: if value is None: raise ValueError("DICE configuration cannot serialize None values") if isinstance(value, bool): return "yes" if value else "no" if isinstance(value, (list, tuple)): return " ".join(str(v) for v in value) return str(value) def write_dice_config(obj: DiceInputConfig, io_writer: TextIO) -> None: values = {f: getattr(obj, f) for f in obj.__class__.model_fields} for key in _ALLOWED_DICE_KEYWORD_IN_ORDER: value = values.pop(key, None) if value is None: continue io_writer.write(f"{key} = {_serialize_value(value)}\n") io_writer.write("$end\n") def write_config(config: DiceInputConfig, directory: Path) -> Path: input_path = directory / config.type if input_path.exists(): logger.info( f"Dice input file {input_path} already exists and will be overwritten" ) input_path.unlink() input_path.parent.mkdir(parents=True, exist_ok=True) with open(input_path, "w") as io: write_dice_config(config, io) return input_path