import warnings from diceplayer.dice.dice_input import ( NPTEqConfig, NPTTerConfig, NVTEqConfig, NVTTerConfig, ) from diceplayer.dice.dice_wrapper import ( DiceEnvironment, DiceWrapper, ) from diceplayer.environment import Atom, Molecule from diceplayer.logger import logger from diceplayer.state.state_model import StateModel import shutil from itertools import batched, chain, islice from pathlib import Path from threading import Thread class DiceHandler: def __init__(self, step_directory: Path): self.dice_directory = step_directory / "dice" def run(self, state: StateModel, cycle: int) -> list[Atom]: if self.dice_directory.exists(): logger.info( f"Found dice directory: {self.dice_directory}, this directory will be purged for a clean state" ) shutil.rmtree(self.dice_directory) self.dice_directory.mkdir(parents=True) return self.run_simulations(state, cycle) def run_simulations(self, state: StateModel, cycle: int) -> list[Atom]: results: list[list[Atom]] = [] threads = [] for p in range(state.config.dice.nprocs): t = Thread(target=self._simulation_process, args=(state, cycle, p, results)) threads.append(t) t.start() for t in threads: t.join() if len(results) != state.config.dice.nprocs: raise RuntimeError( f"Expected {state.config.dice.nprocs} simulation results, but got {len(results)}" ) return self._aggregate_results(state, results) def _simulation_process( self, state: StateModel, cycle: int, proc: int, results: list[list[Atom]], ) -> None: proc_directory = self.dice_directory / f"{proc:02d}" if proc_directory.exists(): shutil.rmtree(proc_directory) proc_directory.mkdir(parents=True) dice = DiceWrapper(state.config.dice, proc_directory) self._generate_phb_file(state, proc_directory) if state.config.dice.randominit == "first" and cycle >= 0: self._generate_last_xyz(state, proc_directory) else: nvt_ter_config = NVTTerConfig.from_config(state.config) dice.run(nvt_ter_config) if len(state.config.dice.nstep) == 2: nvt_eq_config = NVTEqConfig.from_config(state.config) dice.run(nvt_eq_config) elif len(state.config.dice.nstep) == 3: npt_ter_config = NPTTerConfig.from_config(state.config) dice.run(npt_ter_config) npt_eq_config = NPTEqConfig.from_config(state.config) dice.run(npt_eq_config) results.extend( [ self._filter_environment_sites(state, environment) for environment in dice.parse_results() ] ) @staticmethod def _generate_phb_file(state: StateModel, proc_directory: Path) -> None: fstr = "{:<3d} {:>3d} {:>10.5f} {:>10.5f} {:>10.5f} {:>10.6f} {:>9.5f} {:>7.4f}\n" phb_file = proc_directory / state.config.dice.ljname with open(phb_file, "w") as f: f.write(f"{state.config.dice.combrule}\n") f.write(f"{len(state.config.dice.nmol)}\n") for molecule in state.system.molecule: f.write(f"{len(molecule.atom)} {molecule.molname}\n") for atom in molecule.atom: f.write( fstr.format( atom.lbl, atom.na, atom.rx, atom.ry, atom.rz, atom.chg, atom.eps, atom.sig, ) ) def _generate_last_xyz(self, state: StateModel, proc_directory: Path) -> None: ... @staticmethod def _filter_environment_sites( state: StateModel, environment: DiceEnvironment ) -> list[Atom]: picked_environment = [] ref_molecule = state.system.molecule[0] ref_molecule_sizes = ref_molecule.sizes_of_molecule() ref_n_sites = len(ref_molecule.atom) * state.config.dice.nmol[0] min_distance = min( (environment.thickness[i] - ref_molecule_sizes[i]) / 2 for i in range(3) ) site_iter = iter(environment.items) _ = list(islice(site_iter, ref_n_sites)) for molecule_index, molecule in enumerate(state.system.molecule[1:], start=1): molecule_n_atoms = len(molecule.atom) molecule_n_sites = molecule_n_atoms * state.config.dice.nmol[molecule_index] sites = list(islice(site_iter, molecule_n_sites)) for molecule_sites in batched(sites, molecule_n_atoms): new_molecule = Molecule("ASEC TMP MOLECULE") for site_index, atom_site in enumerate(molecule_sites): new_molecule.add_atom( Atom( molecule.atom[site_index].lbl, molecule.atom[site_index].na, atom_site.x, atom_site.y, atom_site.z, molecule.atom[site_index].chg, molecule.atom[site_index].eps, molecule.atom[site_index].sig, ) ) if molecule.signature() != new_molecule.signature(): _message = f"Skipping sites because the molecule signature does not match the reference molecule. Expected {molecule.signature()} but got {new_molecule.signature()}" warnings.warn(_message) logger.warning(_message) continue if ref_molecule.minimum_distance(new_molecule) >= min_distance: continue picked_environment.extend(new_molecule.atom) return picked_environment @staticmethod def _aggregate_results(state: StateModel, results: list[list[Atom]]) -> list[Atom]: norm_factor = round(state.config.dice.nstep[-1] / state.config.dice.isave) agg_results = [] for atom in chain(*[r for r in results]): atom.chg = atom.chg * norm_factor agg_results.append(atom) return agg_results