refactor: update Python version and optimize dice configuration parameters

This commit is contained in:
2026-03-24 23:01:45 -03:00
parent 0763c4a9e1
commit 0470200d00
12 changed files with 228 additions and 233 deletions

View File

@@ -1,16 +1,16 @@
from diceplayer.config import PlayerConfig
from diceplayer.logger import logger
from pydantic import BaseModel, Field
from typing_extensions import Self
from abc import ABC
from dataclasses import dataclass, fields
import random
from enum import StrEnum
from pathlib import Path
from typing import Any, Sequence, TextIO
from typing import Annotated, Any, Literal, TextIO
DICE_KEYWORD_ORDER = [
_ALLOWED_DICE_KEYWORD_IN_ORDER = [
"title",
"ncores",
"ljname",
@@ -32,220 +32,226 @@ DICE_KEYWORD_ORDER = [
]
@dataclass(slots=True)
class BaseConfig(ABC):
ncores: int
ljname: str
outname: str
nmol: Sequence[int]
temp: float
seed: int
isave: int
def write(self, directory: Path, filename: str = "input") -> Path:
input_path = directory / filename
if input_path.exists():
logger.info(
f"Dice input file {input_path} already exists and will be overwritten"
)
input_path.unlink()
input_path.parent.mkdir(parents=True, exist_ok=True)
with open(input_path, "w") as io:
self.write_dice_config(io)
return input_path
def write_dice_config(self, io_writer: TextIO) -> None:
values = {f.name: getattr(self, f.name) for f in fields(self)}
for key in DICE_KEYWORD_ORDER:
value = values.pop(key, None)
if value is None:
continue
io_writer.write(f"{key} = {self._serialize_value(value)}\n")
# write any remaining fields (future extensions)
for key, value in values.items():
if value is None:
continue
io_writer.write(f"{key} = {self._serialize_value(value)}\n")
io_writer.write("$end\n")
@classmethod
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
base_fields = cls._extract_base_fields(config)
return cls(**(base_fields | kwargs))
@staticmethod
def _extract_base_fields(config: PlayerConfig) -> dict[str, Any]:
return dict(
ncores=int(config.ncores / config.dice.nprocs),
ljname=config.dice.ljname,
outname=config.dice.outname,
nmol=config.dice.nmol,
temp=config.dice.temp,
seed=config.dice.seed,
isave=config.dice.isave,
)
@staticmethod
def _get_nstep(config: PlayerConfig, idx: int) -> int:
if len(config.dice.nstep) > idx:
return config.dice.nstep[idx]
return config.dice.nstep[-1]
@staticmethod
def _serialize_value(value: Any) -> str:
if value is None:
raise ValueError("DICE configuration cannot serialize None values")
if isinstance(value, bool):
return "yes" if value else "no"
if isinstance(value, (list, tuple)):
return " ".join(str(v) for v in value)
return str(value)
class DiceRoutineType(StrEnum):
NVT_TER = "nvt.ter"
NVT_EQ = "nvt.eq"
NPT_TER = "npt.ter"
NPT_EQ = "npt.eq"
# -----------------------------------------------------
# NVT BASE
# -----------------------------------------------------
def get_nstep(config, idx: int) -> int:
if len(config.dice.nstep) > idx:
return config.dice.nstep[idx]
return config.dice.nstep[-1]
@dataclass(slots=True)
class NVTConfig(BaseConfig):
title: str = "Diceplayer Run - NVT"
dens: float = ...
nstep: int = ...
vstep: int = 0
def get_seed(config) -> int:
return config.dice.seed or random.randint(0, 2**32 - 1)
@classmethod
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
return super(NVTConfig, cls).from_config(
config,
dens=config.dice.dens,
nstep=cls._get_nstep(config, 0),
)
def get_ncores(config) -> int:
return max(1, int(config.ncores / config.dice.nprocs))
# -----------------------------------------------------
# NVT THERMALIZATION
# -----------------------------------------------------
class NVTTerConfig(BaseModel):
type: Literal[DiceRoutineType.NVT_TER] = DiceRoutineType.NVT_TER
@dataclass(slots=True)
class NVTTerConfig(NVTConfig):
title: str = "Diceplayer Run - NVT Thermalization"
upbuf: int = 360
init: str = "yes"
title: str = "NVT Thermalization"
ncores: int
ljname: str
outname: str
nmol: list[int]
dens: float
temp: float
seed: int
init: Literal["yes"] = "yes"
nstep: int
vstep: Literal[0] = 0
isave: int
upbuf: int
@classmethod
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
return super(NVTTerConfig, cls).from_config(
config,
nstep=cls._get_nstep(config, 0),
return cls(
ncores=get_ncores(config),
ljname=str(config.dice.ljname),
outname=config.dice.outname,
nmol=config.dice.nmol,
dens=config.dice.dens,
temp=config.dice.temp,
seed=get_seed(config),
nstep=get_nstep(config, 0),
isave=config.dice.isave,
upbuf=config.dice.upbuf,
vstep=0,
**kwargs,
)
def write(self, directory: Path, filename: str = "nvt.ter") -> Path:
return super(NVTTerConfig, self).write(directory, filename)
# -----------------------------------------------------
# NVT PRODUCTION
# -----------------------------------------------------
class NVTEqConfig(BaseModel):
type: Literal[DiceRoutineType.NVT_EQ] = DiceRoutineType.NVT_EQ
@dataclass(slots=True)
class NVTEqConfig(NVTConfig):
title: str = "Diceplayer Run - NVT Production"
irdf: int = 0
init: str = "yesreadxyz"
title: str = "NVT Production"
ncores: int
ljname: str
outname: str
nmol: list[int]
dens: float
temp: float
seed: int
init: Literal["no", "yesreadxyz"] = "no"
nstep: int
vstep: int
isave: int
irdf: Literal[0] = 0
upbuf: int
@classmethod
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
return super(NVTEqConfig, cls).from_config(
config,
nstep=cls._get_nstep(config, 1),
irdf=config.dice.irdf,
vstep=0,
return cls(
ncores=get_ncores(config),
ljname=str(config.dice.ljname),
outname=config.dice.outname,
nmol=config.dice.nmol,
dens=config.dice.dens,
temp=config.dice.temp,
seed=get_seed(config),
nstep=get_nstep(config, 1),
vstep=config.dice.vstep,
isave=max(1, get_nstep(config, 1) // 10),
upbuf=config.dice.upbuf,
**kwargs,
)
def write(self, directory: Path, filename: str = "nvt.eq") -> Path:
return super(NVTEqConfig, self).write(directory, filename)
# -----------------------------------------------------
# NPT BASE
# -----------------------------------------------------
@dataclass(slots=True)
class NPTConfig(BaseConfig):
title: str = "Diceplayer Run - NPT"
nstep: int = 0
vstep: int = 5000
press: float = 1.0
@classmethod
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
return super(NPTConfig, cls).from_config(
config,
press=config.dice.press,
)
# -----------------------------------------------------
# NPT THERMALIZATION
# -----------------------------------------------------
class NPTTerConfig(BaseModel):
type: Literal[DiceRoutineType.NPT_TER] = DiceRoutineType.NPT_TER
@dataclass(slots=True)
class NPTTerConfig(NPTConfig):
title: str = "Diceplayer Run - NPT Thermalization"
dens: float | None = None
title: str = "NPT Thermalization"
ncores: int
ljname: str
outname: str
nmol: list[int]
dens: float
temp: float
press: float
seed: int
init: Literal["yes", "yesreadxyz"] = "yes"
nstep: int
vstep: int
isave: int
upbuf: int
@classmethod
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
return super(NPTTerConfig, cls).from_config(
config,
return cls(
ncores=get_ncores(config),
ljname=str(config.dice.ljname),
outname=config.dice.outname,
nmol=config.dice.nmol,
dens=config.dice.dens,
nstep=cls._get_nstep(config, 1),
vstep=config.dice.vstep,
temp=config.dice.temp,
press=config.dice.press,
seed=get_seed(config),
nstep=get_nstep(config, 1),
vstep=max(1, config.dice.vstep),
isave=config.dice.isave,
upbuf=config.dice.upbuf,
**kwargs,
)
def write(self, directory: Path, filename: str = "npt.ter") -> Path:
return super(NPTTerConfig, self).write(directory, filename)
# -----------------------------------------------------
# NPT PRODUCTION
# -----------------------------------------------------
class NPTEqConfig(BaseModel):
type: Literal[DiceRoutineType.NPT_EQ] = DiceRoutineType.NPT_EQ
@dataclass(slots=True)
class NPTEqConfig(NPTConfig):
title: str = "Diceplayer Run - NPT Production"
dens: float | None = None
title: str = "NPT Production"
ncores: int
ljname: str
outname: str
nmol: list[int]
dens: float
temp: float
press: float
seed: int
init: Literal["yes", "yesreadxyz"] = "yes"
nstep: int
vstep: int
isave: int
irdf: Literal[0] = 0
upbuf: int
@classmethod
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
return super(NPTEqConfig, cls).from_config(
config,
return cls(
ncores=get_ncores(config),
ljname=str(config.dice.ljname),
outname=config.dice.outname,
nmol=config.dice.nmol,
dens=config.dice.dens,
nstep=cls._get_nstep(config, 2),
temp=config.dice.temp,
press=config.dice.press,
seed=get_seed(config),
nstep=get_nstep(config, 2),
vstep=config.dice.vstep,
isave=max(1, get_nstep(config, 2) // 10),
upbuf=config.dice.upbuf,
**kwargs,
)
def write(self, directory: Path, filename: str = "npt.eq") -> Path:
return super(NPTEqConfig, self).write(directory, filename)
DiceInputConfig = Annotated[
NVTTerConfig | NVTEqConfig | NPTTerConfig | NPTEqConfig,
Field(discriminator="type"),
]
def _serialize_value(value: Any) -> str:
if value is None:
raise ValueError("DICE configuration cannot serialize None values")
if isinstance(value, bool):
return "yes" if value else "no"
if isinstance(value, (list, tuple)):
return " ".join(str(v) for v in value)
return str(value)
def write_dice_config(obj: DiceInputConfig, io_writer: TextIO) -> None:
values = {f: getattr(obj, f) for f in obj.__class__.model_fields}
for key in _ALLOWED_DICE_KEYWORD_IN_ORDER:
value = values.pop(key, None)
if value is None:
continue
io_writer.write(f"{key} = {_serialize_value(value)}\n")
io_writer.write("$end\n")
def write_config(config: DiceInputConfig, directory: Path) -> Path:
input_path = directory / config.type
if input_path.exists():
logger.info(
f"Dice input file {input_path} already exists and will be overwritten"
)
input_path.unlink()
input_path.parent.mkdir(parents=True, exist_ok=True)
with open(input_path, "w") as io:
write_dice_config(config, io)
return input_path