116 lines
3.9 KiB
Python
116 lines
3.9 KiB
Python
from diceplayer.dice.dice_input import (
|
|
NPTEqConfig,
|
|
NPTTerConfig,
|
|
NVTEqConfig,
|
|
NVTTerConfig,
|
|
)
|
|
from diceplayer.dice.dice_wrapper import DiceWrapper
|
|
from diceplayer.logger import logger
|
|
from diceplayer.state.state_model import StateModel
|
|
|
|
import shutil
|
|
from pathlib import Path
|
|
from threading import Thread
|
|
|
|
|
|
class DiceHandler:
|
|
def __init__(self, step_directory: Path):
|
|
self.dice_directory = step_directory / "dice"
|
|
|
|
def run(self, state: StateModel, cycle: int) -> StateModel:
|
|
if self.dice_directory.exists():
|
|
logger.info(
|
|
f"Found dice directory: {self.dice_directory}, this directory will be purged for a clean state"
|
|
)
|
|
shutil.rmtree(self.dice_directory)
|
|
self.dice_directory.mkdir(parents=True)
|
|
|
|
simulation_results = self.run_simulations(state, cycle)
|
|
|
|
result = self.aggregate_results(simulation_results)
|
|
|
|
return self.commit_simulation_state(state, result)
|
|
|
|
def run_simulations(self, state: StateModel, cycle: int) -> list[dict]:
|
|
results = []
|
|
|
|
threads = []
|
|
for p in range(state.config.dice.nprocs):
|
|
t = Thread(target=self._simulation_process, args=(state, cycle, p, results))
|
|
threads.append(t)
|
|
t.start()
|
|
|
|
for t in threads:
|
|
t.join()
|
|
|
|
if len(results) != state.config.dice.nprocs:
|
|
raise RuntimeError(
|
|
f"Expected {state.config.dice.nprocs} simulation results, but got {len(results)}"
|
|
)
|
|
|
|
return results
|
|
|
|
def aggregate_results(self, simulation_results: list[dict]) -> dict: ...
|
|
|
|
def commit_simulation_state(self, state: StateModel, result: dict) -> StateModel:
|
|
return state
|
|
|
|
def _simulation_process(
|
|
self, state: StateModel, cycle: int, proc: int, results: list[dict]
|
|
) -> None:
|
|
proc_directory = self.dice_directory / f"{proc:02d}"
|
|
if proc_directory.exists():
|
|
shutil.rmtree(proc_directory)
|
|
proc_directory.mkdir(parents=True)
|
|
|
|
dice = DiceWrapper(state.config.dice, proc_directory)
|
|
|
|
self._generate_phb_file(state, proc_directory)
|
|
|
|
if state.config.dice.randominit == "first" and cycle >= 0:
|
|
self._generate_last_xyz(state, proc_directory)
|
|
else:
|
|
nvt_ter_config = NVTTerConfig.from_config(state.config)
|
|
dice.run(nvt_ter_config)
|
|
|
|
if len(state.config.dice.nstep) == 2:
|
|
nvt_eq_config = NVTEqConfig.from_config(state.config)
|
|
dice.run(nvt_eq_config)
|
|
|
|
elif len(state.config.dice.nstep) == 3:
|
|
npt_ter_config = NPTTerConfig.from_config(state.config)
|
|
dice.run(npt_ter_config)
|
|
|
|
npt_eq_config = NPTEqConfig.from_config(state.config)
|
|
dice.run(npt_eq_config)
|
|
|
|
results.append(dice.parse_results(state.system))
|
|
|
|
@staticmethod
|
|
def _generate_phb_file(state: StateModel, proc_directory: Path) -> None:
|
|
fstr = "{:<3d} {:>3d} {:>10.5f} {:>10.5f} {:>10.5f} {:>10.6f} {:>9.5f} {:>7.4f}\n"
|
|
|
|
phb_file = proc_directory / state.config.dice.ljname
|
|
|
|
with open(phb_file, "w") as f:
|
|
f.write(f"{state.config.dice.combrule}\n")
|
|
f.write(f"{len(state.config.dice.nmol)}\n")
|
|
|
|
for molecule in state.system.molecule:
|
|
f.write(f"{len(molecule.atom)} {molecule.molname}\n")
|
|
for atom in molecule.atom:
|
|
f.write(
|
|
fstr.format(
|
|
atom.lbl,
|
|
atom.na,
|
|
atom.rx,
|
|
atom.ry,
|
|
atom.rz,
|
|
atom.chg,
|
|
atom.eps,
|
|
atom.sig,
|
|
)
|
|
)
|
|
|
|
def _generate_last_xyz(self, state: StateModel, proc_directory: Path) -> None: ...
|