feat: initial dice-wrapper structure and working prototype

This commit is contained in:
2026-03-16 00:11:07 -03:00
parent 6a154429e9
commit 30be88e6b4
8 changed files with 174 additions and 25 deletions

View File

@@ -1,10 +1,11 @@
from diceplayer.config import PlayerConfig
from diceplayer.logger import logger
from typing_extensions import Self
import random
from abc import ABC
from dataclasses import dataclass, fields
from pathlib import Path
from typing import Any, Sequence, TextIO
@@ -19,6 +20,19 @@ class BaseConfig(ABC):
isave: int
press: float = 1.0
def write(self, directory: Path, filename: str = "input") -> Path:
input_path = directory / filename
if input_path.exists():
logger.info(f"Dice input file {input_path} already exists and will be overwritten")
input_path.unlink()
input_path.parent.mkdir(parents=True, exist_ok=True)
with open(input_path, "w") as io:
self.write_dice_config(io)
return input_path
def write_dice_config(self, io_writer: TextIO) -> None:
for field in fields(self):
key = field.name
@@ -38,19 +52,13 @@ class BaseConfig(ABC):
@staticmethod
def _extract_base_fields(config: PlayerConfig) -> dict[str, Any]:
seed: int
if config.dice.seed is not None:
seed = config.dice.seed
else:
seed = random.randint(0, 2**32 - 1)
return dict(
ncores=config.ncores,
ncores=int(config.ncores / config.dice.nprocs),
ljname=config.dice.ljname,
outname=config.dice.outname,
nmol=config.dice.nmol,
temp=config.dice.temp,
seed=seed,
seed=config.dice.seed,
isave=config.dice.isave,
press=config.dice.press,
)
@@ -109,6 +117,11 @@ class NVTTerConfig(NVTConfig):
**kwargs,
)
def write(self, directory: Path, filename: str = "nvt.ter") -> Path:
return super(NVTTerConfig, self).write(
directory, filename
)
# -----------------------------------------------------
# NVT PRODUCTION
@@ -131,6 +144,11 @@ class NVTEqConfig(NVTConfig):
**kwargs,
)
def write(self, directory: Path, filename: str = "nvt.eq") -> Path:
return super(NVTEqConfig, self).write(
directory, filename
)
# -----------------------------------------------------
# NPT BASE
@@ -164,6 +182,11 @@ class NPTTerConfig(NPTConfig):
**kwargs,
)
def write(self, directory: Path, filename: str = "npt.ter") -> Path:
return super(NPTTerConfig, self).write(
directory, filename
)
# -----------------------------------------------------
# NPT PRODUCTION
@@ -183,4 +206,9 @@ class NPTEqConfig(NPTConfig):
nstep=cls._get_nstep(config, 2),
vstep=config.dice.vstep,
**kwargs,
)
def write(self, directory: Path, filename: str = "npt.eq") -> Path:
return super(NPTEqConfig, self).write(
directory, filename
)