refactor: restructure dice environment handling and update Python version requirement
Some checks failed
build and upload / test (3.10) (push) Failing after 1m45s
build and upload / pypi-upload (push) Has been skipped

This commit is contained in:
2026-03-29 17:38:44 -03:00
parent 7e66c98f26
commit 2802f10013
9 changed files with 314 additions and 352 deletions

View File

@@ -1,14 +1,21 @@
import warnings
from diceplayer.dice.dice_input import (
NPTEqConfig,
NPTTerConfig,
NVTEqConfig,
NVTTerConfig,
)
from diceplayer.dice.dice_wrapper import DiceWrapper, DiceEnvironment
from diceplayer.dice.dice_wrapper import (
DiceEnvironment,
DiceWrapper,
)
from diceplayer.environment import Atom, Molecule
from diceplayer.logger import logger
from diceplayer.state.state_model import StateModel
import shutil
from itertools import batched, chain, islice
from pathlib import Path
from threading import Thread
@@ -17,7 +24,7 @@ class DiceHandler:
def __init__(self, step_directory: Path):
self.dice_directory = step_directory / "dice"
def run(self, state: StateModel, cycle: int) -> list[DiceEnvironment]:
def run(self, state: StateModel, cycle: int) -> list[Atom]:
if self.dice_directory.exists():
logger.info(
f"Found dice directory: {self.dice_directory}, this directory will be purged for a clean state"
@@ -27,8 +34,8 @@ class DiceHandler:
return self.run_simulations(state, cycle)
def run_simulations(self, state: StateModel, cycle: int) -> list[DiceEnvironment]:
results = []
def run_simulations(self, state: StateModel, cycle: int) -> list[Atom]:
results: list[list[Atom]] = []
threads = []
for p in range(state.config.dice.nprocs):
@@ -44,12 +51,14 @@ class DiceHandler:
f"Expected {state.config.dice.nprocs} simulation results, but got {len(results)}"
)
return [
i for i in [r for r in results]
]
return self._aggregate_results(state, results)
def _simulation_process(
self, state: StateModel, cycle: int, proc: int, results: list[list[DiceEnvironment]]
self,
state: StateModel,
cycle: int,
proc: int,
results: list[list[Atom]],
) -> None:
proc_directory = self.dice_directory / f"{proc:02d}"
if proc_directory.exists():
@@ -77,7 +86,12 @@ class DiceHandler:
npt_eq_config = NPTEqConfig.from_config(state.config)
dice.run(npt_eq_config)
results.append(dice.parse_results(state.system))
results.extend(
[
self._filter_environment_sites(state, environment)
for environment in dice.parse_results()
]
)
@staticmethod
def _generate_phb_file(state: StateModel, proc_directory: Path) -> None:
@@ -106,3 +120,67 @@ class DiceHandler:
)
def _generate_last_xyz(self, state: StateModel, proc_directory: Path) -> None: ...
@staticmethod
def _filter_environment_sites(
state: StateModel, environment: DiceEnvironment
) -> list[Atom]:
picked_environment = []
ref_molecule = state.system.molecule[0]
ref_molecule_sizes = ref_molecule.sizes_of_molecule()
ref_n_sites = len(ref_molecule.atom) * state.config.dice.nmol[0]
min_distance = min(
(environment.thickness[i] - ref_molecule_sizes[i]) / 2 for i in range(3)
)
site_iter = iter(environment.items)
_ = list(islice(site_iter, ref_n_sites))
for molecule_index, molecule in enumerate(state.system.molecule[1:], start=1):
molecule_n_atoms = len(molecule.atom)
molecule_n_sites = molecule_n_atoms * state.config.dice.nmol[molecule_index]
sites = list(islice(site_iter, molecule_n_sites))
for molecule_sites in batched(sites, molecule_n_atoms):
new_molecule = Molecule("ASEC TMP MOLECULE")
for site_index, atom_site in enumerate(molecule_sites):
new_molecule.add_atom(
Atom(
molecule.atom[site_index].lbl,
molecule.atom[site_index].na,
atom_site.x,
atom_site.y,
atom_site.z,
molecule.atom[site_index].chg,
molecule.atom[site_index].eps,
molecule.atom[site_index].sig,
)
)
if molecule.signature() != new_molecule.signature():
_message = f"Skipping sites because the molecule signature does not match the reference molecule. Expected {molecule.signature()} but got {new_molecule.signature()}"
warnings.warn(_message)
logger.warning(_message)
continue
if ref_molecule.minimum_distance(new_molecule) >= min_distance:
continue
picked_environment.extend(new_molecule.atom)
return picked_environment
@staticmethod
def _aggregate_results(state: StateModel, results: list[list[Atom]]) -> list[Atom]:
norm_factor = round(state.config.dice.nstep[-1] / state.config.dice.isave)
agg_results = []
for atom in chain(*[r for r in results]):
atom.chg = atom.chg * norm_factor
agg_results.append(atom)
return agg_results

View File

@@ -1,16 +1,39 @@
from pydantic import TypeAdapter
import diceplayer.dice.dice_input as dice_input
from diceplayer.config import DiceConfig
from diceplayer.environment import System
from pydantic import TypeAdapter
from pydantic.dataclasses import dataclass
import subprocess
from itertools import islice
from pathlib import Path
from typing import Final
from typing import Final, List, Self
type DiceEnvironment = tuple[str, int, int, int]
DiceEnvironmentAdapter = TypeAdapter(DiceEnvironment)
@dataclass(slots=True, frozen=True)
class DiceEnvironmentItem:
atom: str
x: float
y: float
z: float
DiceEnvironmentItemAdapter = TypeAdapter(DiceEnvironmentItem)
@dataclass(slots=True)
class DiceEnvironment:
number_of_sites: int
thickness: List[float]
items: List[DiceEnvironmentItem]
@classmethod
def new(cls, thickness: List[float]) -> Self:
return cls(number_of_sites=0, thickness=thickness, items=[])
def add_site(self, site: DiceEnvironmentItem):
self.items.append(site)
self.number_of_sites += 1
DICE_FLAG_LINE: Final[int] = -2
@@ -28,7 +51,10 @@ class DiceWrapper:
with open(output_path, "w") as outfile, open(input_path, "r") as infile:
exit_status = subprocess.call(
self.dice_config.progname, stdin=infile, stdout=outfile, cwd=self.working_directory
self.dice_config.progname,
stdin=infile,
stdout=outfile,
cwd=self.working_directory,
)
if exit_status != 0:
@@ -41,22 +67,37 @@ class DiceWrapper:
raise RuntimeError(f"Dice simulation failed with exit status {exit_status}")
def parse_results(self, system: System) -> list[DiceEnvironment]:
NUMBER_OF_HEADER_LINES = 2
NUMBER_OF_PRIMARY_ATOMS = len(system.molecule[0].atom)
def parse_results(self) -> list[DiceEnvironment]:
results = []
for output_file in sorted(self.working_directory.glob("phb*.xyz")):
with open(output_file, "r") as f:
for _ in range(NUMBER_OF_HEADER_LINES + NUMBER_OF_PRIMARY_ATOMS):
next(f, None)
for line in f:
if line.strip() == "":
break
positions_file = self.working_directory / "phb.xyz"
if not positions_file.exists():
raise RuntimeError(f"Positions file not found at {self.working_directory}")
results.append(
DiceEnvironmentAdapter.validate_python(line.split())
)
with open(positions_file, "r") as f:
while True:
line = f.readline()
if not line.startswith(" "):
break
environment = DiceEnvironment(
number_of_sites=int(line.strip()),
thickness=[float(n) for n in f.readline().split()[-3:]],
items=[],
)
# Skip the comment line
environment.items.extend(
[
DiceEnvironmentItemAdapter.validate_python(
{"atom": site[0], "x": site[1], "y": site[2], "z": site[3]}
)
for atom in islice(f, environment.number_of_sites)
if (site := atom.strip().split()) and len(site) == 4
]
)
results.append(environment)
return results