Files
DicePlayer/diceplayer/dice/dice_handler.py
Vitor Hideyoshi 2802f10013
Some checks failed
build and upload / test (3.10) (push) Failing after 1m45s
build and upload / pypi-upload (push) Has been skipped
refactor: restructure dice environment handling and update Python version requirement
2026-03-29 17:57:51 -03:00

187 lines
6.4 KiB
Python

import warnings
from diceplayer.dice.dice_input import (
NPTEqConfig,
NPTTerConfig,
NVTEqConfig,
NVTTerConfig,
)
from diceplayer.dice.dice_wrapper import (
DiceEnvironment,
DiceWrapper,
)
from diceplayer.environment import Atom, Molecule
from diceplayer.logger import logger
from diceplayer.state.state_model import StateModel
import shutil
from itertools import batched, chain, islice
from pathlib import Path
from threading import Thread
class DiceHandler:
def __init__(self, step_directory: Path):
self.dice_directory = step_directory / "dice"
def run(self, state: StateModel, cycle: int) -> list[Atom]:
if self.dice_directory.exists():
logger.info(
f"Found dice directory: {self.dice_directory}, this directory will be purged for a clean state"
)
shutil.rmtree(self.dice_directory)
self.dice_directory.mkdir(parents=True)
return self.run_simulations(state, cycle)
def run_simulations(self, state: StateModel, cycle: int) -> list[Atom]:
results: list[list[Atom]] = []
threads = []
for p in range(state.config.dice.nprocs):
t = Thread(target=self._simulation_process, args=(state, cycle, p, results))
threads.append(t)
t.start()
for t in threads:
t.join()
if len(results) != state.config.dice.nprocs:
raise RuntimeError(
f"Expected {state.config.dice.nprocs} simulation results, but got {len(results)}"
)
return self._aggregate_results(state, results)
def _simulation_process(
self,
state: StateModel,
cycle: int,
proc: int,
results: list[list[Atom]],
) -> None:
proc_directory = self.dice_directory / f"{proc:02d}"
if proc_directory.exists():
shutil.rmtree(proc_directory)
proc_directory.mkdir(parents=True)
dice = DiceWrapper(state.config.dice, proc_directory)
self._generate_phb_file(state, proc_directory)
if state.config.dice.randominit == "first" and cycle >= 0:
self._generate_last_xyz(state, proc_directory)
else:
nvt_ter_config = NVTTerConfig.from_config(state.config)
dice.run(nvt_ter_config)
if len(state.config.dice.nstep) == 2:
nvt_eq_config = NVTEqConfig.from_config(state.config)
dice.run(nvt_eq_config)
elif len(state.config.dice.nstep) == 3:
npt_ter_config = NPTTerConfig.from_config(state.config)
dice.run(npt_ter_config)
npt_eq_config = NPTEqConfig.from_config(state.config)
dice.run(npt_eq_config)
results.extend(
[
self._filter_environment_sites(state, environment)
for environment in dice.parse_results()
]
)
@staticmethod
def _generate_phb_file(state: StateModel, proc_directory: Path) -> None:
fstr = "{:<3d} {:>3d} {:>10.5f} {:>10.5f} {:>10.5f} {:>10.6f} {:>9.5f} {:>7.4f}\n"
phb_file = proc_directory / state.config.dice.ljname
with open(phb_file, "w") as f:
f.write(f"{state.config.dice.combrule}\n")
f.write(f"{len(state.config.dice.nmol)}\n")
for molecule in state.system.molecule:
f.write(f"{len(molecule.atom)} {molecule.molname}\n")
for atom in molecule.atom:
f.write(
fstr.format(
atom.lbl,
atom.na,
atom.rx,
atom.ry,
atom.rz,
atom.chg,
atom.eps,
atom.sig,
)
)
def _generate_last_xyz(self, state: StateModel, proc_directory: Path) -> None: ...
@staticmethod
def _filter_environment_sites(
state: StateModel, environment: DiceEnvironment
) -> list[Atom]:
picked_environment = []
ref_molecule = state.system.molecule[0]
ref_molecule_sizes = ref_molecule.sizes_of_molecule()
ref_n_sites = len(ref_molecule.atom) * state.config.dice.nmol[0]
min_distance = min(
(environment.thickness[i] - ref_molecule_sizes[i]) / 2 for i in range(3)
)
site_iter = iter(environment.items)
_ = list(islice(site_iter, ref_n_sites))
for molecule_index, molecule in enumerate(state.system.molecule[1:], start=1):
molecule_n_atoms = len(molecule.atom)
molecule_n_sites = molecule_n_atoms * state.config.dice.nmol[molecule_index]
sites = list(islice(site_iter, molecule_n_sites))
for molecule_sites in batched(sites, molecule_n_atoms):
new_molecule = Molecule("ASEC TMP MOLECULE")
for site_index, atom_site in enumerate(molecule_sites):
new_molecule.add_atom(
Atom(
molecule.atom[site_index].lbl,
molecule.atom[site_index].na,
atom_site.x,
atom_site.y,
atom_site.z,
molecule.atom[site_index].chg,
molecule.atom[site_index].eps,
molecule.atom[site_index].sig,
)
)
if molecule.signature() != new_molecule.signature():
_message = f"Skipping sites because the molecule signature does not match the reference molecule. Expected {molecule.signature()} but got {new_molecule.signature()}"
warnings.warn(_message)
logger.warning(_message)
continue
if ref_molecule.minimum_distance(new_molecule) >= min_distance:
continue
picked_environment.extend(new_molecule.atom)
return picked_environment
@staticmethod
def _aggregate_results(state: StateModel, results: list[list[Atom]]) -> list[Atom]:
norm_factor = round(state.config.dice.nstep[-1] / state.config.dice.isave)
agg_results = []
for atom in chain(*[r for r in results]):
atom.chg = atom.chg * norm_factor
agg_results.append(atom)
return agg_results