feat: initial dice-wrapper structure and working prototype

This commit is contained in:
2026-03-16 00:11:07 -03:00
parent 6a154429e9
commit 30be88e6b4
8 changed files with 174 additions and 25 deletions

View File

@@ -4,13 +4,13 @@ diceplayer:
max_cyc: 5 max_cyc: 5
mem: 24 mem: 24
ncores: 5 ncores: 5
nprocs: 4
qmprog: 'g16' qmprog: 'g16'
lps: no lps: no
ghosts: no ghosts: no
altsteps: 2000 altsteps: 2000
dice: dice:
nprocs: 4
nmol: [1, 100] nmol: [1, 100]
dens: 1.5 dens: 1.5
nstep: [2000, 3000] nstep: [2000, 3000]

View File

@@ -1,6 +1,8 @@
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Literal from typing_extensions import Literal
import random
class DiceConfig(BaseModel): class DiceConfig(BaseModel):
""" """
@@ -10,6 +12,9 @@ class DiceConfig(BaseModel):
model_config = ConfigDict( model_config = ConfigDict(
frozen=True, frozen=True,
) )
nprocs: int = Field(
..., description="Number of processes to use for the DICE simulations"
)
ljname: str = Field(..., description="Name of the Lennard-Jones potential file") ljname: str = Field(..., description="Name of the Lennard-Jones potential file")
outname: str = Field( outname: str = Field(
@@ -48,4 +53,4 @@ class DiceConfig(BaseModel):
randominit: str = Field( randominit: str = Field(
"first", description="Method for initializing the random number generator" "first", description="Method for initializing the random number generator"
) )
seed: int | None = Field(None, description="Seed for the random number generator") seed: int = Field(default_factory=lambda: random.randint(0, 2**32 - 1), description="Seed for the random number generator")

View File

@@ -47,9 +47,6 @@ class PlayerConfig(BaseModel):
switch_cyc: int = Field(..., description="Switch cycle configuration") switch_cyc: int = Field(..., description="Switch cycle configuration")
mem: int = Field(None, description="Memory configuration") mem: int = Field(None, description="Memory configuration")
nprocs: int = Field(
..., description="Number of processors to use for the QM calculations"
)
ncores: int = Field( ncores: int = Field(
..., description="Number of cores to use for the QM calculations" ..., description="Number of cores to use for the QM calculations"
) )

View File

@@ -1,8 +1,80 @@
import shutil
from diceplayer.dice.dice_input import NVTTerConfig, NVTEqConfig, NPTEqConfig, NPTTerConfig
from diceplayer.dice.dice_wrapper import DiceWrapper
from diceplayer.logger import logger
from diceplayer.state.state_model import StateModel from diceplayer.state.state_model import StateModel
from pathlib import Path
from threading import Thread
class DiceHandler: class DiceHandler:
@staticmethod def __init__(self, step_directory: Path):
def run(state: StateModel, current_cycle: int) -> StateModel: self.dice_directory = step_directory / "dice"
print(f"Running Dice - {current_cycle}")
def run(self, state: StateModel, cycle: int) -> StateModel:
if self.dice_directory.exists():
logger.info(f"Found dice directory: {self.dice_directory}, this directory will be purged for a clean state")
shutil.rmtree(self.dice_directory)
self.dice_directory.mkdir(parents=True)
simulation_results = self.run_simulations(state, cycle)
result = self.aggregate_results(simulation_results)
return self.commit_simulation_state(state, result)
def run_simulations(self, state: StateModel, cycle: int) -> list[dict]:
results = []
threads = []
for p in range(state.config.dice.nprocs):
t = Thread(target=self._simulation_process, args=(state, cycle, p, results))
threads.append(t)
t.start()
for t in threads:
t.join()
if len(results) != state.config.dice.nprocs:
raise RuntimeError(f"Expected {state.config.dice.nprocs} simulation results, but got {len(results)}")
return results
def aggregate_results(self, simulation_results: list[dict]) -> dict: ...
def commit_simulation_state(self, state: StateModel, result: dict) -> StateModel:
return state return state
def _simulation_process(self, state: StateModel, cycle: int, proc: int, results: list[dict]) -> None:
proc_directory = self.dice_directory / f"{proc:02d}"
if proc_directory.exists():
shutil.rmtree(proc_directory)
proc_directory.mkdir(parents=True)
dice = DiceWrapper(
state.config.dice, proc_directory
)
if state.config.dice.randominit == "first" and cycle == 0:
nvt_ter_config = NVTTerConfig.from_config(state.config)
dice.run(nvt_ter_config)
else:
self._generate_last_xyz(state, proc_directory)
if len(state.config.dice.nstep) == 2:
nvt_eq_config = NVTEqConfig.from_config(state.config)
dice.run(nvt_eq_config)
elif len(state.config.dice.nstep) == 3:
npt_ter_config = NPTTerConfig.from_config(state.config)
dice.run(npt_ter_config)
npt_eq_config = NPTEqConfig.from_config(state.config)
dice.run(npt_eq_config)
results.append(dice.extract_results())
def _generate_last_xyz(self, state: StateModel, proc_directory: Path) -> None:
...

View File

@@ -1,10 +1,11 @@
from diceplayer.config import PlayerConfig from diceplayer.config import PlayerConfig
from diceplayer.logger import logger
from typing_extensions import Self from typing_extensions import Self
import random
from abc import ABC from abc import ABC
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from pathlib import Path
from typing import Any, Sequence, TextIO from typing import Any, Sequence, TextIO
@@ -19,6 +20,19 @@ class BaseConfig(ABC):
isave: int isave: int
press: float = 1.0 press: float = 1.0
def write(self, directory: Path, filename: str = "input") -> Path:
input_path = directory / filename
if input_path.exists():
logger.info(f"Dice input file {input_path} already exists and will be overwritten")
input_path.unlink()
input_path.parent.mkdir(parents=True, exist_ok=True)
with open(input_path, "w") as io:
self.write_dice_config(io)
return input_path
def write_dice_config(self, io_writer: TextIO) -> None: def write_dice_config(self, io_writer: TextIO) -> None:
for field in fields(self): for field in fields(self):
key = field.name key = field.name
@@ -38,19 +52,13 @@ class BaseConfig(ABC):
@staticmethod @staticmethod
def _extract_base_fields(config: PlayerConfig) -> dict[str, Any]: def _extract_base_fields(config: PlayerConfig) -> dict[str, Any]:
seed: int
if config.dice.seed is not None:
seed = config.dice.seed
else:
seed = random.randint(0, 2**32 - 1)
return dict( return dict(
ncores=config.ncores, ncores=int(config.ncores / config.dice.nprocs),
ljname=config.dice.ljname, ljname=config.dice.ljname,
outname=config.dice.outname, outname=config.dice.outname,
nmol=config.dice.nmol, nmol=config.dice.nmol,
temp=config.dice.temp, temp=config.dice.temp,
seed=seed, seed=config.dice.seed,
isave=config.dice.isave, isave=config.dice.isave,
press=config.dice.press, press=config.dice.press,
) )
@@ -109,6 +117,11 @@ class NVTTerConfig(NVTConfig):
**kwargs, **kwargs,
) )
def write(self, directory: Path, filename: str = "nvt.ter") -> Path:
return super(NVTTerConfig, self).write(
directory, filename
)
# ----------------------------------------------------- # -----------------------------------------------------
# NVT PRODUCTION # NVT PRODUCTION
@@ -131,6 +144,11 @@ class NVTEqConfig(NVTConfig):
**kwargs, **kwargs,
) )
def write(self, directory: Path, filename: str = "nvt.eq") -> Path:
return super(NVTEqConfig, self).write(
directory, filename
)
# ----------------------------------------------------- # -----------------------------------------------------
# NPT BASE # NPT BASE
@@ -164,6 +182,11 @@ class NPTTerConfig(NPTConfig):
**kwargs, **kwargs,
) )
def write(self, directory: Path, filename: str = "npt.ter") -> Path:
return super(NPTTerConfig, self).write(
directory, filename
)
# ----------------------------------------------------- # -----------------------------------------------------
# NPT PRODUCTION # NPT PRODUCTION
@@ -184,3 +207,8 @@ class NPTEqConfig(NPTConfig):
vstep=config.dice.vstep, vstep=config.dice.vstep,
**kwargs, **kwargs,
) )
def write(self, directory: Path, filename: str = "npt.eq") -> Path:
return super(NPTEqConfig, self).write(
directory, filename
)

View File

@@ -0,0 +1,41 @@
import subprocess
from typing import Final
import diceplayer.dice.dice_input as dice_input
from pathlib import Path
from diceplayer.config import DiceConfig
DICE_FLAG_LINE: Final[int] = -2
DICE_END_FLAG: Final[str] = "End of simulation"
class DiceWrapper:
def __init__(self, dice_config: DiceConfig, working_directory: Path):
self.dice_config = dice_config
self.working_directory = working_directory
def run(self, dice_config: dice_input.BaseConfig) -> None:
input_path = dice_config.write(self.working_directory)
output_path = input_path.parent / (input_path.name + ".out")
with open(output_path, "w") as outfile, open(input_path, "r") as infile:
exit_status = subprocess.call(
self.dice_config.progname, stdin=infile, stdout=outfile
)
if exit_status != 0:
raise RuntimeError(f"Dice simulation failed with exit status {exit_status}")
with open(output_path, "r") as outfile:
line = outfile.readlines()[DICE_FLAG_LINE]
if line.strip() == DICE_END_FLAG:
return
raise RuntimeError(f"Dice simulation failed with exit status {exit_status}")
def extract_results(self) -> dict:
return {}

View File

@@ -1,4 +1,4 @@
from diceplayer.config.player_config import PlayerConfig, RoutineType from diceplayer.config.player_config import PlayerConfig
from diceplayer.dice.dice_handler import DiceHandler from diceplayer.dice.dice_handler import DiceHandler
from diceplayer.logger import logger from diceplayer.logger import logger
from diceplayer.optimization.optimization_handler import OptimizationHandler from diceplayer.optimization.optimization_handler import OptimizationHandler
@@ -22,13 +22,14 @@ class Player:
continuation = flags.get("continuation", False) continuation = flags.get("continuation", False)
force = flags.get("force", False) force = flags.get("force", False)
state: StateModel = self._state_handler.get(self.config, force=force) state = self._state_handler.get(self.config, force=force)
if not continuation and state is not None: if not continuation and state is not None:
logger.info( logger.info(
"Continuation flag is not set. Starting a new simulation and deleting any existing state." "Continuation flag is not set. Starting a new simulation and deleting any existing state."
) )
self._state_handler.delete() self._state_handler.delete()
state = self._state_handler.get(self.config, force=force) state = None
if state is None: if state is None:
state = StateModel.from_config(self.config) state = StateModel.from_config(self.config)
@@ -44,9 +45,9 @@ class Player:
if not step_directory.exists(): if not step_directory.exists():
step_directory.mkdir(parents=True) step_directory.mkdir(parents=True)
state = DiceHandler.run(state, state.current_cycle) state = DiceHandler(step_directory).run(state, state.current_cycle)
state = OptimizationHandler.run(state, state.current_cycle) # state = OptimizationHandler.run(state, state.current_cycle)
state.current_cycle += 1 state.current_cycle += 1
self._state_handler.save(state) self._state_handler.save(state)

View File

@@ -1,10 +1,15 @@
from pathlib import Path
from diceplayer.config import PlayerConfig from diceplayer.config import PlayerConfig
from diceplayer.dice.dice_input import NVTEqConfig, NVTTerConfig, NPTTerConfig, NPTEqConfig from diceplayer.dice.dice_input import (
NPTEqConfig,
NPTTerConfig,
NVTEqConfig,
NVTTerConfig,
)
import pytest import pytest
from pathlib import Path
class TestDiceInput: class TestDiceInput:
@pytest.fixture @pytest.fixture