feat: initial dice-wrapper structure and working prototype
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
from diceplayer.config import PlayerConfig
|
||||
from diceplayer.logger import logger
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
import random
|
||||
from abc import ABC
|
||||
from dataclasses import dataclass, fields
|
||||
from pathlib import Path
|
||||
from typing import Any, Sequence, TextIO
|
||||
|
||||
|
||||
@@ -19,6 +20,19 @@ class BaseConfig(ABC):
|
||||
isave: int
|
||||
press: float = 1.0
|
||||
|
||||
def write(self, directory: Path, filename: str = "input") -> Path:
|
||||
input_path = directory / filename
|
||||
|
||||
if input_path.exists():
|
||||
logger.info(f"Dice input file {input_path} already exists and will be overwritten")
|
||||
input_path.unlink()
|
||||
input_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(input_path, "w") as io:
|
||||
self.write_dice_config(io)
|
||||
|
||||
return input_path
|
||||
|
||||
def write_dice_config(self, io_writer: TextIO) -> None:
|
||||
for field in fields(self):
|
||||
key = field.name
|
||||
@@ -38,19 +52,13 @@ class BaseConfig(ABC):
|
||||
|
||||
@staticmethod
|
||||
def _extract_base_fields(config: PlayerConfig) -> dict[str, Any]:
|
||||
seed: int
|
||||
if config.dice.seed is not None:
|
||||
seed = config.dice.seed
|
||||
else:
|
||||
seed = random.randint(0, 2**32 - 1)
|
||||
|
||||
return dict(
|
||||
ncores=config.ncores,
|
||||
ncores=int(config.ncores / config.dice.nprocs),
|
||||
ljname=config.dice.ljname,
|
||||
outname=config.dice.outname,
|
||||
nmol=config.dice.nmol,
|
||||
temp=config.dice.temp,
|
||||
seed=seed,
|
||||
seed=config.dice.seed,
|
||||
isave=config.dice.isave,
|
||||
press=config.dice.press,
|
||||
)
|
||||
@@ -109,6 +117,11 @@ class NVTTerConfig(NVTConfig):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def write(self, directory: Path, filename: str = "nvt.ter") -> Path:
|
||||
return super(NVTTerConfig, self).write(
|
||||
directory, filename
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------
|
||||
# NVT PRODUCTION
|
||||
@@ -131,6 +144,11 @@ class NVTEqConfig(NVTConfig):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def write(self, directory: Path, filename: str = "nvt.eq") -> Path:
|
||||
return super(NVTEqConfig, self).write(
|
||||
directory, filename
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------
|
||||
# NPT BASE
|
||||
@@ -164,6 +182,11 @@ class NPTTerConfig(NPTConfig):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def write(self, directory: Path, filename: str = "npt.ter") -> Path:
|
||||
return super(NPTTerConfig, self).write(
|
||||
directory, filename
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------
|
||||
# NPT PRODUCTION
|
||||
@@ -183,4 +206,9 @@ class NPTEqConfig(NPTConfig):
|
||||
nstep=cls._get_nstep(config, 2),
|
||||
vstep=config.dice.vstep,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def write(self, directory: Path, filename: str = "npt.eq") -> Path:
|
||||
return super(NPTEqConfig, self).write(
|
||||
directory, filename
|
||||
)
|
||||
Reference in New Issue
Block a user