refactor: update dice handling and optimization flow to return structured results
This commit is contained in:
@@ -13,10 +13,11 @@ diceplayer:
|
|||||||
nprocs: 1
|
nprocs: 1
|
||||||
nmol: [1, 200]
|
nmol: [1, 200]
|
||||||
dens: 1.5
|
dens: 1.5
|
||||||
nstep: [200, 300]
|
nstep: [200, 200]
|
||||||
isave: 100
|
vstep: 1000
|
||||||
|
isave: 30
|
||||||
outname: 'phb'
|
outname: 'phb'
|
||||||
progname: 'dice'
|
progname: '/home/hideyoshi/.local/bin/dice'
|
||||||
ljname: 'phb.ljc.example'
|
ljname: 'phb.ljc.example'
|
||||||
randominit: 'always'
|
randominit: 'always'
|
||||||
seed: 12345
|
seed: 12345
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from diceplayer.dice.dice_input import (
|
|||||||
NVTEqConfig,
|
NVTEqConfig,
|
||||||
NVTTerConfig,
|
NVTTerConfig,
|
||||||
)
|
)
|
||||||
from diceplayer.dice.dice_wrapper import DiceWrapper
|
from diceplayer.dice.dice_wrapper import DiceWrapper, DiceEnvironment
|
||||||
from diceplayer.logger import logger
|
from diceplayer.logger import logger
|
||||||
from diceplayer.state.state_model import StateModel
|
from diceplayer.state.state_model import StateModel
|
||||||
|
|
||||||
@@ -17,7 +17,7 @@ class DiceHandler:
|
|||||||
def __init__(self, step_directory: Path):
|
def __init__(self, step_directory: Path):
|
||||||
self.dice_directory = step_directory / "dice"
|
self.dice_directory = step_directory / "dice"
|
||||||
|
|
||||||
def run(self, state: StateModel, cycle: int) -> StateModel:
|
def run(self, state: StateModel, cycle: int) -> list[DiceEnvironment]:
|
||||||
if self.dice_directory.exists():
|
if self.dice_directory.exists():
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Found dice directory: {self.dice_directory}, this directory will be purged for a clean state"
|
f"Found dice directory: {self.dice_directory}, this directory will be purged for a clean state"
|
||||||
@@ -25,13 +25,9 @@ class DiceHandler:
|
|||||||
shutil.rmtree(self.dice_directory)
|
shutil.rmtree(self.dice_directory)
|
||||||
self.dice_directory.mkdir(parents=True)
|
self.dice_directory.mkdir(parents=True)
|
||||||
|
|
||||||
simulation_results = self.run_simulations(state, cycle)
|
return self.run_simulations(state, cycle)
|
||||||
|
|
||||||
result = self.aggregate_results(simulation_results)
|
def run_simulations(self, state: StateModel, cycle: int) -> list[DiceEnvironment]:
|
||||||
|
|
||||||
return self.commit_simulation_state(state, result)
|
|
||||||
|
|
||||||
def run_simulations(self, state: StateModel, cycle: int) -> list[dict]:
|
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
threads = []
|
threads = []
|
||||||
@@ -48,15 +44,12 @@ class DiceHandler:
|
|||||||
f"Expected {state.config.dice.nprocs} simulation results, but got {len(results)}"
|
f"Expected {state.config.dice.nprocs} simulation results, but got {len(results)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
return [
|
||||||
|
i for i in [r for r in results]
|
||||||
def aggregate_results(self, simulation_results: list[dict]) -> dict: ...
|
]
|
||||||
|
|
||||||
def commit_simulation_state(self, state: StateModel, result: dict) -> StateModel:
|
|
||||||
return state
|
|
||||||
|
|
||||||
def _simulation_process(
|
def _simulation_process(
|
||||||
self, state: StateModel, cycle: int, proc: int, results: list[dict]
|
self, state: StateModel, cycle: int, proc: int, results: list[list[DiceEnvironment]]
|
||||||
) -> None:
|
) -> None:
|
||||||
proc_directory = self.dice_directory / f"{proc:02d}"
|
proc_directory = self.dice_directory / f"{proc:02d}"
|
||||||
if proc_directory.exists():
|
if proc_directory.exists():
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
import diceplayer.dice.dice_input as dice_input
|
import diceplayer.dice.dice_input as dice_input
|
||||||
from diceplayer.config import DiceConfig
|
from diceplayer.config import DiceConfig
|
||||||
from diceplayer.environment import System
|
from diceplayer.environment import System
|
||||||
@@ -7,6 +9,10 @@ from pathlib import Path
|
|||||||
from typing import Final
|
from typing import Final
|
||||||
|
|
||||||
|
|
||||||
|
type DiceEnvironment = tuple[str, int, int, int]
|
||||||
|
DiceEnvironmentAdapter = TypeAdapter(DiceEnvironment)
|
||||||
|
|
||||||
|
|
||||||
DICE_FLAG_LINE: Final[int] = -2
|
DICE_FLAG_LINE: Final[int] = -2
|
||||||
DICE_END_FLAG: Final[str] = "End of simulation"
|
DICE_END_FLAG: Final[str] = "End of simulation"
|
||||||
|
|
||||||
@@ -35,9 +41,22 @@ class DiceWrapper:
|
|||||||
|
|
||||||
raise RuntimeError(f"Dice simulation failed with exit status {exit_status}")
|
raise RuntimeError(f"Dice simulation failed with exit status {exit_status}")
|
||||||
|
|
||||||
def parse_results(self, system: System) -> dict:
|
def parse_results(self, system: System) -> list[DiceEnvironment]:
|
||||||
results = {}
|
NUMBER_OF_HEADER_LINES = 2
|
||||||
|
NUMBER_OF_PRIMARY_ATOMS = len(system.molecule[0].atom)
|
||||||
|
|
||||||
|
results = []
|
||||||
for output_file in sorted(self.working_directory.glob("phb*.xyz")):
|
for output_file in sorted(self.working_directory.glob("phb*.xyz")):
|
||||||
...
|
with open(output_file, "r") as f:
|
||||||
|
for _ in range(NUMBER_OF_HEADER_LINES + NUMBER_OF_PRIMARY_ATOMS):
|
||||||
|
next(f, None)
|
||||||
|
|
||||||
|
for line in f:
|
||||||
|
if line.strip() == "":
|
||||||
|
break
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
DiceEnvironmentAdapter.validate_python(line.split())
|
||||||
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|||||||
@@ -1,11 +1,20 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from diceplayer.config.player_config import RoutineType
|
from diceplayer.config.player_config import RoutineType
|
||||||
|
from diceplayer.dice.dice_wrapper import DiceEnvironment
|
||||||
from diceplayer.state.state_model import StateModel
|
from diceplayer.state.state_model import StateModel
|
||||||
|
|
||||||
|
|
||||||
class OptimizationHandler:
|
class OptimizationHandler:
|
||||||
@staticmethod
|
def __init__(self, step_directory: Path):
|
||||||
def run(state: StateModel, current_cycle: int) -> StateModel:
|
self.dice_directory = step_directory / "dice"
|
||||||
print(f"Running Optimization - {current_cycle}")
|
|
||||||
|
def run(self, state: StateModel, current_cycle: int, dice_environment: list[DiceEnvironment]) -> StateModel:
|
||||||
|
routine = self._fetch_current_routine(state, current_cycle)
|
||||||
|
print(f"Running Optimization - {current_cycle} - {routine}")
|
||||||
|
|
||||||
|
print(dice_environment)
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from diceplayer.config.player_config import PlayerConfig
|
from diceplayer.config.player_config import PlayerConfig
|
||||||
from diceplayer.dice.dice_handler import DiceHandler
|
from diceplayer.dice.dice_handler import DiceHandler
|
||||||
from diceplayer.logger import logger
|
from diceplayer.logger import logger
|
||||||
|
from diceplayer.optimization.optimization_handler import OptimizationHandler
|
||||||
from diceplayer.state.state_handler import StateHandler
|
from diceplayer.state.state_handler import StateHandler
|
||||||
from diceplayer.state.state_model import StateModel
|
from diceplayer.state.state_model import StateModel
|
||||||
from diceplayer.utils.potential import read_system_from_phb
|
from diceplayer.utils.potential import read_system_from_phb
|
||||||
@@ -45,9 +46,13 @@ class Player:
|
|||||||
if not step_directory.exists():
|
if not step_directory.exists():
|
||||||
step_directory.mkdir(parents=True)
|
step_directory.mkdir(parents=True)
|
||||||
|
|
||||||
state = DiceHandler(step_directory).run(state, state.current_cycle)
|
dice_environment = DiceHandler(step_directory).run(
|
||||||
|
state, state.current_cycle
|
||||||
|
)
|
||||||
|
|
||||||
# state = OptimizationHandler.run(state, state.current_cycle)
|
state = OptimizationHandler(step_directory).run(
|
||||||
|
state, state.current_cycle, dice_environment
|
||||||
|
)
|
||||||
|
|
||||||
state.current_cycle += 1
|
state.current_cycle += 1
|
||||||
self._state_handler.save(state)
|
self._state_handler.save(state)
|
||||||
|
|||||||
Reference in New Issue
Block a user