refactor: update Python version and optimize dice configuration parameters
This commit is contained in:
@@ -1,16 +1,16 @@
|
||||
from diceplayer.config import PlayerConfig
|
||||
from diceplayer.logger import logger
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Self
|
||||
|
||||
from abc import ABC
|
||||
from dataclasses import dataclass, fields
|
||||
import random
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Any, Sequence, TextIO
|
||||
from typing import Annotated, Any, Literal, TextIO
|
||||
|
||||
|
||||
|
||||
DICE_KEYWORD_ORDER = [
|
||||
_ALLOWED_DICE_KEYWORD_IN_ORDER = [
|
||||
"title",
|
||||
"ncores",
|
||||
"ljname",
|
||||
@@ -32,220 +32,226 @@ DICE_KEYWORD_ORDER = [
|
||||
]
|
||||
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class BaseConfig(ABC):
|
||||
ncores: int
|
||||
ljname: str
|
||||
outname: str
|
||||
nmol: Sequence[int]
|
||||
temp: float
|
||||
seed: int
|
||||
isave: int
|
||||
|
||||
def write(self, directory: Path, filename: str = "input") -> Path:
|
||||
input_path = directory / filename
|
||||
|
||||
if input_path.exists():
|
||||
logger.info(
|
||||
f"Dice input file {input_path} already exists and will be overwritten"
|
||||
)
|
||||
input_path.unlink()
|
||||
input_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(input_path, "w") as io:
|
||||
self.write_dice_config(io)
|
||||
|
||||
return input_path
|
||||
|
||||
def write_dice_config(self, io_writer: TextIO) -> None:
|
||||
values = {f.name: getattr(self, f.name) for f in fields(self)}
|
||||
|
||||
for key in DICE_KEYWORD_ORDER:
|
||||
value = values.pop(key, None)
|
||||
if value is None:
|
||||
continue
|
||||
io_writer.write(f"{key} = {self._serialize_value(value)}\n")
|
||||
|
||||
# write any remaining fields (future extensions)
|
||||
for key, value in values.items():
|
||||
if value is None:
|
||||
continue
|
||||
io_writer.write(f"{key} = {self._serialize_value(value)}\n")
|
||||
|
||||
io_writer.write("$end\n")
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
|
||||
base_fields = cls._extract_base_fields(config)
|
||||
return cls(**(base_fields | kwargs))
|
||||
|
||||
@staticmethod
|
||||
def _extract_base_fields(config: PlayerConfig) -> dict[str, Any]:
|
||||
return dict(
|
||||
ncores=int(config.ncores / config.dice.nprocs),
|
||||
ljname=config.dice.ljname,
|
||||
outname=config.dice.outname,
|
||||
nmol=config.dice.nmol,
|
||||
temp=config.dice.temp,
|
||||
seed=config.dice.seed,
|
||||
isave=config.dice.isave,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_nstep(config: PlayerConfig, idx: int) -> int:
|
||||
if len(config.dice.nstep) > idx:
|
||||
return config.dice.nstep[idx]
|
||||
return config.dice.nstep[-1]
|
||||
|
||||
@staticmethod
|
||||
def _serialize_value(value: Any) -> str:
|
||||
if value is None:
|
||||
raise ValueError("DICE configuration cannot serialize None values")
|
||||
|
||||
if isinstance(value, bool):
|
||||
return "yes" if value else "no"
|
||||
|
||||
if isinstance(value, (list, tuple)):
|
||||
return " ".join(str(v) for v in value)
|
||||
|
||||
return str(value)
|
||||
class DiceRoutineType(StrEnum):
|
||||
NVT_TER = "nvt.ter"
|
||||
NVT_EQ = "nvt.eq"
|
||||
NPT_TER = "npt.ter"
|
||||
NPT_EQ = "npt.eq"
|
||||
|
||||
|
||||
# -----------------------------------------------------
|
||||
# NVT BASE
|
||||
# -----------------------------------------------------
|
||||
def get_nstep(config, idx: int) -> int:
|
||||
if len(config.dice.nstep) > idx:
|
||||
return config.dice.nstep[idx]
|
||||
return config.dice.nstep[-1]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class NVTConfig(BaseConfig):
|
||||
title: str = "Diceplayer Run - NVT"
|
||||
dens: float = ...
|
||||
nstep: int = ...
|
||||
vstep: int = 0
|
||||
def get_seed(config) -> int:
|
||||
return config.dice.seed or random.randint(0, 2**32 - 1)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
|
||||
return super(NVTConfig, cls).from_config(
|
||||
config,
|
||||
dens=config.dice.dens,
|
||||
nstep=cls._get_nstep(config, 0),
|
||||
)
|
||||
|
||||
def get_ncores(config) -> int:
|
||||
return max(1, int(config.ncores / config.dice.nprocs))
|
||||
|
||||
|
||||
# -----------------------------------------------------
|
||||
# NVT THERMALIZATION
|
||||
# -----------------------------------------------------
|
||||
class NVTTerConfig(BaseModel):
|
||||
type: Literal[DiceRoutineType.NVT_TER] = DiceRoutineType.NVT_TER
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class NVTTerConfig(NVTConfig):
|
||||
title: str = "Diceplayer Run - NVT Thermalization"
|
||||
upbuf: int = 360
|
||||
init: str = "yes"
|
||||
title: str = "NVT Thermalization"
|
||||
ncores: int
|
||||
ljname: str
|
||||
outname: str
|
||||
nmol: list[int]
|
||||
dens: float
|
||||
temp: float
|
||||
seed: int
|
||||
init: Literal["yes"] = "yes"
|
||||
nstep: int
|
||||
vstep: Literal[0] = 0
|
||||
isave: int
|
||||
upbuf: int
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
|
||||
return super(NVTTerConfig, cls).from_config(
|
||||
config,
|
||||
nstep=cls._get_nstep(config, 0),
|
||||
return cls(
|
||||
ncores=get_ncores(config),
|
||||
ljname=str(config.dice.ljname),
|
||||
outname=config.dice.outname,
|
||||
nmol=config.dice.nmol,
|
||||
dens=config.dice.dens,
|
||||
temp=config.dice.temp,
|
||||
seed=get_seed(config),
|
||||
nstep=get_nstep(config, 0),
|
||||
isave=config.dice.isave,
|
||||
upbuf=config.dice.upbuf,
|
||||
vstep=0,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def write(self, directory: Path, filename: str = "nvt.ter") -> Path:
|
||||
return super(NVTTerConfig, self).write(directory, filename)
|
||||
|
||||
|
||||
# -----------------------------------------------------
|
||||
# NVT PRODUCTION
|
||||
# -----------------------------------------------------
|
||||
class NVTEqConfig(BaseModel):
|
||||
type: Literal[DiceRoutineType.NVT_EQ] = DiceRoutineType.NVT_EQ
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class NVTEqConfig(NVTConfig):
|
||||
title: str = "Diceplayer Run - NVT Production"
|
||||
irdf: int = 0
|
||||
init: str = "yesreadxyz"
|
||||
title: str = "NVT Production"
|
||||
ncores: int
|
||||
ljname: str
|
||||
outname: str
|
||||
nmol: list[int]
|
||||
dens: float
|
||||
temp: float
|
||||
seed: int
|
||||
init: Literal["no", "yesreadxyz"] = "no"
|
||||
nstep: int
|
||||
vstep: int
|
||||
isave: int
|
||||
irdf: Literal[0] = 0
|
||||
upbuf: int
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
|
||||
return super(NVTEqConfig, cls).from_config(
|
||||
config,
|
||||
nstep=cls._get_nstep(config, 1),
|
||||
irdf=config.dice.irdf,
|
||||
vstep=0,
|
||||
return cls(
|
||||
ncores=get_ncores(config),
|
||||
ljname=str(config.dice.ljname),
|
||||
outname=config.dice.outname,
|
||||
nmol=config.dice.nmol,
|
||||
dens=config.dice.dens,
|
||||
temp=config.dice.temp,
|
||||
seed=get_seed(config),
|
||||
nstep=get_nstep(config, 1),
|
||||
vstep=config.dice.vstep,
|
||||
isave=max(1, get_nstep(config, 1) // 10),
|
||||
upbuf=config.dice.upbuf,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def write(self, directory: Path, filename: str = "nvt.eq") -> Path:
|
||||
return super(NVTEqConfig, self).write(directory, filename)
|
||||
|
||||
|
||||
# -----------------------------------------------------
|
||||
# NPT BASE
|
||||
# -----------------------------------------------------
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class NPTConfig(BaseConfig):
|
||||
title: str = "Diceplayer Run - NPT"
|
||||
nstep: int = 0
|
||||
vstep: int = 5000
|
||||
press: float = 1.0
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
|
||||
return super(NPTConfig, cls).from_config(
|
||||
config,
|
||||
press=config.dice.press,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------
|
||||
# NPT THERMALIZATION
|
||||
# -----------------------------------------------------
|
||||
class NPTTerConfig(BaseModel):
|
||||
type: Literal[DiceRoutineType.NPT_TER] = DiceRoutineType.NPT_TER
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class NPTTerConfig(NPTConfig):
|
||||
title: str = "Diceplayer Run - NPT Thermalization"
|
||||
dens: float | None = None
|
||||
title: str = "NPT Thermalization"
|
||||
ncores: int
|
||||
ljname: str
|
||||
outname: str
|
||||
nmol: list[int]
|
||||
dens: float
|
||||
temp: float
|
||||
press: float
|
||||
seed: int
|
||||
init: Literal["yes", "yesreadxyz"] = "yes"
|
||||
nstep: int
|
||||
vstep: int
|
||||
isave: int
|
||||
upbuf: int
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
|
||||
return super(NPTTerConfig, cls).from_config(
|
||||
config,
|
||||
return cls(
|
||||
ncores=get_ncores(config),
|
||||
ljname=str(config.dice.ljname),
|
||||
outname=config.dice.outname,
|
||||
nmol=config.dice.nmol,
|
||||
dens=config.dice.dens,
|
||||
nstep=cls._get_nstep(config, 1),
|
||||
vstep=config.dice.vstep,
|
||||
temp=config.dice.temp,
|
||||
press=config.dice.press,
|
||||
seed=get_seed(config),
|
||||
nstep=get_nstep(config, 1),
|
||||
vstep=max(1, config.dice.vstep),
|
||||
isave=config.dice.isave,
|
||||
upbuf=config.dice.upbuf,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def write(self, directory: Path, filename: str = "npt.ter") -> Path:
|
||||
return super(NPTTerConfig, self).write(directory, filename)
|
||||
|
||||
|
||||
# -----------------------------------------------------
|
||||
# NPT PRODUCTION
|
||||
# -----------------------------------------------------
|
||||
class NPTEqConfig(BaseModel):
|
||||
type: Literal[DiceRoutineType.NPT_EQ] = DiceRoutineType.NPT_EQ
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class NPTEqConfig(NPTConfig):
|
||||
title: str = "Diceplayer Run - NPT Production"
|
||||
dens: float | None = None
|
||||
title: str = "NPT Production"
|
||||
ncores: int
|
||||
ljname: str
|
||||
outname: str
|
||||
nmol: list[int]
|
||||
dens: float
|
||||
temp: float
|
||||
press: float
|
||||
seed: int
|
||||
init: Literal["yes", "yesreadxyz"] = "yes"
|
||||
nstep: int
|
||||
vstep: int
|
||||
isave: int
|
||||
irdf: Literal[0] = 0
|
||||
upbuf: int
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
|
||||
return super(NPTEqConfig, cls).from_config(
|
||||
config,
|
||||
return cls(
|
||||
ncores=get_ncores(config),
|
||||
ljname=str(config.dice.ljname),
|
||||
outname=config.dice.outname,
|
||||
nmol=config.dice.nmol,
|
||||
dens=config.dice.dens,
|
||||
nstep=cls._get_nstep(config, 2),
|
||||
temp=config.dice.temp,
|
||||
press=config.dice.press,
|
||||
seed=get_seed(config),
|
||||
nstep=get_nstep(config, 2),
|
||||
vstep=config.dice.vstep,
|
||||
isave=max(1, get_nstep(config, 2) // 10),
|
||||
upbuf=config.dice.upbuf,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def write(self, directory: Path, filename: str = "npt.eq") -> Path:
|
||||
return super(NPTEqConfig, self).write(directory, filename)
|
||||
|
||||
DiceInputConfig = Annotated[
|
||||
NVTTerConfig | NVTEqConfig | NPTTerConfig | NPTEqConfig,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
def _serialize_value(value: Any) -> str:
|
||||
if value is None:
|
||||
raise ValueError("DICE configuration cannot serialize None values")
|
||||
|
||||
if isinstance(value, bool):
|
||||
return "yes" if value else "no"
|
||||
|
||||
if isinstance(value, (list, tuple)):
|
||||
return " ".join(str(v) for v in value)
|
||||
|
||||
return str(value)
|
||||
|
||||
|
||||
def write_dice_config(obj: DiceInputConfig, io_writer: TextIO) -> None:
|
||||
values = {f: getattr(obj, f) for f in obj.__class__.model_fields}
|
||||
|
||||
for key in _ALLOWED_DICE_KEYWORD_IN_ORDER:
|
||||
value = values.pop(key, None)
|
||||
if value is None:
|
||||
continue
|
||||
io_writer.write(f"{key} = {_serialize_value(value)}\n")
|
||||
|
||||
io_writer.write("$end\n")
|
||||
|
||||
|
||||
def write_config(config: DiceInputConfig, directory: Path) -> Path:
|
||||
input_path = directory / config.type
|
||||
|
||||
if input_path.exists():
|
||||
logger.info(
|
||||
f"Dice input file {input_path} already exists and will be overwritten"
|
||||
)
|
||||
input_path.unlink()
|
||||
input_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(input_path, "w") as io:
|
||||
write_dice_config(config, io)
|
||||
|
||||
return input_path
|
||||
|
||||
Reference in New Issue
Block a user