feat: initial dice-wrapper structure and working prototype
This commit is contained in:
@@ -4,13 +4,13 @@ diceplayer:
|
|||||||
max_cyc: 5
|
max_cyc: 5
|
||||||
mem: 24
|
mem: 24
|
||||||
ncores: 5
|
ncores: 5
|
||||||
nprocs: 4
|
|
||||||
qmprog: 'g16'
|
qmprog: 'g16'
|
||||||
lps: no
|
lps: no
|
||||||
ghosts: no
|
ghosts: no
|
||||||
altsteps: 2000
|
altsteps: 2000
|
||||||
|
|
||||||
dice:
|
dice:
|
||||||
|
nprocs: 4
|
||||||
nmol: [1, 100]
|
nmol: [1, 100]
|
||||||
dens: 1.5
|
dens: 1.5
|
||||||
nstep: [2000, 3000]
|
nstep: [2000, 3000]
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
class DiceConfig(BaseModel):
|
class DiceConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
@@ -10,6 +12,9 @@ class DiceConfig(BaseModel):
|
|||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
frozen=True,
|
frozen=True,
|
||||||
)
|
)
|
||||||
|
nprocs: int = Field(
|
||||||
|
..., description="Number of processes to use for the DICE simulations"
|
||||||
|
)
|
||||||
|
|
||||||
ljname: str = Field(..., description="Name of the Lennard-Jones potential file")
|
ljname: str = Field(..., description="Name of the Lennard-Jones potential file")
|
||||||
outname: str = Field(
|
outname: str = Field(
|
||||||
@@ -48,4 +53,4 @@ class DiceConfig(BaseModel):
|
|||||||
randominit: str = Field(
|
randominit: str = Field(
|
||||||
"first", description="Method for initializing the random number generator"
|
"first", description="Method for initializing the random number generator"
|
||||||
)
|
)
|
||||||
seed: int | None = Field(None, description="Seed for the random number generator")
|
seed: int = Field(default_factory=lambda: random.randint(0, 2**32 - 1), description="Seed for the random number generator")
|
||||||
|
|||||||
@@ -47,9 +47,6 @@ class PlayerConfig(BaseModel):
|
|||||||
switch_cyc: int = Field(..., description="Switch cycle configuration")
|
switch_cyc: int = Field(..., description="Switch cycle configuration")
|
||||||
|
|
||||||
mem: int = Field(None, description="Memory configuration")
|
mem: int = Field(None, description="Memory configuration")
|
||||||
nprocs: int = Field(
|
|
||||||
..., description="Number of processors to use for the QM calculations"
|
|
||||||
)
|
|
||||||
ncores: int = Field(
|
ncores: int = Field(
|
||||||
..., description="Number of cores to use for the QM calculations"
|
..., description="Number of cores to use for the QM calculations"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,8 +1,80 @@
|
|||||||
|
import shutil
|
||||||
|
|
||||||
|
from diceplayer.dice.dice_input import NVTTerConfig, NVTEqConfig, NPTEqConfig, NPTTerConfig
|
||||||
|
from diceplayer.dice.dice_wrapper import DiceWrapper
|
||||||
|
from diceplayer.logger import logger
|
||||||
from diceplayer.state.state_model import StateModel
|
from diceplayer.state.state_model import StateModel
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
|
||||||
class DiceHandler:
|
class DiceHandler:
|
||||||
@staticmethod
|
def __init__(self, step_directory: Path):
|
||||||
def run(state: StateModel, current_cycle: int) -> StateModel:
|
self.dice_directory = step_directory / "dice"
|
||||||
print(f"Running Dice - {current_cycle}")
|
|
||||||
|
def run(self, state: StateModel, cycle: int) -> StateModel:
|
||||||
|
if self.dice_directory.exists():
|
||||||
|
logger.info(f"Found dice directory: {self.dice_directory}, this directory will be purged for a clean state")
|
||||||
|
shutil.rmtree(self.dice_directory)
|
||||||
|
self.dice_directory.mkdir(parents=True)
|
||||||
|
|
||||||
|
simulation_results = self.run_simulations(state, cycle)
|
||||||
|
|
||||||
|
result = self.aggregate_results(simulation_results)
|
||||||
|
|
||||||
|
return self.commit_simulation_state(state, result)
|
||||||
|
|
||||||
|
def run_simulations(self, state: StateModel, cycle: int) -> list[dict]:
|
||||||
|
results = []
|
||||||
|
|
||||||
|
threads = []
|
||||||
|
for p in range(state.config.dice.nprocs):
|
||||||
|
t = Thread(target=self._simulation_process, args=(state, cycle, p, results))
|
||||||
|
threads.append(t)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
if len(results) != state.config.dice.nprocs:
|
||||||
|
raise RuntimeError(f"Expected {state.config.dice.nprocs} simulation results, but got {len(results)}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def aggregate_results(self, simulation_results: list[dict]) -> dict: ...
|
||||||
|
|
||||||
|
def commit_simulation_state(self, state: StateModel, result: dict) -> StateModel:
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
def _simulation_process(self, state: StateModel, cycle: int, proc: int, results: list[dict]) -> None:
|
||||||
|
proc_directory = self.dice_directory / f"{proc:02d}"
|
||||||
|
if proc_directory.exists():
|
||||||
|
shutil.rmtree(proc_directory)
|
||||||
|
proc_directory.mkdir(parents=True)
|
||||||
|
|
||||||
|
dice = DiceWrapper(
|
||||||
|
state.config.dice, proc_directory
|
||||||
|
)
|
||||||
|
|
||||||
|
if state.config.dice.randominit == "first" and cycle == 0:
|
||||||
|
nvt_ter_config = NVTTerConfig.from_config(state.config)
|
||||||
|
dice.run(nvt_ter_config)
|
||||||
|
else:
|
||||||
|
self._generate_last_xyz(state, proc_directory)
|
||||||
|
|
||||||
|
if len(state.config.dice.nstep) == 2:
|
||||||
|
nvt_eq_config = NVTEqConfig.from_config(state.config)
|
||||||
|
dice.run(nvt_eq_config)
|
||||||
|
|
||||||
|
elif len(state.config.dice.nstep) == 3:
|
||||||
|
npt_ter_config = NPTTerConfig.from_config(state.config)
|
||||||
|
dice.run(npt_ter_config)
|
||||||
|
|
||||||
|
npt_eq_config = NPTEqConfig.from_config(state.config)
|
||||||
|
dice.run(npt_eq_config)
|
||||||
|
|
||||||
|
results.append(dice.extract_results())
|
||||||
|
|
||||||
|
def _generate_last_xyz(self, state: StateModel, proc_directory: Path) -> None:
|
||||||
|
...
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
from diceplayer.config import PlayerConfig
|
from diceplayer.config import PlayerConfig
|
||||||
|
from diceplayer.logger import logger
|
||||||
|
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
import random
|
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Sequence, TextIO
|
from typing import Any, Sequence, TextIO
|
||||||
|
|
||||||
|
|
||||||
@@ -19,6 +20,19 @@ class BaseConfig(ABC):
|
|||||||
isave: int
|
isave: int
|
||||||
press: float = 1.0
|
press: float = 1.0
|
||||||
|
|
||||||
|
def write(self, directory: Path, filename: str = "input") -> Path:
|
||||||
|
input_path = directory / filename
|
||||||
|
|
||||||
|
if input_path.exists():
|
||||||
|
logger.info(f"Dice input file {input_path} already exists and will be overwritten")
|
||||||
|
input_path.unlink()
|
||||||
|
input_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
with open(input_path, "w") as io:
|
||||||
|
self.write_dice_config(io)
|
||||||
|
|
||||||
|
return input_path
|
||||||
|
|
||||||
def write_dice_config(self, io_writer: TextIO) -> None:
|
def write_dice_config(self, io_writer: TextIO) -> None:
|
||||||
for field in fields(self):
|
for field in fields(self):
|
||||||
key = field.name
|
key = field.name
|
||||||
@@ -38,19 +52,13 @@ class BaseConfig(ABC):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_base_fields(config: PlayerConfig) -> dict[str, Any]:
|
def _extract_base_fields(config: PlayerConfig) -> dict[str, Any]:
|
||||||
seed: int
|
|
||||||
if config.dice.seed is not None:
|
|
||||||
seed = config.dice.seed
|
|
||||||
else:
|
|
||||||
seed = random.randint(0, 2**32 - 1)
|
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
ncores=config.ncores,
|
ncores=int(config.ncores / config.dice.nprocs),
|
||||||
ljname=config.dice.ljname,
|
ljname=config.dice.ljname,
|
||||||
outname=config.dice.outname,
|
outname=config.dice.outname,
|
||||||
nmol=config.dice.nmol,
|
nmol=config.dice.nmol,
|
||||||
temp=config.dice.temp,
|
temp=config.dice.temp,
|
||||||
seed=seed,
|
seed=config.dice.seed,
|
||||||
isave=config.dice.isave,
|
isave=config.dice.isave,
|
||||||
press=config.dice.press,
|
press=config.dice.press,
|
||||||
)
|
)
|
||||||
@@ -109,6 +117,11 @@ class NVTTerConfig(NVTConfig):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def write(self, directory: Path, filename: str = "nvt.ter") -> Path:
|
||||||
|
return super(NVTTerConfig, self).write(
|
||||||
|
directory, filename
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------
|
# -----------------------------------------------------
|
||||||
# NVT PRODUCTION
|
# NVT PRODUCTION
|
||||||
@@ -131,6 +144,11 @@ class NVTEqConfig(NVTConfig):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def write(self, directory: Path, filename: str = "nvt.eq") -> Path:
|
||||||
|
return super(NVTEqConfig, self).write(
|
||||||
|
directory, filename
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------
|
# -----------------------------------------------------
|
||||||
# NPT BASE
|
# NPT BASE
|
||||||
@@ -164,6 +182,11 @@ class NPTTerConfig(NPTConfig):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def write(self, directory: Path, filename: str = "npt.ter") -> Path:
|
||||||
|
return super(NPTTerConfig, self).write(
|
||||||
|
directory, filename
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------
|
# -----------------------------------------------------
|
||||||
# NPT PRODUCTION
|
# NPT PRODUCTION
|
||||||
@@ -183,4 +206,9 @@ class NPTEqConfig(NPTConfig):
|
|||||||
nstep=cls._get_nstep(config, 2),
|
nstep=cls._get_nstep(config, 2),
|
||||||
vstep=config.dice.vstep,
|
vstep=config.dice.vstep,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def write(self, directory: Path, filename: str = "npt.eq") -> Path:
|
||||||
|
return super(NPTEqConfig, self).write(
|
||||||
|
directory, filename
|
||||||
)
|
)
|
||||||
41
diceplayer/dice/dice_wrapper.py
Normal file
41
diceplayer/dice/dice_wrapper.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
import subprocess
|
||||||
|
from typing import Final
|
||||||
|
|
||||||
|
import diceplayer.dice.dice_input as dice_input
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from diceplayer.config import DiceConfig
|
||||||
|
|
||||||
|
|
||||||
|
DICE_FLAG_LINE: Final[int] = -2
|
||||||
|
DICE_END_FLAG: Final[str] = "End of simulation"
|
||||||
|
|
||||||
|
|
||||||
|
class DiceWrapper:
|
||||||
|
def __init__(self, dice_config: DiceConfig, working_directory: Path):
|
||||||
|
self.dice_config = dice_config
|
||||||
|
self.working_directory = working_directory
|
||||||
|
|
||||||
|
def run(self, dice_config: dice_input.BaseConfig) -> None:
|
||||||
|
input_path = dice_config.write(self.working_directory)
|
||||||
|
output_path = input_path.parent / (input_path.name + ".out")
|
||||||
|
|
||||||
|
with open(output_path, "w") as outfile, open(input_path, "r") as infile:
|
||||||
|
exit_status = subprocess.call(
|
||||||
|
self.dice_config.progname, stdin=infile, stdout=outfile
|
||||||
|
)
|
||||||
|
|
||||||
|
if exit_status != 0:
|
||||||
|
raise RuntimeError(f"Dice simulation failed with exit status {exit_status}")
|
||||||
|
|
||||||
|
with open(output_path, "r") as outfile:
|
||||||
|
line = outfile.readlines()[DICE_FLAG_LINE]
|
||||||
|
if line.strip() == DICE_END_FLAG:
|
||||||
|
return
|
||||||
|
|
||||||
|
raise RuntimeError(f"Dice simulation failed with exit status {exit_status}")
|
||||||
|
|
||||||
|
def extract_results(self) -> dict:
|
||||||
|
return {}
|
||||||
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
from diceplayer.config.player_config import PlayerConfig, RoutineType
|
from diceplayer.config.player_config import PlayerConfig
|
||||||
from diceplayer.dice.dice_handler import DiceHandler
|
from diceplayer.dice.dice_handler import DiceHandler
|
||||||
from diceplayer.logger import logger
|
from diceplayer.logger import logger
|
||||||
from diceplayer.optimization.optimization_handler import OptimizationHandler
|
from diceplayer.optimization.optimization_handler import OptimizationHandler
|
||||||
@@ -22,13 +22,14 @@ class Player:
|
|||||||
continuation = flags.get("continuation", False)
|
continuation = flags.get("continuation", False)
|
||||||
force = flags.get("force", False)
|
force = flags.get("force", False)
|
||||||
|
|
||||||
state: StateModel = self._state_handler.get(self.config, force=force)
|
state = self._state_handler.get(self.config, force=force)
|
||||||
if not continuation and state is not None:
|
if not continuation and state is not None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Continuation flag is not set. Starting a new simulation and deleting any existing state."
|
"Continuation flag is not set. Starting a new simulation and deleting any existing state."
|
||||||
)
|
)
|
||||||
self._state_handler.delete()
|
self._state_handler.delete()
|
||||||
state = self._state_handler.get(self.config, force=force)
|
state = None
|
||||||
|
|
||||||
|
|
||||||
if state is None:
|
if state is None:
|
||||||
state = StateModel.from_config(self.config)
|
state = StateModel.from_config(self.config)
|
||||||
@@ -44,9 +45,9 @@ class Player:
|
|||||||
if not step_directory.exists():
|
if not step_directory.exists():
|
||||||
step_directory.mkdir(parents=True)
|
step_directory.mkdir(parents=True)
|
||||||
|
|
||||||
state = DiceHandler.run(state, state.current_cycle)
|
state = DiceHandler(step_directory).run(state, state.current_cycle)
|
||||||
|
|
||||||
state = OptimizationHandler.run(state, state.current_cycle)
|
# state = OptimizationHandler.run(state, state.current_cycle)
|
||||||
|
|
||||||
state.current_cycle += 1
|
state.current_cycle += 1
|
||||||
self._state_handler.save(state)
|
self._state_handler.save(state)
|
||||||
|
|||||||
@@ -1,10 +1,15 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from diceplayer.config import PlayerConfig
|
from diceplayer.config import PlayerConfig
|
||||||
from diceplayer.dice.dice_input import NVTEqConfig, NVTTerConfig, NPTTerConfig, NPTEqConfig
|
from diceplayer.dice.dice_input import (
|
||||||
|
NPTEqConfig,
|
||||||
|
NPTTerConfig,
|
||||||
|
NVTEqConfig,
|
||||||
|
NVTTerConfig,
|
||||||
|
)
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
class TestDiceInput:
|
class TestDiceInput:
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
Reference in New Issue
Block a user