refactor: restructure dice environment handling and update Python version requirement
This commit is contained in:
@@ -1,14 +1,21 @@
|
||||
import warnings
|
||||
|
||||
from diceplayer.dice.dice_input import (
|
||||
NPTEqConfig,
|
||||
NPTTerConfig,
|
||||
NVTEqConfig,
|
||||
NVTTerConfig,
|
||||
)
|
||||
from diceplayer.dice.dice_wrapper import DiceWrapper, DiceEnvironment
|
||||
from diceplayer.dice.dice_wrapper import (
|
||||
DiceEnvironment,
|
||||
DiceWrapper,
|
||||
)
|
||||
from diceplayer.environment import Atom, Molecule
|
||||
from diceplayer.logger import logger
|
||||
from diceplayer.state.state_model import StateModel
|
||||
|
||||
import shutil
|
||||
from itertools import batched, chain, islice
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
|
||||
@@ -17,7 +24,7 @@ class DiceHandler:
|
||||
def __init__(self, step_directory: Path):
|
||||
self.dice_directory = step_directory / "dice"
|
||||
|
||||
def run(self, state: StateModel, cycle: int) -> list[DiceEnvironment]:
|
||||
def run(self, state: StateModel, cycle: int) -> list[Atom]:
|
||||
if self.dice_directory.exists():
|
||||
logger.info(
|
||||
f"Found dice directory: {self.dice_directory}, this directory will be purged for a clean state"
|
||||
@@ -27,8 +34,8 @@ class DiceHandler:
|
||||
|
||||
return self.run_simulations(state, cycle)
|
||||
|
||||
def run_simulations(self, state: StateModel, cycle: int) -> list[DiceEnvironment]:
|
||||
results = []
|
||||
def run_simulations(self, state: StateModel, cycle: int) -> list[Atom]:
|
||||
results: list[list[Atom]] = []
|
||||
|
||||
threads = []
|
||||
for p in range(state.config.dice.nprocs):
|
||||
@@ -44,12 +51,14 @@ class DiceHandler:
|
||||
f"Expected {state.config.dice.nprocs} simulation results, but got {len(results)}"
|
||||
)
|
||||
|
||||
return [
|
||||
i for i in [r for r in results]
|
||||
]
|
||||
return self._aggregate_results(state, results)
|
||||
|
||||
def _simulation_process(
|
||||
self, state: StateModel, cycle: int, proc: int, results: list[list[DiceEnvironment]]
|
||||
self,
|
||||
state: StateModel,
|
||||
cycle: int,
|
||||
proc: int,
|
||||
results: list[list[Atom]],
|
||||
) -> None:
|
||||
proc_directory = self.dice_directory / f"{proc:02d}"
|
||||
if proc_directory.exists():
|
||||
@@ -77,7 +86,12 @@ class DiceHandler:
|
||||
npt_eq_config = NPTEqConfig.from_config(state.config)
|
||||
dice.run(npt_eq_config)
|
||||
|
||||
results.append(dice.parse_results(state.system))
|
||||
results.extend(
|
||||
[
|
||||
self._filter_environment_sites(state, environment)
|
||||
for environment in dice.parse_results()
|
||||
]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _generate_phb_file(state: StateModel, proc_directory: Path) -> None:
|
||||
@@ -106,3 +120,67 @@ class DiceHandler:
|
||||
)
|
||||
|
||||
def _generate_last_xyz(self, state: StateModel, proc_directory: Path) -> None: ...
|
||||
|
||||
@staticmethod
|
||||
def _filter_environment_sites(
|
||||
state: StateModel, environment: DiceEnvironment
|
||||
) -> list[Atom]:
|
||||
picked_environment = []
|
||||
|
||||
ref_molecule = state.system.molecule[0]
|
||||
ref_molecule_sizes = ref_molecule.sizes_of_molecule()
|
||||
ref_n_sites = len(ref_molecule.atom) * state.config.dice.nmol[0]
|
||||
|
||||
min_distance = min(
|
||||
(environment.thickness[i] - ref_molecule_sizes[i]) / 2 for i in range(3)
|
||||
)
|
||||
|
||||
site_iter = iter(environment.items)
|
||||
_ = list(islice(site_iter, ref_n_sites))
|
||||
|
||||
for molecule_index, molecule in enumerate(state.system.molecule[1:], start=1):
|
||||
molecule_n_atoms = len(molecule.atom)
|
||||
molecule_n_sites = molecule_n_atoms * state.config.dice.nmol[molecule_index]
|
||||
|
||||
sites = list(islice(site_iter, molecule_n_sites))
|
||||
|
||||
for molecule_sites in batched(sites, molecule_n_atoms):
|
||||
new_molecule = Molecule("ASEC TMP MOLECULE")
|
||||
|
||||
for site_index, atom_site in enumerate(molecule_sites):
|
||||
new_molecule.add_atom(
|
||||
Atom(
|
||||
molecule.atom[site_index].lbl,
|
||||
molecule.atom[site_index].na,
|
||||
atom_site.x,
|
||||
atom_site.y,
|
||||
atom_site.z,
|
||||
molecule.atom[site_index].chg,
|
||||
molecule.atom[site_index].eps,
|
||||
molecule.atom[site_index].sig,
|
||||
)
|
||||
)
|
||||
|
||||
if molecule.signature() != new_molecule.signature():
|
||||
_message = f"Skipping sites because the molecule signature does not match the reference molecule. Expected {molecule.signature()} but got {new_molecule.signature()}"
|
||||
warnings.warn(_message)
|
||||
logger.warning(_message)
|
||||
continue
|
||||
|
||||
if ref_molecule.minimum_distance(new_molecule) >= min_distance:
|
||||
continue
|
||||
|
||||
picked_environment.extend(new_molecule.atom)
|
||||
|
||||
return picked_environment
|
||||
|
||||
@staticmethod
|
||||
def _aggregate_results(state: StateModel, results: list[list[Atom]]) -> list[Atom]:
|
||||
norm_factor = round(state.config.dice.nstep[-1] / state.config.dice.isave)
|
||||
|
||||
agg_results = []
|
||||
for atom in chain(*[r for r in results]):
|
||||
atom.chg = atom.chg * norm_factor
|
||||
agg_results.append(atom)
|
||||
|
||||
return agg_results
|
||||
|
||||
@@ -1,16 +1,39 @@
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
import diceplayer.dice.dice_input as dice_input
|
||||
from diceplayer.config import DiceConfig
|
||||
from diceplayer.environment import System
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
import subprocess
|
||||
from itertools import islice
|
||||
from pathlib import Path
|
||||
from typing import Final
|
||||
from typing import Final, List, Self
|
||||
|
||||
|
||||
type DiceEnvironment = tuple[str, int, int, int]
|
||||
DiceEnvironmentAdapter = TypeAdapter(DiceEnvironment)
|
||||
@dataclass(slots=True, frozen=True)
|
||||
class DiceEnvironmentItem:
|
||||
atom: str
|
||||
x: float
|
||||
y: float
|
||||
z: float
|
||||
|
||||
|
||||
DiceEnvironmentItemAdapter = TypeAdapter(DiceEnvironmentItem)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class DiceEnvironment:
|
||||
number_of_sites: int
|
||||
thickness: List[float]
|
||||
items: List[DiceEnvironmentItem]
|
||||
|
||||
@classmethod
|
||||
def new(cls, thickness: List[float]) -> Self:
|
||||
return cls(number_of_sites=0, thickness=thickness, items=[])
|
||||
|
||||
def add_site(self, site: DiceEnvironmentItem):
|
||||
self.items.append(site)
|
||||
self.number_of_sites += 1
|
||||
|
||||
|
||||
DICE_FLAG_LINE: Final[int] = -2
|
||||
@@ -28,7 +51,10 @@ class DiceWrapper:
|
||||
|
||||
with open(output_path, "w") as outfile, open(input_path, "r") as infile:
|
||||
exit_status = subprocess.call(
|
||||
self.dice_config.progname, stdin=infile, stdout=outfile, cwd=self.working_directory
|
||||
self.dice_config.progname,
|
||||
stdin=infile,
|
||||
stdout=outfile,
|
||||
cwd=self.working_directory,
|
||||
)
|
||||
|
||||
if exit_status != 0:
|
||||
@@ -41,22 +67,37 @@ class DiceWrapper:
|
||||
|
||||
raise RuntimeError(f"Dice simulation failed with exit status {exit_status}")
|
||||
|
||||
def parse_results(self, system: System) -> list[DiceEnvironment]:
|
||||
NUMBER_OF_HEADER_LINES = 2
|
||||
NUMBER_OF_PRIMARY_ATOMS = len(system.molecule[0].atom)
|
||||
|
||||
def parse_results(self) -> list[DiceEnvironment]:
|
||||
results = []
|
||||
for output_file in sorted(self.working_directory.glob("phb*.xyz")):
|
||||
with open(output_file, "r") as f:
|
||||
for _ in range(NUMBER_OF_HEADER_LINES + NUMBER_OF_PRIMARY_ATOMS):
|
||||
next(f, None)
|
||||
|
||||
for line in f:
|
||||
if line.strip() == "":
|
||||
break
|
||||
positions_file = self.working_directory / "phb.xyz"
|
||||
if not positions_file.exists():
|
||||
raise RuntimeError(f"Positions file not found at {self.working_directory}")
|
||||
|
||||
results.append(
|
||||
DiceEnvironmentAdapter.validate_python(line.split())
|
||||
)
|
||||
with open(positions_file, "r") as f:
|
||||
while True:
|
||||
line = f.readline()
|
||||
if not line.startswith(" "):
|
||||
break
|
||||
|
||||
environment = DiceEnvironment(
|
||||
number_of_sites=int(line.strip()),
|
||||
thickness=[float(n) for n in f.readline().split()[-3:]],
|
||||
items=[],
|
||||
)
|
||||
|
||||
# Skip the comment line
|
||||
|
||||
environment.items.extend(
|
||||
[
|
||||
DiceEnvironmentItemAdapter.validate_python(
|
||||
{"atom": site[0], "x": site[1], "y": site[2], "z": site[3]}
|
||||
)
|
||||
for atom in islice(f, environment.number_of_sites)
|
||||
if (site := atom.strip().split()) and len(site) == 4
|
||||
]
|
||||
)
|
||||
|
||||
results.append(environment)
|
||||
|
||||
return results
|
||||
|
||||
@@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
from diceplayer.environment import Atom
|
||||
from diceplayer.logger import logger
|
||||
from diceplayer.utils.cache import invalidate_computed_properties
|
||||
from diceplayer.utils.misc import BOHR2ANG, EA_2_DEBYE
|
||||
from diceplayer.utils.ptable import GHOST_NUMBER
|
||||
|
||||
@@ -15,7 +14,6 @@ from typing_extensions import List, Self, Tuple
|
||||
import math
|
||||
from copy import deepcopy
|
||||
from dataclasses import field
|
||||
from functools import cached_property
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -34,11 +32,11 @@ class Molecule:
|
||||
molname: str
|
||||
atom: List[Atom] = field(default_factory=list)
|
||||
|
||||
@cached_property
|
||||
@property
|
||||
def total_mass(self) -> float:
|
||||
return sum(atom.mass for atom in self.atom)
|
||||
|
||||
@cached_property
|
||||
@property
|
||||
def com(self) -> npt.NDArray[np.float64]:
|
||||
com = np.zeros(3)
|
||||
|
||||
@@ -49,7 +47,7 @@ class Molecule:
|
||||
|
||||
return com
|
||||
|
||||
@cached_property
|
||||
@property
|
||||
def inertia_tensor(self) -> npt.NDArray[np.float64]:
|
||||
"""
|
||||
Calculates the inertia tensor of the molecule.
|
||||
@@ -79,7 +77,6 @@ class Molecule:
|
||||
|
||||
return inertia_tensor
|
||||
|
||||
@invalidate_computed_properties()
|
||||
def add_atom(self, a: Atom) -> None:
|
||||
"""
|
||||
Adds Atom instance to the molecule.
|
||||
@@ -90,7 +87,6 @@ class Molecule:
|
||||
|
||||
self.atom.append(a)
|
||||
|
||||
@invalidate_computed_properties()
|
||||
def remove_atom(self, a: Atom) -> None:
|
||||
"""
|
||||
Removes Atom instance from the molecule.
|
||||
@@ -101,7 +97,6 @@ class Molecule:
|
||||
|
||||
self.atom.remove(a)
|
||||
|
||||
@invalidate_computed_properties()
|
||||
def move_center_of_mass_to_origin(self) -> None:
|
||||
"""
|
||||
Updated positions based on the center of mass of the molecule
|
||||
@@ -111,7 +106,6 @@ class Molecule:
|
||||
atom.ry -= self.com[1]
|
||||
atom.rz -= self.com[2]
|
||||
|
||||
@invalidate_computed_properties()
|
||||
def rotate_to_standard_orientation(self) -> None:
|
||||
"""
|
||||
Rotates the molecule to the standard orientation
|
||||
@@ -315,3 +309,12 @@ class Molecule:
|
||||
diff = coords_a[:, None, :] - coords_b[None, :, :]
|
||||
d2 = np.sum(diff**2, axis=-1)
|
||||
return np.sqrt(d2.min())
|
||||
|
||||
def signature(self) -> List[int]:
|
||||
"""
|
||||
Returns the signature of the molecule, which is a list of the number of atoms of each type in the molecule.
|
||||
|
||||
Returns:
|
||||
List[int]: signature of the molecule
|
||||
"""
|
||||
return [a.lbl for a in self.atom]
|
||||
|
||||
@@ -1,19 +1,28 @@
|
||||
from pathlib import Path
|
||||
|
||||
from diceplayer.config.player_config import RoutineType
|
||||
from diceplayer.dice.dice_wrapper import DiceEnvironment
|
||||
from diceplayer.dice.dice_wrapper import DiceEnvironmentItem
|
||||
from diceplayer.environment import Atom
|
||||
from diceplayer.logger import logger
|
||||
from diceplayer.state.state_model import StateModel
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class OptimizationHandler:
|
||||
def __init__(self, step_directory: Path):
|
||||
self.dice_directory = step_directory / "dice"
|
||||
self.optimization_directory = step_directory / "optimization"
|
||||
|
||||
def run(self, state: StateModel, current_cycle: int, dice_environment: list[DiceEnvironment]) -> StateModel:
|
||||
def run(
|
||||
self,
|
||||
state: StateModel,
|
||||
current_cycle: int,
|
||||
dice_environment: list[Atom],
|
||||
) -> StateModel:
|
||||
routine = self._fetch_current_routine(state, current_cycle)
|
||||
print(f"Running Optimization - {current_cycle} - {routine}")
|
||||
logger.info(
|
||||
f"Running Optimization - {current_cycle} - {routine} - {self.optimization_directory}"
|
||||
)
|
||||
|
||||
print(dice_environment)
|
||||
logger.info(dice_environment)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
Reference in New Issue
Block a user