refactor: update dice handling and optimization flow to return structured results

This commit is contained in:
2026-03-25 11:05:23 -03:00
parent 0470200d00
commit 7e66c98f26
5 changed files with 53 additions and 26 deletions

View File

@@ -13,10 +13,11 @@ diceplayer:
nprocs: 1 nprocs: 1
nmol: [1, 200] nmol: [1, 200]
dens: 1.5 dens: 1.5
nstep: [200, 300] nstep: [200, 200]
isave: 100 vstep: 1000
isave: 30
outname: 'phb' outname: 'phb'
progname: 'dice' progname: '/home/hideyoshi/.local/bin/dice'
ljname: 'phb.ljc.example' ljname: 'phb.ljc.example'
randominit: 'always' randominit: 'always'
seed: 12345 seed: 12345

View File

@@ -4,7 +4,7 @@ from diceplayer.dice.dice_input import (
NVTEqConfig, NVTEqConfig,
NVTTerConfig, NVTTerConfig,
) )
from diceplayer.dice.dice_wrapper import DiceWrapper from diceplayer.dice.dice_wrapper import DiceWrapper, DiceEnvironment
from diceplayer.logger import logger from diceplayer.logger import logger
from diceplayer.state.state_model import StateModel from diceplayer.state.state_model import StateModel
@@ -17,7 +17,7 @@ class DiceHandler:
def __init__(self, step_directory: Path): def __init__(self, step_directory: Path):
self.dice_directory = step_directory / "dice" self.dice_directory = step_directory / "dice"
def run(self, state: StateModel, cycle: int) -> StateModel: def run(self, state: StateModel, cycle: int) -> list[DiceEnvironment]:
if self.dice_directory.exists(): if self.dice_directory.exists():
logger.info( logger.info(
f"Found dice directory: {self.dice_directory}, this directory will be purged for a clean state" f"Found dice directory: {self.dice_directory}, this directory will be purged for a clean state"
@@ -25,13 +25,9 @@ class DiceHandler:
shutil.rmtree(self.dice_directory) shutil.rmtree(self.dice_directory)
self.dice_directory.mkdir(parents=True) self.dice_directory.mkdir(parents=True)
simulation_results = self.run_simulations(state, cycle) return self.run_simulations(state, cycle)
result = self.aggregate_results(simulation_results) def run_simulations(self, state: StateModel, cycle: int) -> list[DiceEnvironment]:
return self.commit_simulation_state(state, result)
def run_simulations(self, state: StateModel, cycle: int) -> list[dict]:
results = [] results = []
threads = [] threads = []
@@ -48,15 +44,12 @@ class DiceHandler:
f"Expected {state.config.dice.nprocs} simulation results, but got {len(results)}" f"Expected {state.config.dice.nprocs} simulation results, but got {len(results)}"
) )
return results return [
i for i in [r for r in results]
def aggregate_results(self, simulation_results: list[dict]) -> dict: ... ]
def commit_simulation_state(self, state: StateModel, result: dict) -> StateModel:
return state
def _simulation_process( def _simulation_process(
self, state: StateModel, cycle: int, proc: int, results: list[dict] self, state: StateModel, cycle: int, proc: int, results: list[list[DiceEnvironment]]
) -> None: ) -> None:
proc_directory = self.dice_directory / f"{proc:02d}" proc_directory = self.dice_directory / f"{proc:02d}"
if proc_directory.exists(): if proc_directory.exists():

View File

@@ -1,3 +1,5 @@
from pydantic import TypeAdapter
import diceplayer.dice.dice_input as dice_input import diceplayer.dice.dice_input as dice_input
from diceplayer.config import DiceConfig from diceplayer.config import DiceConfig
from diceplayer.environment import System from diceplayer.environment import System
@@ -7,6 +9,10 @@ from pathlib import Path
from typing import Final from typing import Final
type DiceEnvironment = tuple[str, int, int, int]
DiceEnvironmentAdapter = TypeAdapter(DiceEnvironment)
DICE_FLAG_LINE: Final[int] = -2 DICE_FLAG_LINE: Final[int] = -2
DICE_END_FLAG: Final[str] = "End of simulation" DICE_END_FLAG: Final[str] = "End of simulation"
@@ -35,9 +41,22 @@ class DiceWrapper:
raise RuntimeError(f"Dice simulation failed with exit status {exit_status}") raise RuntimeError(f"Dice simulation failed with exit status {exit_status}")
def parse_results(self, system: System) -> dict: def parse_results(self, system: System) -> list[DiceEnvironment]:
results = {} NUMBER_OF_HEADER_LINES = 2
NUMBER_OF_PRIMARY_ATOMS = len(system.molecule[0].atom)
results = []
for output_file in sorted(self.working_directory.glob("phb*.xyz")): for output_file in sorted(self.working_directory.glob("phb*.xyz")):
... with open(output_file, "r") as f:
for _ in range(NUMBER_OF_HEADER_LINES + NUMBER_OF_PRIMARY_ATOMS):
next(f, None)
for line in f:
if line.strip() == "":
break
results.append(
DiceEnvironmentAdapter.validate_python(line.split())
)
return results return results

View File

@@ -1,11 +1,20 @@
from pathlib import Path
from diceplayer.config.player_config import RoutineType from diceplayer.config.player_config import RoutineType
from diceplayer.dice.dice_wrapper import DiceEnvironment
from diceplayer.state.state_model import StateModel from diceplayer.state.state_model import StateModel
class OptimizationHandler: class OptimizationHandler:
@staticmethod def __init__(self, step_directory: Path):
def run(state: StateModel, current_cycle: int) -> StateModel: self.dice_directory = step_directory / "dice"
print(f"Running Optimization - {current_cycle}")
def run(self, state: StateModel, current_cycle: int, dice_environment: list[DiceEnvironment]) -> StateModel:
routine = self._fetch_current_routine(state, current_cycle)
print(f"Running Optimization - {current_cycle} - {routine}")
print(dice_environment)
return state return state
@staticmethod @staticmethod

View File

@@ -1,6 +1,7 @@
from diceplayer.config.player_config import PlayerConfig from diceplayer.config.player_config import PlayerConfig
from diceplayer.dice.dice_handler import DiceHandler from diceplayer.dice.dice_handler import DiceHandler
from diceplayer.logger import logger from diceplayer.logger import logger
from diceplayer.optimization.optimization_handler import OptimizationHandler
from diceplayer.state.state_handler import StateHandler from diceplayer.state.state_handler import StateHandler
from diceplayer.state.state_model import StateModel from diceplayer.state.state_model import StateModel
from diceplayer.utils.potential import read_system_from_phb from diceplayer.utils.potential import read_system_from_phb
@@ -45,9 +46,13 @@ class Player:
if not step_directory.exists(): if not step_directory.exists():
step_directory.mkdir(parents=True) step_directory.mkdir(parents=True)
state = DiceHandler(step_directory).run(state, state.current_cycle) dice_environment = DiceHandler(step_directory).run(
state, state.current_cycle
)
# state = OptimizationHandler.run(state, state.current_cycle) state = OptimizationHandler(step_directory).run(
state, state.current_cycle, dice_environment
)
state.current_cycle += 1 state.current_cycle += 1
self._state_handler.save(state) self._state_handler.save(state)