From 30be88e6b43f5d3d3af75c8d8fa8f4f644fe6444 Mon Sep 17 00:00:00 2001 From: Vitor Hideyoshi Date: Mon, 16 Mar 2026 00:11:07 -0300 Subject: [PATCH] feat: initial dice-wrapper structure and working prototype --- control.example.yml | 2 +- diceplayer/config/dice_config.py | 7 ++- diceplayer/config/player_config.py | 3 -- diceplayer/dice/dice_handler.py | 78 ++++++++++++++++++++++++++++-- diceplayer/dice/dice_input.py | 46 ++++++++++++++---- diceplayer/dice/dice_wrapper.py | 41 ++++++++++++++++ diceplayer/player.py | 11 +++-- tests/dice/test_dice_input.py | 11 +++-- 8 files changed, 174 insertions(+), 25 deletions(-) create mode 100644 diceplayer/dice/dice_wrapper.py diff --git a/control.example.yml b/control.example.yml index 6ecbb53..705e4cf 100644 --- a/control.example.yml +++ b/control.example.yml @@ -4,13 +4,13 @@ diceplayer: max_cyc: 5 mem: 24 ncores: 5 - nprocs: 4 qmprog: 'g16' lps: no ghosts: no altsteps: 2000 dice: + nprocs: 4 nmol: [1, 100] dens: 1.5 nstep: [2000, 3000] diff --git a/diceplayer/config/dice_config.py b/diceplayer/config/dice_config.py index 21a3027..99b59ab 100644 --- a/diceplayer/config/dice_config.py +++ b/diceplayer/config/dice_config.py @@ -1,6 +1,8 @@ from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Literal +import random + class DiceConfig(BaseModel): """ @@ -10,6 +12,9 @@ class DiceConfig(BaseModel): model_config = ConfigDict( frozen=True, ) + nprocs: int = Field( + ..., description="Number of processes to use for the DICE simulations" + ) ljname: str = Field(..., description="Name of the Lennard-Jones potential file") outname: str = Field( @@ -48,4 +53,4 @@ class DiceConfig(BaseModel): randominit: str = Field( "first", description="Method for initializing the random number generator" ) - seed: int | None = Field(None, description="Seed for the random number generator") + seed: int = Field(default_factory=lambda: random.randint(0, 2**32 - 1), description="Seed for the random number generator") diff --git a/diceplayer/config/player_config.py b/diceplayer/config/player_config.py index e2d1109..74403f6 100644 --- a/diceplayer/config/player_config.py +++ b/diceplayer/config/player_config.py @@ -47,9 +47,6 @@ class PlayerConfig(BaseModel): switch_cyc: int = Field(..., description="Switch cycle configuration") mem: int = Field(None, description="Memory configuration") - nprocs: int = Field( - ..., description="Number of processors to use for the QM calculations" - ) ncores: int = Field( ..., description="Number of cores to use for the QM calculations" ) diff --git a/diceplayer/dice/dice_handler.py b/diceplayer/dice/dice_handler.py index 2842e77..f418cda 100644 --- a/diceplayer/dice/dice_handler.py +++ b/diceplayer/dice/dice_handler.py @@ -1,8 +1,80 @@ +import shutil + +from diceplayer.dice.dice_input import NVTTerConfig, NVTEqConfig, NPTEqConfig, NPTTerConfig +from diceplayer.dice.dice_wrapper import DiceWrapper +from diceplayer.logger import logger from diceplayer.state.state_model import StateModel +from pathlib import Path +from threading import Thread + class DiceHandler: - @staticmethod - def run(state: StateModel, current_cycle: int) -> StateModel: - print(f"Running Dice - {current_cycle}") + def __init__(self, step_directory: Path): + self.dice_directory = step_directory / "dice" + + def run(self, state: StateModel, cycle: int) -> StateModel: + if self.dice_directory.exists(): + logger.info(f"Found dice directory: {self.dice_directory}, this directory will be purged for a clean state") + shutil.rmtree(self.dice_directory) + self.dice_directory.mkdir(parents=True) + + simulation_results = self.run_simulations(state, cycle) + + result = self.aggregate_results(simulation_results) + + return self.commit_simulation_state(state, result) + + def run_simulations(self, state: StateModel, cycle: int) -> list[dict]: + results = [] + + threads = [] + for p in range(state.config.dice.nprocs): + t = Thread(target=self._simulation_process, args=(state, cycle, p, results)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + if len(results) != state.config.dice.nprocs: + raise RuntimeError(f"Expected {state.config.dice.nprocs} simulation results, but got {len(results)}") + + return results + + def aggregate_results(self, simulation_results: list[dict]) -> dict: ... + + def commit_simulation_state(self, state: StateModel, result: dict) -> StateModel: return state + + def _simulation_process(self, state: StateModel, cycle: int, proc: int, results: list[dict]) -> None: + proc_directory = self.dice_directory / f"{proc:02d}" + if proc_directory.exists(): + shutil.rmtree(proc_directory) + proc_directory.mkdir(parents=True) + + dice = DiceWrapper( + state.config.dice, proc_directory + ) + + if state.config.dice.randominit == "first" and cycle == 0: + nvt_ter_config = NVTTerConfig.from_config(state.config) + dice.run(nvt_ter_config) + else: + self._generate_last_xyz(state, proc_directory) + + if len(state.config.dice.nstep) == 2: + nvt_eq_config = NVTEqConfig.from_config(state.config) + dice.run(nvt_eq_config) + + elif len(state.config.dice.nstep) == 3: + npt_ter_config = NPTTerConfig.from_config(state.config) + dice.run(npt_ter_config) + + npt_eq_config = NPTEqConfig.from_config(state.config) + dice.run(npt_eq_config) + + results.append(dice.extract_results()) + + def _generate_last_xyz(self, state: StateModel, proc_directory: Path) -> None: + ... diff --git a/diceplayer/dice/dice_input.py b/diceplayer/dice/dice_input.py index 75fc6db..34adecd 100644 --- a/diceplayer/dice/dice_input.py +++ b/diceplayer/dice/dice_input.py @@ -1,10 +1,11 @@ from diceplayer.config import PlayerConfig +from diceplayer.logger import logger from typing_extensions import Self -import random from abc import ABC from dataclasses import dataclass, fields +from pathlib import Path from typing import Any, Sequence, TextIO @@ -19,6 +20,19 @@ class BaseConfig(ABC): isave: int press: float = 1.0 + def write(self, directory: Path, filename: str = "input") -> Path: + input_path = directory / filename + + if input_path.exists(): + logger.info(f"Dice input file {input_path} already exists and will be overwritten") + input_path.unlink() + input_path.parent.mkdir(parents=True, exist_ok=True) + + with open(input_path, "w") as io: + self.write_dice_config(io) + + return input_path + def write_dice_config(self, io_writer: TextIO) -> None: for field in fields(self): key = field.name @@ -38,19 +52,13 @@ class BaseConfig(ABC): @staticmethod def _extract_base_fields(config: PlayerConfig) -> dict[str, Any]: - seed: int - if config.dice.seed is not None: - seed = config.dice.seed - else: - seed = random.randint(0, 2**32 - 1) - return dict( - ncores=config.ncores, + ncores=int(config.ncores / config.dice.nprocs), ljname=config.dice.ljname, outname=config.dice.outname, nmol=config.dice.nmol, temp=config.dice.temp, - seed=seed, + seed=config.dice.seed, isave=config.dice.isave, press=config.dice.press, ) @@ -109,6 +117,11 @@ class NVTTerConfig(NVTConfig): **kwargs, ) + def write(self, directory: Path, filename: str = "nvt.ter") -> Path: + return super(NVTTerConfig, self).write( + directory, filename + ) + # ----------------------------------------------------- # NVT PRODUCTION @@ -131,6 +144,11 @@ class NVTEqConfig(NVTConfig): **kwargs, ) + def write(self, directory: Path, filename: str = "nvt.eq") -> Path: + return super(NVTEqConfig, self).write( + directory, filename + ) + # ----------------------------------------------------- # NPT BASE @@ -164,6 +182,11 @@ class NPTTerConfig(NPTConfig): **kwargs, ) + def write(self, directory: Path, filename: str = "npt.ter") -> Path: + return super(NPTTerConfig, self).write( + directory, filename + ) + # ----------------------------------------------------- # NPT PRODUCTION @@ -183,4 +206,9 @@ class NPTEqConfig(NPTConfig): nstep=cls._get_nstep(config, 2), vstep=config.dice.vstep, **kwargs, + ) + + def write(self, directory: Path, filename: str = "npt.eq") -> Path: + return super(NPTEqConfig, self).write( + directory, filename ) \ No newline at end of file diff --git a/diceplayer/dice/dice_wrapper.py b/diceplayer/dice/dice_wrapper.py new file mode 100644 index 0000000..46796a9 --- /dev/null +++ b/diceplayer/dice/dice_wrapper.py @@ -0,0 +1,41 @@ +import subprocess +from typing import Final + +import diceplayer.dice.dice_input as dice_input + +from pathlib import Path + +from diceplayer.config import DiceConfig + + +DICE_FLAG_LINE: Final[int] = -2 +DICE_END_FLAG: Final[str] = "End of simulation" + + +class DiceWrapper: + def __init__(self, dice_config: DiceConfig, working_directory: Path): + self.dice_config = dice_config + self.working_directory = working_directory + + def run(self, dice_config: dice_input.BaseConfig) -> None: + input_path = dice_config.write(self.working_directory) + output_path = input_path.parent / (input_path.name + ".out") + + with open(output_path, "w") as outfile, open(input_path, "r") as infile: + exit_status = subprocess.call( + self.dice_config.progname, stdin=infile, stdout=outfile + ) + + if exit_status != 0: + raise RuntimeError(f"Dice simulation failed with exit status {exit_status}") + + with open(output_path, "r") as outfile: + line = outfile.readlines()[DICE_FLAG_LINE] + if line.strip() == DICE_END_FLAG: + return + + raise RuntimeError(f"Dice simulation failed with exit status {exit_status}") + + def extract_results(self) -> dict: + return {} + diff --git a/diceplayer/player.py b/diceplayer/player.py index b2ce936..4f4979c 100644 --- a/diceplayer/player.py +++ b/diceplayer/player.py @@ -1,4 +1,4 @@ -from diceplayer.config.player_config import PlayerConfig, RoutineType +from diceplayer.config.player_config import PlayerConfig from diceplayer.dice.dice_handler import DiceHandler from diceplayer.logger import logger from diceplayer.optimization.optimization_handler import OptimizationHandler @@ -22,13 +22,14 @@ class Player: continuation = flags.get("continuation", False) force = flags.get("force", False) - state: StateModel = self._state_handler.get(self.config, force=force) + state = self._state_handler.get(self.config, force=force) if not continuation and state is not None: logger.info( "Continuation flag is not set. Starting a new simulation and deleting any existing state." ) self._state_handler.delete() - state = self._state_handler.get(self.config, force=force) + state = None + if state is None: state = StateModel.from_config(self.config) @@ -44,9 +45,9 @@ class Player: if not step_directory.exists(): step_directory.mkdir(parents=True) - state = DiceHandler.run(state, state.current_cycle) + state = DiceHandler(step_directory).run(state, state.current_cycle) - state = OptimizationHandler.run(state, state.current_cycle) + # state = OptimizationHandler.run(state, state.current_cycle) state.current_cycle += 1 self._state_handler.save(state) diff --git a/tests/dice/test_dice_input.py b/tests/dice/test_dice_input.py index d2141af..2e1dad1 100644 --- a/tests/dice/test_dice_input.py +++ b/tests/dice/test_dice_input.py @@ -1,10 +1,15 @@ -from pathlib import Path - from diceplayer.config import PlayerConfig -from diceplayer.dice.dice_input import NVTEqConfig, NVTTerConfig, NPTTerConfig, NPTEqConfig +from diceplayer.dice.dice_input import ( + NPTEqConfig, + NPTTerConfig, + NVTEqConfig, + NVTTerConfig, +) import pytest +from pathlib import Path + class TestDiceInput: @pytest.fixture