Files
DicePlayer/diceplayer/dice/dice_handler.py

81 lines
2.8 KiB
Python

import shutil
from diceplayer.dice.dice_input import NVTTerConfig, NVTEqConfig, NPTEqConfig, NPTTerConfig
from diceplayer.dice.dice_wrapper import DiceWrapper
from diceplayer.logger import logger
from diceplayer.state.state_model import StateModel
from pathlib import Path
from threading import Thread
class DiceHandler:
def __init__(self, step_directory: Path):
self.dice_directory = step_directory / "dice"
def run(self, state: StateModel, cycle: int) -> StateModel:
if self.dice_directory.exists():
logger.info(f"Found dice directory: {self.dice_directory}, this directory will be purged for a clean state")
shutil.rmtree(self.dice_directory)
self.dice_directory.mkdir(parents=True)
simulation_results = self.run_simulations(state, cycle)
result = self.aggregate_results(simulation_results)
return self.commit_simulation_state(state, result)
def run_simulations(self, state: StateModel, cycle: int) -> list[dict]:
results = []
threads = []
for p in range(state.config.dice.nprocs):
t = Thread(target=self._simulation_process, args=(state, cycle, p, results))
threads.append(t)
t.start()
for t in threads:
t.join()
if len(results) != state.config.dice.nprocs:
raise RuntimeError(f"Expected {state.config.dice.nprocs} simulation results, but got {len(results)}")
return results
def aggregate_results(self, simulation_results: list[dict]) -> dict: ...
def commit_simulation_state(self, state: StateModel, result: dict) -> StateModel:
return state
def _simulation_process(self, state: StateModel, cycle: int, proc: int, results: list[dict]) -> None:
proc_directory = self.dice_directory / f"{proc:02d}"
if proc_directory.exists():
shutil.rmtree(proc_directory)
proc_directory.mkdir(parents=True)
dice = DiceWrapper(
state.config.dice, proc_directory
)
if state.config.dice.randominit == "first" and cycle == 0:
nvt_ter_config = NVTTerConfig.from_config(state.config)
dice.run(nvt_ter_config)
else:
self._generate_last_xyz(state, proc_directory)
if len(state.config.dice.nstep) == 2:
nvt_eq_config = NVTEqConfig.from_config(state.config)
dice.run(nvt_eq_config)
elif len(state.config.dice.nstep) == 3:
npt_ter_config = NPTTerConfig.from_config(state.config)
dice.run(npt_ter_config)
npt_eq_config = NPTEqConfig.from_config(state.config)
dice.run(npt_eq_config)
results.append(dice.extract_results())
def _generate_last_xyz(self, state: StateModel, proc_directory: Path) -> None:
...