import shutil from diceplayer.dice.dice_input import NVTTerConfig, NVTEqConfig, NPTEqConfig, NPTTerConfig from diceplayer.dice.dice_wrapper import DiceWrapper from diceplayer.logger import logger from diceplayer.state.state_model import StateModel from pathlib import Path from threading import Thread class DiceHandler: def __init__(self, step_directory: Path): self.dice_directory = step_directory / "dice" def run(self, state: StateModel, cycle: int) -> StateModel: if self.dice_directory.exists(): logger.info(f"Found dice directory: {self.dice_directory}, this directory will be purged for a clean state") shutil.rmtree(self.dice_directory) self.dice_directory.mkdir(parents=True) simulation_results = self.run_simulations(state, cycle) result = self.aggregate_results(simulation_results) return self.commit_simulation_state(state, result) def run_simulations(self, state: StateModel, cycle: int) -> list[dict]: results = [] threads = [] for p in range(state.config.dice.nprocs): t = Thread(target=self._simulation_process, args=(state, cycle, p, results)) threads.append(t) t.start() for t in threads: t.join() if len(results) != state.config.dice.nprocs: raise RuntimeError(f"Expected {state.config.dice.nprocs} simulation results, but got {len(results)}") return results def aggregate_results(self, simulation_results: list[dict]) -> dict: ... def commit_simulation_state(self, state: StateModel, result: dict) -> StateModel: return state def _simulation_process(self, state: StateModel, cycle: int, proc: int, results: list[dict]) -> None: proc_directory = self.dice_directory / f"{proc:02d}" if proc_directory.exists(): shutil.rmtree(proc_directory) proc_directory.mkdir(parents=True) dice = DiceWrapper( state.config.dice, proc_directory ) if state.config.dice.randominit == "first" and cycle == 0: nvt_ter_config = NVTTerConfig.from_config(state.config) dice.run(nvt_ter_config) else: self._generate_last_xyz(state, proc_directory) if len(state.config.dice.nstep) == 2: nvt_eq_config = NVTEqConfig.from_config(state.config) dice.run(nvt_eq_config) elif len(state.config.dice.nstep) == 3: npt_ter_config = NPTTerConfig.from_config(state.config) dice.run(npt_ter_config) npt_eq_config = NPTEqConfig.from_config(state.config) dice.run(npt_eq_config) results.append(dice.extract_results()) def _generate_last_xyz(self, state: StateModel, proc_directory: Path) -> None: ...