Files
DicePlayer/diceplayer/dice/dice_input.py

258 lines
6.4 KiB
Python

from diceplayer.config import PlayerConfig
from diceplayer.logger import logger
from pydantic import BaseModel, Field
from typing_extensions import Self
import random
from enum import StrEnum
from pathlib import Path
from typing import Annotated, Any, Literal, TextIO
_ALLOWED_DICE_KEYWORD_IN_ORDER = [
"title",
"ncores",
"ljname",
"outname",
"nmol",
"dens",
"temp",
"press",
"seed",
"init",
"nstep",
"vstep",
"mstop",
"accum",
"iprint",
"isave",
"irdf",
"upbuf",
]
class DiceRoutineType(StrEnum):
NVT_TER = "nvt.ter"
NVT_EQ = "nvt.eq"
NPT_TER = "npt.ter"
NPT_EQ = "npt.eq"
def get_nstep(config, idx: int) -> int:
if len(config.dice.nstep) > idx:
return config.dice.nstep[idx]
return config.dice.nstep[-1]
def get_seed(config) -> int:
return config.dice.seed or random.randint(0, 2**32 - 1)
def get_ncores(config) -> int:
return max(1, int(config.ncores / config.dice.nprocs))
# -----------------------------------------------------
# NVT THERMALIZATION
# -----------------------------------------------------
class NVTTerConfig(BaseModel):
type: Literal[DiceRoutineType.NVT_TER] = DiceRoutineType.NVT_TER
title: str = "NVT Thermalization"
ncores: int
ljname: str
outname: str
nmol: list[int]
dens: float
temp: float
seed: int
init: Literal["yes"] = "yes"
nstep: int
vstep: Literal[0] = 0
isave: int
upbuf: int
@classmethod
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
return cls(
ncores=get_ncores(config),
ljname=str(config.dice.ljname),
outname=config.dice.outname,
nmol=config.dice.nmol,
dens=config.dice.dens,
temp=config.dice.temp,
seed=get_seed(config),
nstep=get_nstep(config, 0),
isave=config.dice.isave,
upbuf=config.dice.upbuf,
**kwargs,
)
# -----------------------------------------------------
# NVT PRODUCTION
# -----------------------------------------------------
class NVTEqConfig(BaseModel):
type: Literal[DiceRoutineType.NVT_EQ] = DiceRoutineType.NVT_EQ
title: str = "NVT Production"
ncores: int
ljname: str
outname: str
nmol: list[int]
dens: float
temp: float
seed: int
init: Literal["no", "yesreadxyz"] = "no"
nstep: int
vstep: int
isave: int
irdf: Literal[0] = 0
upbuf: int
@classmethod
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
return cls(
ncores=get_ncores(config),
ljname=str(config.dice.ljname),
outname=config.dice.outname,
nmol=config.dice.nmol,
dens=config.dice.dens,
temp=config.dice.temp,
seed=get_seed(config),
nstep=get_nstep(config, 1),
vstep=config.dice.vstep,
isave=max(1, get_nstep(config, 1) // 10),
upbuf=config.dice.upbuf,
**kwargs,
)
# -----------------------------------------------------
# NPT THERMALIZATION
# -----------------------------------------------------
class NPTTerConfig(BaseModel):
type: Literal[DiceRoutineType.NPT_TER] = DiceRoutineType.NPT_TER
title: str = "NPT Thermalization"
ncores: int
ljname: str
outname: str
nmol: list[int]
dens: float
temp: float
press: float
seed: int
init: Literal["yes", "yesreadxyz"] = "yes"
nstep: int
vstep: int
isave: int
upbuf: int
@classmethod
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
return cls(
ncores=get_ncores(config),
ljname=str(config.dice.ljname),
outname=config.dice.outname,
nmol=config.dice.nmol,
dens=config.dice.dens,
temp=config.dice.temp,
press=config.dice.press,
seed=get_seed(config),
nstep=get_nstep(config, 1),
vstep=max(1, config.dice.vstep),
isave=config.dice.isave,
upbuf=config.dice.upbuf,
**kwargs,
)
# -----------------------------------------------------
# NPT PRODUCTION
# -----------------------------------------------------
class NPTEqConfig(BaseModel):
type: Literal[DiceRoutineType.NPT_EQ] = DiceRoutineType.NPT_EQ
title: str = "NPT Production"
ncores: int
ljname: str
outname: str
nmol: list[int]
dens: float
temp: float
press: float
seed: int
init: Literal["yes", "yesreadxyz"] = "yes"
nstep: int
vstep: int
isave: int
irdf: Literal[0] = 0
upbuf: int
@classmethod
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
return cls(
ncores=get_ncores(config),
ljname=str(config.dice.ljname),
outname=config.dice.outname,
nmol=config.dice.nmol,
dens=config.dice.dens,
temp=config.dice.temp,
press=config.dice.press,
seed=get_seed(config),
nstep=get_nstep(config, 2),
vstep=config.dice.vstep,
isave=max(1, get_nstep(config, 2) // 10),
upbuf=config.dice.upbuf,
**kwargs,
)
DiceInputConfig = Annotated[
NVTTerConfig | NVTEqConfig | NPTTerConfig | NPTEqConfig,
Field(discriminator="type"),
]
def _serialize_value(value: Any) -> str:
if value is None:
raise ValueError("DICE configuration cannot serialize None values")
if isinstance(value, bool):
return "yes" if value else "no"
if isinstance(value, (list, tuple)):
return " ".join(str(v) for v in value)
return str(value)
def write_dice_config(obj: DiceInputConfig, io_writer: TextIO) -> None:
values = {f: getattr(obj, f) for f in obj.__class__.model_fields}
for key in _ALLOWED_DICE_KEYWORD_IN_ORDER:
value = values.pop(key, None)
if value is None:
continue
io_writer.write(f"{key} = {_serialize_value(value)}\n")
io_writer.write("$end\n")
def write_config(config: DiceInputConfig, directory: Path) -> Path:
input_path = directory / config.type
if input_path.exists():
logger.info(
f"Dice input file {input_path} already exists and will be overwritten"
)
input_path.unlink()
input_path.parent.mkdir(parents=True, exist_ok=True)
with open(input_path, "w") as io:
write_dice_config(config, io)
return input_path