feat: initial dice-wrapper structure and working prototype

This commit is contained in:
2026-03-16 00:11:07 -03:00
parent 6a154429e9
commit 30be88e6b4
8 changed files with 174 additions and 25 deletions

View File

@@ -1,8 +1,80 @@
import shutil
from diceplayer.dice.dice_input import NVTTerConfig, NVTEqConfig, NPTEqConfig, NPTTerConfig
from diceplayer.dice.dice_wrapper import DiceWrapper
from diceplayer.logger import logger
from diceplayer.state.state_model import StateModel
from pathlib import Path
from threading import Thread
class DiceHandler:
@staticmethod
def run(state: StateModel, current_cycle: int) -> StateModel:
print(f"Running Dice - {current_cycle}")
def __init__(self, step_directory: Path):
self.dice_directory = step_directory / "dice"
def run(self, state: StateModel, cycle: int) -> StateModel:
if self.dice_directory.exists():
logger.info(f"Found dice directory: {self.dice_directory}, this directory will be purged for a clean state")
shutil.rmtree(self.dice_directory)
self.dice_directory.mkdir(parents=True)
simulation_results = self.run_simulations(state, cycle)
result = self.aggregate_results(simulation_results)
return self.commit_simulation_state(state, result)
def run_simulations(self, state: StateModel, cycle: int) -> list[dict]:
results = []
threads = []
for p in range(state.config.dice.nprocs):
t = Thread(target=self._simulation_process, args=(state, cycle, p, results))
threads.append(t)
t.start()
for t in threads:
t.join()
if len(results) != state.config.dice.nprocs:
raise RuntimeError(f"Expected {state.config.dice.nprocs} simulation results, but got {len(results)}")
return results
def aggregate_results(self, simulation_results: list[dict]) -> dict: ...
def commit_simulation_state(self, state: StateModel, result: dict) -> StateModel:
return state
def _simulation_process(self, state: StateModel, cycle: int, proc: int, results: list[dict]) -> None:
proc_directory = self.dice_directory / f"{proc:02d}"
if proc_directory.exists():
shutil.rmtree(proc_directory)
proc_directory.mkdir(parents=True)
dice = DiceWrapper(
state.config.dice, proc_directory
)
if state.config.dice.randominit == "first" and cycle == 0:
nvt_ter_config = NVTTerConfig.from_config(state.config)
dice.run(nvt_ter_config)
else:
self._generate_last_xyz(state, proc_directory)
if len(state.config.dice.nstep) == 2:
nvt_eq_config = NVTEqConfig.from_config(state.config)
dice.run(nvt_eq_config)
elif len(state.config.dice.nstep) == 3:
npt_ter_config = NPTTerConfig.from_config(state.config)
dice.run(npt_ter_config)
npt_eq_config = NPTEqConfig.from_config(state.config)
dice.run(npt_eq_config)
results.append(dice.extract_results())
def _generate_last_xyz(self, state: StateModel, proc_directory: Path) -> None:
...

View File

@@ -1,10 +1,11 @@
from diceplayer.config import PlayerConfig
from diceplayer.logger import logger
from typing_extensions import Self
import random
from abc import ABC
from dataclasses import dataclass, fields
from pathlib import Path
from typing import Any, Sequence, TextIO
@@ -19,6 +20,19 @@ class BaseConfig(ABC):
isave: int
press: float = 1.0
def write(self, directory: Path, filename: str = "input") -> Path:
input_path = directory / filename
if input_path.exists():
logger.info(f"Dice input file {input_path} already exists and will be overwritten")
input_path.unlink()
input_path.parent.mkdir(parents=True, exist_ok=True)
with open(input_path, "w") as io:
self.write_dice_config(io)
return input_path
def write_dice_config(self, io_writer: TextIO) -> None:
for field in fields(self):
key = field.name
@@ -38,19 +52,13 @@ class BaseConfig(ABC):
@staticmethod
def _extract_base_fields(config: PlayerConfig) -> dict[str, Any]:
seed: int
if config.dice.seed is not None:
seed = config.dice.seed
else:
seed = random.randint(0, 2**32 - 1)
return dict(
ncores=config.ncores,
ncores=int(config.ncores / config.dice.nprocs),
ljname=config.dice.ljname,
outname=config.dice.outname,
nmol=config.dice.nmol,
temp=config.dice.temp,
seed=seed,
seed=config.dice.seed,
isave=config.dice.isave,
press=config.dice.press,
)
@@ -109,6 +117,11 @@ class NVTTerConfig(NVTConfig):
**kwargs,
)
def write(self, directory: Path, filename: str = "nvt.ter") -> Path:
return super(NVTTerConfig, self).write(
directory, filename
)
# -----------------------------------------------------
# NVT PRODUCTION
@@ -131,6 +144,11 @@ class NVTEqConfig(NVTConfig):
**kwargs,
)
def write(self, directory: Path, filename: str = "nvt.eq") -> Path:
return super(NVTEqConfig, self).write(
directory, filename
)
# -----------------------------------------------------
# NPT BASE
@@ -164,6 +182,11 @@ class NPTTerConfig(NPTConfig):
**kwargs,
)
def write(self, directory: Path, filename: str = "npt.ter") -> Path:
return super(NPTTerConfig, self).write(
directory, filename
)
# -----------------------------------------------------
# NPT PRODUCTION
@@ -183,4 +206,9 @@ class NPTEqConfig(NPTConfig):
nstep=cls._get_nstep(config, 2),
vstep=config.dice.vstep,
**kwargs,
)
def write(self, directory: Path, filename: str = "npt.eq") -> Path:
return super(NPTEqConfig, self).write(
directory, filename
)

View File

@@ -0,0 +1,41 @@
import subprocess
from typing import Final
import diceplayer.dice.dice_input as dice_input
from pathlib import Path
from diceplayer.config import DiceConfig
DICE_FLAG_LINE: Final[int] = -2
DICE_END_FLAG: Final[str] = "End of simulation"
class DiceWrapper:
def __init__(self, dice_config: DiceConfig, working_directory: Path):
self.dice_config = dice_config
self.working_directory = working_directory
def run(self, dice_config: dice_input.BaseConfig) -> None:
input_path = dice_config.write(self.working_directory)
output_path = input_path.parent / (input_path.name + ".out")
with open(output_path, "w") as outfile, open(input_path, "r") as infile:
exit_status = subprocess.call(
self.dice_config.progname, stdin=infile, stdout=outfile
)
if exit_status != 0:
raise RuntimeError(f"Dice simulation failed with exit status {exit_status}")
with open(output_path, "r") as outfile:
line = outfile.readlines()[DICE_FLAG_LINE]
if line.strip() == DICE_END_FLAG:
return
raise RuntimeError(f"Dice simulation failed with exit status {exit_status}")
def extract_results(self) -> dict:
return {}