Compare commits
10 Commits
main
...
0470200d00
| Author | SHA1 | Date | |
|---|---|---|---|
|
0470200d00
|
|||
|
0763c4a9e1
|
|||
|
30be88e6b4
|
|||
|
6a154429e9
|
|||
|
4c8cbc821d
|
|||
|
9f22304dd8
|
|||
|
53eb34a83e
|
|||
|
06ae9b41f0
|
|||
|
11ff4c0c21
|
|||
|
c59f0d6516
|
@@ -1 +1 @@
|
||||
3.10
|
||||
3.12
|
||||
|
||||
@@ -1,23 +1,25 @@
|
||||
diceplayer:
|
||||
opt: no
|
||||
type: both
|
||||
switch_cyc: 3
|
||||
max_cyc: 5
|
||||
mem: 24
|
||||
maxcyc: 5
|
||||
ncores: 5
|
||||
nprocs: 4
|
||||
ncores: 20
|
||||
qmprog: 'g16'
|
||||
lps: no
|
||||
ghosts: no
|
||||
altsteps: 2000
|
||||
|
||||
dice:
|
||||
nmol: [1, 100]
|
||||
nprocs: 1
|
||||
nmol: [1, 200]
|
||||
dens: 1.5
|
||||
nstep: [2000, 3000]
|
||||
isave: 1000
|
||||
nstep: [200, 300]
|
||||
isave: 100
|
||||
outname: 'phb'
|
||||
progname: '~/.local/bin/dice'
|
||||
ljname: 'phb.ljc'
|
||||
progname: 'dice'
|
||||
ljname: 'phb.ljc.example'
|
||||
randominit: 'always'
|
||||
seed: 12345
|
||||
|
||||
gaussian:
|
||||
qmprog: 'g16'
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
from diceplayer.utils import Logger
|
||||
|
||||
from importlib import metadata
|
||||
|
||||
|
||||
VERSION = metadata.version("diceplayer")
|
||||
|
||||
logger = Logger(__name__)
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
from diceplayer import VERSION, logger
|
||||
from diceplayer.cli import ArgsModel, read_input
|
||||
from diceplayer.config import PlayerConfig
|
||||
from diceplayer.logger import logger
|
||||
from diceplayer.player import Player
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from importlib import metadata
|
||||
|
||||
|
||||
VERSION = metadata.version("diceplayer")
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Read and store the arguments passed to the program
|
||||
and set the usage and help messages
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser(prog="Diceplayer")
|
||||
parser.add_argument(
|
||||
"-c", "--continue", dest="opt_continue", default=False, action="store_true"
|
||||
"-v", "--version", action="version", version="diceplayer-" + VERSION
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v", "--version", action="version", version="diceplayer-" + VERSION
|
||||
"-c", "--continue", dest="continuation", default=False, action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
@@ -34,36 +34,26 @@ def main():
|
||||
metavar="OUTFILE",
|
||||
help="output file of diceplayer [default = run.log]",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--force",
|
||||
dest="force",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="force overwrite existing state file if it exists [default = False]",
|
||||
)
|
||||
args = ArgsModel.from_args(parser.parse_args())
|
||||
|
||||
# Open OUTFILE for writing and print keywords and initial info
|
||||
logger.set_logger(args.outfile, logging.INFO)
|
||||
logger.set_output_file(args.outfile)
|
||||
|
||||
if args.opt_continue:
|
||||
player = Player.from_save()
|
||||
else:
|
||||
player = Player.from_file(args.infile)
|
||||
config: PlayerConfig
|
||||
try:
|
||||
config = read_input(args.infile)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read input file: {e}")
|
||||
return
|
||||
|
||||
player.read_potentials()
|
||||
|
||||
player.create_simulation_dir()
|
||||
player.create_geoms_file()
|
||||
|
||||
player.print_keywords()
|
||||
|
||||
player.print_potentials()
|
||||
|
||||
player.prepare_system()
|
||||
|
||||
player.start()
|
||||
|
||||
logger.info("\n+" + 88 * "-" + "+\n")
|
||||
|
||||
player.print_results()
|
||||
|
||||
logger.info("\n+" + 88 * "-" + "+\n")
|
||||
|
||||
logger.info("Diceplayer finished successfully \n")
|
||||
Player(config).play(continuation=args.continuation, force=args.force)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
5
diceplayer/cli/__init__.py
Normal file
5
diceplayer/cli/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .args_model import ArgsModel
|
||||
from .read_input_file import read_input
|
||||
|
||||
|
||||
__all__ = ["ArgsModel", "read_input"]
|
||||
12
diceplayer/cli/args_model.py
Normal file
12
diceplayer/cli/args_model.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ArgsModel(BaseModel):
|
||||
outfile: str
|
||||
infile: str
|
||||
continuation: bool
|
||||
force: bool
|
||||
|
||||
@classmethod
|
||||
def from_args(cls, args):
|
||||
return cls(**vars(args))
|
||||
9
diceplayer/cli/read_input_file.py
Normal file
9
diceplayer/cli/read_input_file.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from diceplayer.config import PlayerConfig
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
def read_input(infile) -> PlayerConfig:
|
||||
with open(infile, "r") as f:
|
||||
values = yaml.safe_load(f)
|
||||
return PlayerConfig.model_validate(values["diceplayer"])
|
||||
@@ -0,0 +1,10 @@
|
||||
from .dice_config import DiceConfig
|
||||
from .gaussian_config import GaussianConfig
|
||||
from .player_config import PlayerConfig
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DiceConfig",
|
||||
"GaussianConfig",
|
||||
"PlayerConfig",
|
||||
]
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import List, Literal
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import Literal
|
||||
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class DiceConfig(BaseModel):
|
||||
@@ -7,15 +10,22 @@ class DiceConfig(BaseModel):
|
||||
Data Transfer Object for the Dice configuration.
|
||||
"""
|
||||
|
||||
ljname: str = Field(..., description="Name of the Lennard-Jones potential file")
|
||||
model_config = ConfigDict(
|
||||
frozen=True,
|
||||
)
|
||||
nprocs: int = Field(
|
||||
..., description="Number of processes to use for the DICE simulations"
|
||||
)
|
||||
|
||||
ljname: Path = Field(..., description="Name of the Lennard-Jones potential file")
|
||||
outname: str = Field(
|
||||
..., description="Name of the output file for the simulation results"
|
||||
)
|
||||
dens: float = Field(..., description="Density of the system")
|
||||
nmol: List[int] = Field(
|
||||
nmol: list[int] = Field(
|
||||
..., description="List of the number of molecules for each component"
|
||||
)
|
||||
nstep: List[int] = Field(
|
||||
nstep: list[int] = Field(
|
||||
...,
|
||||
description="List of the number of steps for each component",
|
||||
min_length=2,
|
||||
@@ -25,6 +35,13 @@ class DiceConfig(BaseModel):
|
||||
upbuf: int = Field(
|
||||
360, description="Buffer size for the potential energy calculations"
|
||||
)
|
||||
irdf: int = Field(
|
||||
0,
|
||||
description="Controls the interval of Monte Carlo steps at which configurations are used at computation of radial distribution functions",
|
||||
)
|
||||
vstep: int = Field(
|
||||
5000, description="Frequency of volume change moves in NPT simulations"
|
||||
)
|
||||
combrule: Literal["+", "*"] = Field(
|
||||
"*", description="Combination rule for the Lennard-Jones potential"
|
||||
)
|
||||
@@ -37,4 +54,7 @@ class DiceConfig(BaseModel):
|
||||
randominit: str = Field(
|
||||
"first", description="Method for initializing the random number generator"
|
||||
)
|
||||
seed: int | None = Field(None, description="Seed for the random number generator")
|
||||
seed: int = Field(
|
||||
default_factory=lambda: int(1e6 * random.random()),
|
||||
description="Seed for the random number generator",
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
@@ -7,10 +7,14 @@ class GaussianConfig(BaseModel):
|
||||
Data Transfer Object for the Gaussian configuration.
|
||||
"""
|
||||
|
||||
level: str = Field(..., description="Level of theory for the QM calculations")
|
||||
model_config = ConfigDict(
|
||||
frozen=True,
|
||||
)
|
||||
|
||||
qmprog: Literal["g03", "g09", "g16"] = Field(
|
||||
"g16", description="QM program to use for the calculations"
|
||||
)
|
||||
level: str = Field(..., description="Level of theory for the QM calculations")
|
||||
|
||||
chgmult: list[int] = Field(
|
||||
default_factory=lambda: [0, 1],
|
||||
@@ -20,6 +24,6 @@ class GaussianConfig(BaseModel):
|
||||
"chelpg", description="Population analysis method for the QM calculations"
|
||||
)
|
||||
chg_tol: float = Field(0.01, description="Charge tolerance for the QM calculations")
|
||||
keywords: str = Field(
|
||||
keywords: str | None = Field(
|
||||
None, description="Additional keywords for the QM calculations"
|
||||
)
|
||||
|
||||
@@ -1,27 +1,52 @@
|
||||
from diceplayer.config.dice_config import DiceConfig
|
||||
from diceplayer.config.gaussian_config import GaussianConfig
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing_extensions import Any
|
||||
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
MIN_STEP = 20000
|
||||
STEP_INCREMENT = 1000
|
||||
|
||||
|
||||
class RoutineType(str, Enum):
|
||||
CHARGE = "charge"
|
||||
GEOMETRY = "geometry"
|
||||
BOTH = "both"
|
||||
|
||||
|
||||
class PlayerConfig(BaseModel):
|
||||
"""
|
||||
Data Transfer Object for the player configuration.
|
||||
Configuration for DICEPlayer simulations.
|
||||
|
||||
Attributes:
|
||||
type: Type of simulation to perform (charge, geometry, or both).
|
||||
max_cyc: Maximum number of cycles for the geometry optimization.
|
||||
switch_cyc: Cycle at which to switch from charge to geometry optimization (if type is "both").
|
||||
mem: Memory configuration for QM calculations.
|
||||
nprocs: Number of processors to use for QM calculations.
|
||||
ncores: Number of cores to use for QM calculations.
|
||||
dice: Configuration parameters specific to DICE simulations.
|
||||
gaussian: Configuration parameters specific to Gaussian calculations.
|
||||
altsteps: Number of steps for the alternate simulation (default: 20000).
|
||||
geoms_file: File name for the geometries output (default: "geoms.xyz").
|
||||
simulation_dir: Directory name for the simulation files (default: "simfiles").
|
||||
"""
|
||||
|
||||
opt: bool = Field(..., description="Whether to perform geometry optimization")
|
||||
maxcyc: int = Field(
|
||||
..., description="Maximum number of cycles for the geometry optimization"
|
||||
model_config = ConfigDict(
|
||||
frozen=True,
|
||||
)
|
||||
nprocs: int = Field(
|
||||
..., description="Number of processors to use for the QM calculations"
|
||||
|
||||
type: RoutineType = Field(..., description="Type of simulation to perform")
|
||||
max_cyc: int = Field(
|
||||
..., description="Maximum number of cycles for the geometry optimization", gt=0
|
||||
)
|
||||
switch_cyc: int = Field(..., description="Switch cycle configuration")
|
||||
|
||||
mem: int = Field(None, description="Memory configuration")
|
||||
ncores: int = Field(
|
||||
..., description="Number of cores to use for the QM calculations"
|
||||
)
|
||||
@@ -29,20 +54,37 @@ class PlayerConfig(BaseModel):
|
||||
dice: DiceConfig = Field(..., description="Dice configuration")
|
||||
gaussian: GaussianConfig = Field(..., description="Gaussian configuration")
|
||||
|
||||
mem: int = Field(None, description="Memory configuration")
|
||||
switchcyc: int = Field(3, description="Switch cycle configuration")
|
||||
qmprog: str = Field("g16", description="QM program to use for the calculations")
|
||||
altsteps: int = Field(
|
||||
20000, description="Number of steps for the alternate simulation"
|
||||
)
|
||||
geoms_file: Path = Field(
|
||||
"geoms.xyz", description="File name for the geometries output"
|
||||
Path("geoms.xyz"), description="File name for the geometries output"
|
||||
)
|
||||
simulation_dir: Path = Field(
|
||||
"simfiles", description="Directory name for the simulation files"
|
||||
Path("simfiles"), description="Directory name for the simulation files"
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_altsteps(self) -> Self:
|
||||
self.altsteps = round(max(MIN_STEP, self.altsteps) / 1000) * 1000
|
||||
return self
|
||||
@model_validator(mode="before")
|
||||
@staticmethod
|
||||
def validate_altsteps(fields) -> dict[str, Any]:
|
||||
altsteps = fields.pop("altsteps", MIN_STEP)
|
||||
fields["altsteps"] = (
|
||||
round(max(MIN_STEP, altsteps) / STEP_INCREMENT) * STEP_INCREMENT
|
||||
)
|
||||
return fields
|
||||
|
||||
@model_validator(mode="before")
|
||||
@staticmethod
|
||||
def validate_switch_cyc(fields: dict[str, Any]) -> dict[str, Any]:
|
||||
max_cyc = int(fields.get("max_cyc", 0))
|
||||
switch_cyc = int(fields.get("switch_cyc", max_cyc))
|
||||
|
||||
if fields.get("type") == "both" and not switch_cyc < max_cyc:
|
||||
raise ValueError("switch_cyc must be less than max_cyc when type='both'.")
|
||||
|
||||
if fields.get("type") != "both" and switch_cyc != max_cyc:
|
||||
raise ValueError(
|
||||
"switch_cyc must be equal to max_cyc when type is not 'both'."
|
||||
)
|
||||
|
||||
return fields
|
||||
|
||||
174
diceplayer/dice/__init__.py
Normal file
174
diceplayer/dice/__init__.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
DICE Monte Carlo Simulation Interface
|
||||
=====================================
|
||||
|
||||
This package provides utilities for configuring and running simulations with
|
||||
the DICE Monte Carlo molecular simulation program.
|
||||
|
||||
DICE performs statistical sampling of molecular systems using the Metropolis
|
||||
Monte Carlo algorithm. Simulations are defined by text input files containing
|
||||
keywords that control the thermodynamic ensemble, system composition, and
|
||||
simulation parameters.
|
||||
|
||||
---------------------------------------------------------------------
|
||||
|
||||
Simulation Ensembles
|
||||
--------------------
|
||||
|
||||
DICE supports multiple statistical ensembles.
|
||||
|
||||
NVT
|
||||
Canonical ensemble where the following properties remain constant:
|
||||
|
||||
- N: number of molecules
|
||||
- V: system volume
|
||||
- T: temperature
|
||||
|
||||
The system density is fixed and the simulation box volume does not change
|
||||
during the simulation.
|
||||
|
||||
NPT
|
||||
Isothermal–isobaric ensemble where the following properties remain constant:
|
||||
|
||||
- N: number of molecules
|
||||
- P: pressure
|
||||
- T: temperature
|
||||
|
||||
The simulation box volume is allowed to fluctuate in order to maintain the
|
||||
target pressure.
|
||||
|
||||
---------------------------------------------------------------------
|
||||
|
||||
Simulation Stages
|
||||
-----------------
|
||||
|
||||
Simulations are typically executed in multiple stages.
|
||||
|
||||
Thermalization (TER)
|
||||
Initial phase where the system relaxes to the desired thermodynamic
|
||||
conditions. Molecular configurations stabilize and the system reaches
|
||||
equilibrium.
|
||||
|
||||
During this stage statistical properties are **not accumulated**.
|
||||
|
||||
Production / Equilibration (EQ)
|
||||
Main sampling phase after the system has equilibrated.
|
||||
|
||||
Statistical properties such as energies, densities, and radial
|
||||
distribution functions are collected and configurations may be saved
|
||||
for later analysis.
|
||||
|
||||
---------------------------------------------------------------------
|
||||
|
||||
Typical Simulation Pipeline
|
||||
---------------------------
|
||||
|
||||
Two common execution workflows are used.
|
||||
|
||||
NVT Simulation
|
||||
Used when the system density is known.
|
||||
|
||||
1. NVT.ter → thermalization at constant density
|
||||
2. NVT.eq → production sampling
|
||||
|
||||
NPT Simulation
|
||||
Used when the equilibrium density is unknown.
|
||||
|
||||
1. NVT.ter → initial thermalization at approximate density
|
||||
2. NPT.ter → pressure relaxation (volume adjustment)
|
||||
3. NPT.eq → production sampling at target pressure
|
||||
|
||||
---------------------------------------------------------------------
|
||||
|
||||
DICE Input Keywords
|
||||
-------------------
|
||||
|
||||
The following keywords are used in the generated input files.
|
||||
|
||||
title
|
||||
Descriptive title printed in the simulation output.
|
||||
|
||||
ncores
|
||||
Number of CPU cores used by the DICE executable.
|
||||
|
||||
ljname
|
||||
File containing Lennard-Jones parameters and molecular topology.
|
||||
|
||||
outname
|
||||
Prefix used for simulation output files.
|
||||
|
||||
nmol
|
||||
Number of molecules of each species in the system.
|
||||
|
||||
dens
|
||||
System density (g/cm³). Used only in NVT simulations or for
|
||||
initialization of NPT runs.
|
||||
|
||||
press
|
||||
Target pressure used in NPT simulations.
|
||||
|
||||
temp
|
||||
Simulation temperature.
|
||||
|
||||
nstep
|
||||
Number of Monte Carlo cycles executed in the simulation stage.
|
||||
|
||||
init
|
||||
Defines how the simulation initializes molecular coordinates.
|
||||
|
||||
yes
|
||||
Random initial configuration.
|
||||
|
||||
no
|
||||
Continue from a previous configuration.
|
||||
|
||||
yesreadxyz
|
||||
Read coordinates from a previously saved XYZ configuration.
|
||||
|
||||
vstep
|
||||
Frequency of volume-change moves in NPT simulations.
|
||||
|
||||
mstop
|
||||
Molecule displacement control flag used internally by DICE.
|
||||
|
||||
accum
|
||||
Enables or disables accumulation of statistical averages.
|
||||
|
||||
iprint
|
||||
Frequency of simulation information printed to the output.
|
||||
|
||||
isave
|
||||
Frequency at which configurations are written to trajectory files.
|
||||
|
||||
irdf
|
||||
Controls calculation of radial distribution functions.
|
||||
|
||||
seed
|
||||
Random number generator seed used by the Monte Carlo algorithm.
|
||||
|
||||
upbuf
|
||||
Buffer size parameter used internally by DICE during thermalization.
|
||||
|
||||
---------------------------------------------------------------------
|
||||
|
||||
Output Files
|
||||
------------
|
||||
|
||||
Important output files produced during the simulation include:
|
||||
|
||||
phb.xyz
|
||||
XYZ trajectory containing sampled molecular configurations.
|
||||
|
||||
last.xyz
|
||||
Final configuration of the simulation, often used as the starting
|
||||
configuration for the next simulation cycle.
|
||||
|
||||
---------------------------------------------------------------------
|
||||
|
||||
References
|
||||
----------
|
||||
|
||||
DICE is a Monte Carlo molecular simulation program developed primarily
|
||||
by researchers at the University of São Paulo (USP) for studying liquids,
|
||||
solutions, and solvation phenomena.
|
||||
"""
|
||||
115
diceplayer/dice/dice_handler.py
Normal file
115
diceplayer/dice/dice_handler.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from diceplayer.dice.dice_input import (
|
||||
NPTEqConfig,
|
||||
NPTTerConfig,
|
||||
NVTEqConfig,
|
||||
NVTTerConfig,
|
||||
)
|
||||
from diceplayer.dice.dice_wrapper import DiceWrapper
|
||||
from diceplayer.logger import logger
|
||||
from diceplayer.state.state_model import StateModel
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
|
||||
|
||||
class DiceHandler:
|
||||
def __init__(self, step_directory: Path):
|
||||
self.dice_directory = step_directory / "dice"
|
||||
|
||||
def run(self, state: StateModel, cycle: int) -> StateModel:
|
||||
if self.dice_directory.exists():
|
||||
logger.info(
|
||||
f"Found dice directory: {self.dice_directory}, this directory will be purged for a clean state"
|
||||
)
|
||||
shutil.rmtree(self.dice_directory)
|
||||
self.dice_directory.mkdir(parents=True)
|
||||
|
||||
simulation_results = self.run_simulations(state, cycle)
|
||||
|
||||
result = self.aggregate_results(simulation_results)
|
||||
|
||||
return self.commit_simulation_state(state, result)
|
||||
|
||||
def run_simulations(self, state: StateModel, cycle: int) -> list[dict]:
|
||||
results = []
|
||||
|
||||
threads = []
|
||||
for p in range(state.config.dice.nprocs):
|
||||
t = Thread(target=self._simulation_process, args=(state, cycle, p, results))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
if len(results) != state.config.dice.nprocs:
|
||||
raise RuntimeError(
|
||||
f"Expected {state.config.dice.nprocs} simulation results, but got {len(results)}"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def aggregate_results(self, simulation_results: list[dict]) -> dict: ...
|
||||
|
||||
def commit_simulation_state(self, state: StateModel, result: dict) -> StateModel:
|
||||
return state
|
||||
|
||||
def _simulation_process(
|
||||
self, state: StateModel, cycle: int, proc: int, results: list[dict]
|
||||
) -> None:
|
||||
proc_directory = self.dice_directory / f"{proc:02d}"
|
||||
if proc_directory.exists():
|
||||
shutil.rmtree(proc_directory)
|
||||
proc_directory.mkdir(parents=True)
|
||||
|
||||
dice = DiceWrapper(state.config.dice, proc_directory)
|
||||
|
||||
self._generate_phb_file(state, proc_directory)
|
||||
|
||||
if state.config.dice.randominit == "first" and cycle >= 0:
|
||||
self._generate_last_xyz(state, proc_directory)
|
||||
else:
|
||||
nvt_ter_config = NVTTerConfig.from_config(state.config)
|
||||
dice.run(nvt_ter_config)
|
||||
|
||||
if len(state.config.dice.nstep) == 2:
|
||||
nvt_eq_config = NVTEqConfig.from_config(state.config)
|
||||
dice.run(nvt_eq_config)
|
||||
|
||||
elif len(state.config.dice.nstep) == 3:
|
||||
npt_ter_config = NPTTerConfig.from_config(state.config)
|
||||
dice.run(npt_ter_config)
|
||||
|
||||
npt_eq_config = NPTEqConfig.from_config(state.config)
|
||||
dice.run(npt_eq_config)
|
||||
|
||||
results.append(dice.parse_results(state.system))
|
||||
|
||||
@staticmethod
|
||||
def _generate_phb_file(state: StateModel, proc_directory: Path) -> None:
|
||||
fstr = "{:<3d} {:>3d} {:>10.5f} {:>10.5f} {:>10.5f} {:>10.6f} {:>9.5f} {:>7.4f}\n"
|
||||
|
||||
phb_file = proc_directory / state.config.dice.ljname
|
||||
|
||||
with open(phb_file, "w") as f:
|
||||
f.write(f"{state.config.dice.combrule}\n")
|
||||
f.write(f"{len(state.config.dice.nmol)}\n")
|
||||
|
||||
for molecule in state.system.molecule:
|
||||
f.write(f"{len(molecule.atom)} {molecule.molname}\n")
|
||||
for atom in molecule.atom:
|
||||
f.write(
|
||||
fstr.format(
|
||||
atom.lbl,
|
||||
atom.na,
|
||||
atom.rx,
|
||||
atom.ry,
|
||||
atom.rz,
|
||||
atom.chg,
|
||||
atom.eps,
|
||||
atom.sig,
|
||||
)
|
||||
)
|
||||
|
||||
def _generate_last_xyz(self, state: StateModel, proc_directory: Path) -> None: ...
|
||||
257
diceplayer/dice/dice_input.py
Normal file
257
diceplayer/dice/dice_input.py
Normal file
@@ -0,0 +1,257 @@
|
||||
from diceplayer.config import PlayerConfig
|
||||
from diceplayer.logger import logger
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Self
|
||||
|
||||
import random
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Literal, TextIO
|
||||
|
||||
|
||||
_ALLOWED_DICE_KEYWORD_IN_ORDER = [
|
||||
"title",
|
||||
"ncores",
|
||||
"ljname",
|
||||
"outname",
|
||||
"nmol",
|
||||
"dens",
|
||||
"temp",
|
||||
"press",
|
||||
"seed",
|
||||
"init",
|
||||
"nstep",
|
||||
"vstep",
|
||||
"mstop",
|
||||
"accum",
|
||||
"iprint",
|
||||
"isave",
|
||||
"irdf",
|
||||
"upbuf",
|
||||
]
|
||||
|
||||
|
||||
class DiceRoutineType(StrEnum):
|
||||
NVT_TER = "nvt.ter"
|
||||
NVT_EQ = "nvt.eq"
|
||||
NPT_TER = "npt.ter"
|
||||
NPT_EQ = "npt.eq"
|
||||
|
||||
|
||||
def get_nstep(config, idx: int) -> int:
|
||||
if len(config.dice.nstep) > idx:
|
||||
return config.dice.nstep[idx]
|
||||
return config.dice.nstep[-1]
|
||||
|
||||
|
||||
def get_seed(config) -> int:
|
||||
return config.dice.seed or random.randint(0, 2**32 - 1)
|
||||
|
||||
|
||||
def get_ncores(config) -> int:
|
||||
return max(1, int(config.ncores / config.dice.nprocs))
|
||||
|
||||
|
||||
# -----------------------------------------------------
|
||||
# NVT THERMALIZATION
|
||||
# -----------------------------------------------------
|
||||
class NVTTerConfig(BaseModel):
|
||||
type: Literal[DiceRoutineType.NVT_TER] = DiceRoutineType.NVT_TER
|
||||
|
||||
title: str = "NVT Thermalization"
|
||||
ncores: int
|
||||
ljname: str
|
||||
outname: str
|
||||
nmol: list[int]
|
||||
dens: float
|
||||
temp: float
|
||||
seed: int
|
||||
init: Literal["yes"] = "yes"
|
||||
nstep: int
|
||||
vstep: Literal[0] = 0
|
||||
isave: int
|
||||
upbuf: int
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
|
||||
return cls(
|
||||
ncores=get_ncores(config),
|
||||
ljname=str(config.dice.ljname),
|
||||
outname=config.dice.outname,
|
||||
nmol=config.dice.nmol,
|
||||
dens=config.dice.dens,
|
||||
temp=config.dice.temp,
|
||||
seed=get_seed(config),
|
||||
nstep=get_nstep(config, 0),
|
||||
isave=config.dice.isave,
|
||||
upbuf=config.dice.upbuf,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------
|
||||
# NVT PRODUCTION
|
||||
# -----------------------------------------------------
|
||||
class NVTEqConfig(BaseModel):
|
||||
type: Literal[DiceRoutineType.NVT_EQ] = DiceRoutineType.NVT_EQ
|
||||
|
||||
title: str = "NVT Production"
|
||||
ncores: int
|
||||
ljname: str
|
||||
outname: str
|
||||
nmol: list[int]
|
||||
dens: float
|
||||
temp: float
|
||||
seed: int
|
||||
init: Literal["no", "yesreadxyz"] = "no"
|
||||
nstep: int
|
||||
vstep: int
|
||||
isave: int
|
||||
irdf: Literal[0] = 0
|
||||
upbuf: int
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
|
||||
return cls(
|
||||
ncores=get_ncores(config),
|
||||
ljname=str(config.dice.ljname),
|
||||
outname=config.dice.outname,
|
||||
nmol=config.dice.nmol,
|
||||
dens=config.dice.dens,
|
||||
temp=config.dice.temp,
|
||||
seed=get_seed(config),
|
||||
nstep=get_nstep(config, 1),
|
||||
vstep=config.dice.vstep,
|
||||
isave=max(1, get_nstep(config, 1) // 10),
|
||||
upbuf=config.dice.upbuf,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------
|
||||
# NPT THERMALIZATION
|
||||
# -----------------------------------------------------
|
||||
class NPTTerConfig(BaseModel):
|
||||
type: Literal[DiceRoutineType.NPT_TER] = DiceRoutineType.NPT_TER
|
||||
|
||||
title: str = "NPT Thermalization"
|
||||
ncores: int
|
||||
ljname: str
|
||||
outname: str
|
||||
nmol: list[int]
|
||||
dens: float
|
||||
temp: float
|
||||
press: float
|
||||
seed: int
|
||||
init: Literal["yes", "yesreadxyz"] = "yes"
|
||||
nstep: int
|
||||
vstep: int
|
||||
isave: int
|
||||
upbuf: int
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
|
||||
return cls(
|
||||
ncores=get_ncores(config),
|
||||
ljname=str(config.dice.ljname),
|
||||
outname=config.dice.outname,
|
||||
nmol=config.dice.nmol,
|
||||
dens=config.dice.dens,
|
||||
temp=config.dice.temp,
|
||||
press=config.dice.press,
|
||||
seed=get_seed(config),
|
||||
nstep=get_nstep(config, 1),
|
||||
vstep=max(1, config.dice.vstep),
|
||||
isave=config.dice.isave,
|
||||
upbuf=config.dice.upbuf,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------
|
||||
# NPT PRODUCTION
|
||||
# -----------------------------------------------------
|
||||
class NPTEqConfig(BaseModel):
|
||||
type: Literal[DiceRoutineType.NPT_EQ] = DiceRoutineType.NPT_EQ
|
||||
|
||||
title: str = "NPT Production"
|
||||
ncores: int
|
||||
ljname: str
|
||||
outname: str
|
||||
nmol: list[int]
|
||||
dens: float
|
||||
temp: float
|
||||
press: float
|
||||
seed: int
|
||||
init: Literal["yes", "yesreadxyz"] = "yes"
|
||||
nstep: int
|
||||
vstep: int
|
||||
isave: int
|
||||
irdf: Literal[0] = 0
|
||||
upbuf: int
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
|
||||
return cls(
|
||||
ncores=get_ncores(config),
|
||||
ljname=str(config.dice.ljname),
|
||||
outname=config.dice.outname,
|
||||
nmol=config.dice.nmol,
|
||||
dens=config.dice.dens,
|
||||
temp=config.dice.temp,
|
||||
press=config.dice.press,
|
||||
seed=get_seed(config),
|
||||
nstep=get_nstep(config, 2),
|
||||
vstep=config.dice.vstep,
|
||||
isave=max(1, get_nstep(config, 2) // 10),
|
||||
upbuf=config.dice.upbuf,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
DiceInputConfig = Annotated[
|
||||
NVTTerConfig | NVTEqConfig | NPTTerConfig | NPTEqConfig,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
def _serialize_value(value: Any) -> str:
|
||||
if value is None:
|
||||
raise ValueError("DICE configuration cannot serialize None values")
|
||||
|
||||
if isinstance(value, bool):
|
||||
return "yes" if value else "no"
|
||||
|
||||
if isinstance(value, (list, tuple)):
|
||||
return " ".join(str(v) for v in value)
|
||||
|
||||
return str(value)
|
||||
|
||||
|
||||
def write_dice_config(obj: DiceInputConfig, io_writer: TextIO) -> None:
|
||||
values = {f: getattr(obj, f) for f in obj.__class__.model_fields}
|
||||
|
||||
for key in _ALLOWED_DICE_KEYWORD_IN_ORDER:
|
||||
value = values.pop(key, None)
|
||||
if value is None:
|
||||
continue
|
||||
io_writer.write(f"{key} = {_serialize_value(value)}\n")
|
||||
|
||||
io_writer.write("$end\n")
|
||||
|
||||
|
||||
def write_config(config: DiceInputConfig, directory: Path) -> Path:
|
||||
input_path = directory / config.type
|
||||
|
||||
if input_path.exists():
|
||||
logger.info(
|
||||
f"Dice input file {input_path} already exists and will be overwritten"
|
||||
)
|
||||
input_path.unlink()
|
||||
input_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(input_path, "w") as io:
|
||||
write_dice_config(config, io)
|
||||
|
||||
return input_path
|
||||
43
diceplayer/dice/dice_wrapper.py
Normal file
43
diceplayer/dice/dice_wrapper.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import diceplayer.dice.dice_input as dice_input
|
||||
from diceplayer.config import DiceConfig
|
||||
from diceplayer.environment import System
|
||||
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Final
|
||||
|
||||
|
||||
DICE_FLAG_LINE: Final[int] = -2
|
||||
DICE_END_FLAG: Final[str] = "End of simulation"
|
||||
|
||||
|
||||
class DiceWrapper:
|
||||
def __init__(self, dice_config: DiceConfig, working_directory: Path):
|
||||
self.dice_config = dice_config
|
||||
self.working_directory = working_directory
|
||||
|
||||
def run(self, dice_config: dice_input.DiceInputConfig) -> None:
|
||||
input_path = dice_input.write_config(dice_config, self.working_directory)
|
||||
output_path = input_path.parent / (input_path.name + ".out")
|
||||
|
||||
with open(output_path, "w") as outfile, open(input_path, "r") as infile:
|
||||
exit_status = subprocess.call(
|
||||
self.dice_config.progname, stdin=infile, stdout=outfile, cwd=self.working_directory
|
||||
)
|
||||
|
||||
if exit_status != 0:
|
||||
raise RuntimeError(f"Dice simulation failed with exit status {exit_status}")
|
||||
|
||||
with open(output_path, "r") as outfile:
|
||||
line = outfile.readlines()[DICE_FLAG_LINE]
|
||||
if line.strip() == DICE_END_FLAG:
|
||||
return
|
||||
|
||||
raise RuntimeError(f"Dice simulation failed with exit status {exit_status}")
|
||||
|
||||
def parse_results(self, system: System) -> dict:
|
||||
results = {}
|
||||
for output_file in sorted(self.working_directory.glob("phb*.xyz")):
|
||||
...
|
||||
|
||||
return results
|
||||
@@ -1,9 +1,9 @@
|
||||
from diceplayer.utils.ptable import AtomInfo, PTable
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(slots=True)
|
||||
class Atom:
|
||||
"""
|
||||
Atom class declaration. This class is used throughout the DicePlayer program to represent atoms.
|
||||
|
||||
@@ -1,19 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from diceplayer import logger
|
||||
from diceplayer.environment import Atom
|
||||
from diceplayer.logger import logger
|
||||
from diceplayer.utils.cache import invalidate_computed_properties
|
||||
from diceplayer.utils.misc import BOHR2ANG, EA_2_DEBYE
|
||||
from diceplayer.utils.ptable import GHOST_NUMBER
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import numpy.linalg as linalg
|
||||
import numpy.typing as npt
|
||||
from pydantic.dataclasses import dataclass
|
||||
from typing_extensions import List, Self, Tuple
|
||||
|
||||
import math
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import field
|
||||
from functools import cached_property
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
from .__interface import Interface
|
||||
from .dice_interface import DiceInterface
|
||||
from .gaussian_interface import GaussianInterface
|
||||
|
||||
|
||||
__all__ = ["Interface", "DiceInterface", "GaussianInterface"]
|
||||
@@ -1,26 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from diceplayer.config.player_config import PlayerConfig
|
||||
from diceplayer.environment.system import System
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Interface(ABC):
|
||||
__slots__ = ["step", "system"]
|
||||
|
||||
def __init__(self):
|
||||
self.system: System | None = None
|
||||
self.step: PlayerConfig | None = None
|
||||
|
||||
@abstractmethod
|
||||
def configure(self, step: PlayerConfig, system: System):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start(self, cycle: int):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
pass
|
||||
@@ -1,389 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from diceplayer import logger
|
||||
from diceplayer.config.player_config import PlayerConfig
|
||||
from diceplayer.environment.system import System
|
||||
from diceplayer.interface import Interface
|
||||
|
||||
from setproctitle import setproctitle
|
||||
from typing_extensions import Final, TextIO
|
||||
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from multiprocessing import Process, connection
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
DICE_END_FLAG: Final[str] = "End of simulation"
|
||||
DICE_FLAG_LINE: Final[int] = -2
|
||||
UMAANG3_TO_GCM3: Final[float] = 1.6605
|
||||
|
||||
MAX_SEED: Final[int] = 4294967295
|
||||
|
||||
|
||||
class DiceInterface(Interface):
|
||||
title = "Diceplayer run"
|
||||
|
||||
def configure(self, step: PlayerConfig, system: System):
|
||||
self.step = step
|
||||
self.system = system
|
||||
|
||||
def start(self, cycle: int):
|
||||
procs = []
|
||||
sentinels = []
|
||||
|
||||
for proc in range(1, self.step.nprocs + 1):
|
||||
p = Process(target=self._simulation_process, args=(cycle, proc))
|
||||
p.start()
|
||||
|
||||
procs.append(p)
|
||||
sentinels.append(p.sentinel)
|
||||
|
||||
while procs:
|
||||
finished = connection.wait(sentinels)
|
||||
for proc_sentinel in finished:
|
||||
i = sentinels.index(proc_sentinel)
|
||||
status = procs[i].exitcode
|
||||
procs.pop(i)
|
||||
sentinels.pop(i)
|
||||
if status != 0:
|
||||
for p in procs:
|
||||
p.terminate()
|
||||
sys.exit(status)
|
||||
|
||||
logger.info("\n")
|
||||
|
||||
def reset(self):
|
||||
del self.step
|
||||
del self.system
|
||||
|
||||
def _simulation_process(self, cycle: int, proc: int):
|
||||
setproctitle(f"diceplayer-step{cycle:0d}-p{proc:0d}")
|
||||
|
||||
try:
|
||||
self._make_proc_dir(cycle, proc)
|
||||
self._make_dice_inputs(cycle, proc)
|
||||
self._run_dice(cycle, proc)
|
||||
except Exception as err:
|
||||
sys.exit(err)
|
||||
|
||||
def _make_proc_dir(self, cycle, proc):
|
||||
simulation_dir = Path(self.step.simulation_dir)
|
||||
if not simulation_dir.exists():
|
||||
simulation_dir.mkdir(parents=True)
|
||||
|
||||
proc_dir = Path(simulation_dir, f"step{cycle:02d}", f"p{proc:02d}")
|
||||
proc_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _make_dice_inputs(self, cycle, proc):
|
||||
proc_dir = Path(self.step.simulation_dir, f"step{cycle:02d}", f"p{proc:02d}")
|
||||
|
||||
self._make_potentials(proc_dir)
|
||||
|
||||
random.seed(self._make_dice_seed())
|
||||
|
||||
# This is logic is used to make the initial configuration file
|
||||
# for the next cycle using the last.xyz file from the previous cycle.
|
||||
if self.step.dice.randominit == "first" and cycle > 1:
|
||||
last_xyz = Path(
|
||||
self.step.simulation_dir,
|
||||
f"step{(cycle - 1):02d}",
|
||||
f"p{proc:02d}",
|
||||
"last.xyz",
|
||||
)
|
||||
if not last_xyz.exists():
|
||||
raise FileNotFoundError(f"File {last_xyz} not found.")
|
||||
|
||||
with open(last_xyz, "r") as last_xyz_file:
|
||||
self._make_init_file(proc_dir, last_xyz_file)
|
||||
last_xyz_file.seek(0)
|
||||
self.step.dice.dens = self._new_density(last_xyz_file)
|
||||
|
||||
else:
|
||||
self._make_nvt_ter(cycle, proc_dir)
|
||||
|
||||
if len(self.step.dice.nstep) == 2:
|
||||
self._make_nvt_eq(cycle, proc_dir)
|
||||
|
||||
elif len(self.step.dice.nstep) == 3:
|
||||
self._make_npt_ter(cycle, proc_dir)
|
||||
self._make_npt_eq(proc_dir)
|
||||
|
||||
def _run_dice(self, cycle: int, proc: int):
|
||||
working_dir = os.getcwd()
|
||||
|
||||
proc_dir = Path(self.step.simulation_dir, f"step{cycle:02d}", f"p{proc:02d}")
|
||||
|
||||
logger.info(
|
||||
f"Simulation process {str(proc_dir)} initiated with pid {os.getpid()}"
|
||||
)
|
||||
|
||||
os.chdir(proc_dir)
|
||||
|
||||
if not (self.step.dice.randominit == "first" and cycle > 1):
|
||||
self.run_dice_file(cycle, proc, "NVT.ter")
|
||||
|
||||
if len(self.step.dice.nstep) == 2:
|
||||
self.run_dice_file(cycle, proc, "NVT.eq")
|
||||
|
||||
elif len(self.step.dice.nstep) == 3:
|
||||
self.run_dice_file(cycle, proc, "NPT.ter")
|
||||
self.run_dice_file(cycle, proc, "NPT.eq")
|
||||
|
||||
os.chdir(working_dir)
|
||||
|
||||
xyz_file = Path(proc_dir, "phb.xyz")
|
||||
last_xyz_file = Path(proc_dir, "last.xyz")
|
||||
|
||||
if xyz_file.exists():
|
||||
shutil.copy(xyz_file, last_xyz_file)
|
||||
else:
|
||||
raise FileNotFoundError(f"File {xyz_file} not found.")
|
||||
|
||||
@staticmethod
|
||||
def _make_dice_seed() -> int:
|
||||
num = time.time()
|
||||
num = (num - int(num)) * 1e6
|
||||
|
||||
num = int((num - int(num)) * 1e6)
|
||||
|
||||
return (os.getpid() * num) % (MAX_SEED + 1)
|
||||
|
||||
def _make_init_file(self, proc_dir: Path, last_xyz_file: TextIO):
|
||||
xyz_lines = last_xyz_file.readlines()
|
||||
|
||||
SECONDARY_MOLECULE_LENGTH = 0
|
||||
for i in range(1, len(self.step.dice.nmol)):
|
||||
SECONDARY_MOLECULE_LENGTH += self.step.dice.nmol[i] * len(
|
||||
self.system.molecule[i].atom
|
||||
)
|
||||
|
||||
xyz_lines = xyz_lines[-SECONDARY_MOLECULE_LENGTH:]
|
||||
|
||||
input_file = Path(proc_dir, self.step.dice.outname + ".xy")
|
||||
with open(input_file, "w") as f:
|
||||
for atom in self.system.molecule[0].atom:
|
||||
f.write(f"{atom.rx:>10.6f} {atom.ry:>10.6f} {atom.rz:>10.6f}\n")
|
||||
|
||||
for line in xyz_lines:
|
||||
atom = line.split()
|
||||
rx = float(atom[1])
|
||||
ry = float(atom[2])
|
||||
rz = float(atom[3])
|
||||
f.write(f"{rx:>10.6f} {ry:>10.6f} {rz:>10.6f}\n")
|
||||
|
||||
f.write("$end")
|
||||
|
||||
def _new_density(self, last_xyz_file: TextIO):
|
||||
last_xyz_lines = last_xyz_file.readlines()
|
||||
|
||||
box = last_xyz_lines[1].split()
|
||||
volume = float(box[-3]) * float(box[-2]) * float(box[-1])
|
||||
|
||||
total_mass = 0
|
||||
for i in range(len(self.system.molecule)):
|
||||
total_mass += self.system.molecule[i].total_mass * self.step.dice.nmol[i]
|
||||
|
||||
density = (total_mass / volume) * UMAANG3_TO_GCM3
|
||||
|
||||
return density
|
||||
|
||||
def _make_nvt_ter(self, cycle, proc_dir):
|
||||
file = Path(proc_dir, "NVT.ter")
|
||||
with open(file, "w") as f:
|
||||
f.write(f"title = {self.title} - NVT Thermalization\n")
|
||||
f.write(f"ncores = {self.step.ncores}\n")
|
||||
f.write(f"ljname = {self.step.dice.ljname}\n")
|
||||
f.write(f"outname = {self.step.dice.outname}\n")
|
||||
|
||||
mol_string = " ".join(str(x) for x in self.step.dice.nmol)
|
||||
f.write(f"nmol = {mol_string}\n")
|
||||
|
||||
f.write(f"dens = {self.step.dice.dens}\n")
|
||||
f.write(f"temp = {self.step.dice.temp}\n")
|
||||
|
||||
if self.step.dice.randominit == "first" and cycle > 1:
|
||||
f.write("init = yesreadxyz\n")
|
||||
f.write(f"nstep = {self.step.altsteps}\n")
|
||||
else:
|
||||
f.write("init = yes\n")
|
||||
f.write(f"nstep = {self.step.dice.nstep[0]}\n")
|
||||
|
||||
f.write("vstep = 0\n")
|
||||
f.write("mstop = 1\n")
|
||||
f.write("accum = no\n")
|
||||
f.write("iprint = 1\n")
|
||||
f.write("isave = 0\n")
|
||||
f.write("irdf = 0\n")
|
||||
|
||||
seed = int(1e6 * random.random())
|
||||
f.write(f"seed = {seed}\n")
|
||||
f.write(f"upbuf = {self.step.dice.upbuf}")
|
||||
|
||||
def _make_nvt_eq(self, cycle, proc_dir):
|
||||
file = Path(proc_dir, "NVT.eq")
|
||||
with open(file, "w") as f:
|
||||
f.write(f"title = {self.title} - NVT Production\n")
|
||||
f.write(f"ncores = {self.step.ncores}\n")
|
||||
f.write(f"ljname = {self.step.dice.ljname}\n")
|
||||
f.write(f"outname = {self.step.dice.outname}\n")
|
||||
|
||||
mol_string = " ".join(str(x) for x in self.step.dice.nmol)
|
||||
f.write(f"nmol = {mol_string}\n")
|
||||
|
||||
f.write(f"dens = {self.step.dice.dens}\n")
|
||||
f.write(f"temp = {self.step.dice.temp}\n")
|
||||
|
||||
if self.step.dice.randominit == "first" and cycle > 1:
|
||||
f.write("init = yesreadxyz\n")
|
||||
else:
|
||||
f.write("init = no\n")
|
||||
|
||||
f.write(f"nstep = {self.step.dice.nstep[1]}\n")
|
||||
|
||||
f.write("vstep = 0\n")
|
||||
f.write("mstop = 1\n")
|
||||
f.write("accum = no\n")
|
||||
f.write("iprint = 1\n")
|
||||
|
||||
f.write(f"isave = {self.step.dice.isave}\n")
|
||||
f.write(f"irdf = {10 * self.step.nprocs}\n")
|
||||
|
||||
seed = int(1e6 * random.random())
|
||||
f.write("seed = {}\n".format(seed))
|
||||
|
||||
def _make_npt_ter(self, cycle, proc_dir):
|
||||
file = Path(proc_dir, "NPT.ter")
|
||||
with open(file, "w") as f:
|
||||
f.write(f"title = {self.title} - NPT Thermalization\n")
|
||||
f.write(f"ncores = {self.step.ncores}\n")
|
||||
f.write(f"ljname = {self.step.dice.ljname}\n")
|
||||
f.write(f"outname = {self.step.dice.outname}\n")
|
||||
|
||||
mol_string = " ".join(str(x) for x in self.step.dice.nmol)
|
||||
f.write(f"nmol = {mol_string}\n")
|
||||
|
||||
f.write(f"press = {self.step.dice.press}\n")
|
||||
f.write(f"temp = {self.step.dice.temp}\n")
|
||||
|
||||
if self.step.dice.randominit == "first" and cycle > 1:
|
||||
f.write("init = yesreadxyz\n")
|
||||
f.write(f"dens = {self.step.dice.dens:<8.4f}\n")
|
||||
f.write(f"vstep = {int(self.step.altsteps / 5)}\n")
|
||||
else:
|
||||
f.write("init = no\n")
|
||||
f.write(f"vstep = {int(self.step.dice.nstep[1] / 5)}\n")
|
||||
|
||||
f.write("nstep = 5\n")
|
||||
f.write("mstop = 1\n")
|
||||
f.write("accum = no\n")
|
||||
f.write("iprint = 1\n")
|
||||
f.write("isave = 0\n")
|
||||
f.write("irdf = 0\n")
|
||||
|
||||
seed = int(1e6 * random.random())
|
||||
f.write(f"seed = {seed}\n")
|
||||
|
||||
def _make_npt_eq(self, proc_dir):
|
||||
file = Path(proc_dir, "NPT.eq")
|
||||
with open(file, "w") as f:
|
||||
f.write(f"title = {self.title} - NPT Production\n")
|
||||
f.write(f"ncores = {self.step.ncores}\n")
|
||||
f.write(f"ljname = {self.step.dice.ljname}\n")
|
||||
f.write(f"outname = {self.step.dice.outname}\n")
|
||||
|
||||
mol_string = " ".join(str(x) for x in self.step.dice.nmol)
|
||||
f.write(f"nmol = {mol_string}\n")
|
||||
|
||||
f.write(f"press = {self.step.dice.press}\n")
|
||||
f.write(f"temp = {self.step.dice.temp}\n")
|
||||
|
||||
f.write("nstep = 5\n")
|
||||
|
||||
f.write(f"vstep = {int(self.step.dice.nstep[2] / 5)}\n")
|
||||
f.write("init = no\n")
|
||||
f.write("mstop = 1\n")
|
||||
f.write("accum = no\n")
|
||||
f.write("iprint = 1\n")
|
||||
f.write(f"isave = {self.step.dice.isave}\n")
|
||||
f.write(f"irdf = {10 * self.step.nprocs}\n")
|
||||
|
||||
seed = int(1e6 * random.random())
|
||||
f.write(f"seed = {seed}\n")
|
||||
|
||||
def _make_potentials(self, proc_dir):
|
||||
fstr = "{:<3d} {:>3d} {:>10.5f} {:>10.5f} {:>10.5f} {:>10.6f} {:>9.5f} {:>7.4f}\n"
|
||||
|
||||
file = Path(proc_dir, self.step.dice.ljname)
|
||||
with open(file, "w") as f:
|
||||
f.write(f"{self.step.dice.combrule}\n")
|
||||
f.write(f"{len(self.step.dice.nmol)}\n")
|
||||
|
||||
nsites_qm = len(self.system.molecule[0].atom)
|
||||
f.write(f"{nsites_qm} {self.system.molecule[0].molname}\n")
|
||||
|
||||
for atom in self.system.molecule[0].atom:
|
||||
f.write(
|
||||
fstr.format(
|
||||
atom.lbl,
|
||||
atom.na,
|
||||
atom.rx,
|
||||
atom.ry,
|
||||
atom.rz,
|
||||
atom.chg,
|
||||
atom.eps,
|
||||
atom.sig,
|
||||
)
|
||||
)
|
||||
|
||||
for mol in self.system.molecule[1:]:
|
||||
f.write(f"{len(mol.atom)} {mol.molname}\n")
|
||||
for atom in mol.atom:
|
||||
f.write(
|
||||
fstr.format(
|
||||
atom.lbl,
|
||||
atom.na,
|
||||
atom.rx,
|
||||
atom.ry,
|
||||
atom.rz,
|
||||
atom.chg,
|
||||
atom.eps,
|
||||
atom.sig,
|
||||
)
|
||||
)
|
||||
|
||||
def run_dice_file(self, cycle: int, proc: int, file_name: str):
|
||||
with (
|
||||
open(Path(file_name), "r") as infile,
|
||||
open(Path(file_name + ".out"), "w") as outfile,
|
||||
):
|
||||
if shutil.which("bash") is not None:
|
||||
exit_status = subprocess.call(
|
||||
[
|
||||
"bash",
|
||||
"-c",
|
||||
f"exec -a dice-step{cycle}-p{proc} {self.step.dice.progname} < {infile.name} > {outfile.name}",
|
||||
]
|
||||
)
|
||||
else:
|
||||
exit_status = subprocess.call(
|
||||
self.step.dice.progname, stdin=infile, stdout=outfile
|
||||
)
|
||||
|
||||
if exit_status != 0:
|
||||
raise RuntimeError(
|
||||
f"Dice process step{cycle:02d}-p{proc:02d} did not exit properly"
|
||||
)
|
||||
|
||||
with open(Path(file_name + ".out"), "r") as outfile:
|
||||
flag = outfile.readlines()[DICE_FLAG_LINE].strip()
|
||||
if flag != DICE_END_FLAG:
|
||||
raise RuntimeError(
|
||||
f"Dice process step{cycle:02d}-p{proc:02d} did not exit properly"
|
||||
)
|
||||
|
||||
logger.info(f"Dice {file_name} - step{cycle:02d}-p{proc:02d} exited properly")
|
||||
@@ -1,359 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from diceplayer import logger
|
||||
from diceplayer.config.player_config import PlayerConfig
|
||||
from diceplayer.environment import Atom
|
||||
from diceplayer.environment.molecule import Molecule
|
||||
from diceplayer.environment.system import System
|
||||
from diceplayer.interface import Interface
|
||||
from diceplayer.utils.misc import date_time
|
||||
from diceplayer.utils.ptable import PTable
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from typing_extensions import Any, Dict, List, Tuple
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class GaussianInterface(Interface):
|
||||
def configure(self, step_dto: PlayerConfig, system: System):
|
||||
self.system = system
|
||||
self.step = step_dto
|
||||
|
||||
def start(self, cycle: int) -> Dict[str, NDArray]:
|
||||
self._make_qm_dir(cycle)
|
||||
|
||||
if cycle > 1:
|
||||
self._copy_chk_file_from_previous_step(cycle)
|
||||
|
||||
asec_charges = self.populate_asec_vdw(cycle)
|
||||
self._make_gaussian_input_file(cycle, asec_charges)
|
||||
|
||||
self._run_gaussian(cycle)
|
||||
self._run_formchk(cycle)
|
||||
|
||||
return_value = {}
|
||||
if self.step.opt:
|
||||
# return_value['position'] = np.array(
|
||||
# self._run_optimization(cycle)
|
||||
# )
|
||||
raise NotImplementedError("Optimization not implemented yet.")
|
||||
|
||||
else:
|
||||
return_value["charges"] = np.array(self._read_charges_from_fchk(cycle))
|
||||
|
||||
return return_value
|
||||
|
||||
def reset(self):
|
||||
del self.step
|
||||
del self.system
|
||||
|
||||
def _make_qm_dir(self, cycle: int):
|
||||
qm_dir_path = Path(self.step.simulation_dir, f"step{cycle:02d}", "qm")
|
||||
if not qm_dir_path.exists():
|
||||
qm_dir_path.mkdir()
|
||||
|
||||
def _copy_chk_file_from_previous_step(self, cycle: int):
|
||||
current_chk_file_path = Path(
|
||||
self.step.simulation_dir, f"step{cycle:02d}", "qm", "asec.chk"
|
||||
)
|
||||
if current_chk_file_path.exists():
|
||||
raise FileExistsError(f"File {current_chk_file_path} already exists.")
|
||||
|
||||
previous_chk_file_path = Path(
|
||||
self.step.simulation_dir, f"step{(cycle - 1):02d}", "qm", "asec.chk"
|
||||
)
|
||||
if not previous_chk_file_path.exists():
|
||||
raise FileNotFoundError(f"File {previous_chk_file_path} does not exist.")
|
||||
|
||||
shutil.copy(previous_chk_file_path, current_chk_file_path)
|
||||
|
||||
def populate_asec_vdw(self, cycle: int) -> list[dict]:
|
||||
norm_factor = self._calculate_norm_factor()
|
||||
|
||||
nsitesref = len(self.system.molecule[0].atom)
|
||||
|
||||
nsites_total = self._calculate_total_number_of_sites(nsitesref)
|
||||
|
||||
proc_charges = []
|
||||
for proc in range(1, self.step.nprocs + 1):
|
||||
proc_charges.append(self._read_charges_from_last_step(cycle, proc))
|
||||
|
||||
asec_charges, thickness, picked_mols = self._evaluate_proc_charges(
|
||||
nsites_total, proc_charges
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"In average, {(sum(picked_mols) / norm_factor):^7.2f} molecules\n"
|
||||
f"were selected from each of the {len(picked_mols)} configurations\n"
|
||||
f"of the production simulations to form the ASEC, comprising a shell with\n"
|
||||
f"minimum thickness of {(sum(thickness) / norm_factor):>6.2f} Angstrom\n"
|
||||
)
|
||||
|
||||
for charge in asec_charges:
|
||||
charge["chg"] = charge["chg"] / norm_factor
|
||||
|
||||
return asec_charges
|
||||
|
||||
def _calculate_norm_factor(self) -> int:
|
||||
if self.step.dice.nstep[-1] % self.step.dice.isave == 0:
|
||||
nconfigs = round(self.step.dice.nstep[-1] / self.step.dice.isave)
|
||||
else:
|
||||
nconfigs = int(self.step.dice.nstep[-1] / self.step.dice.isave)
|
||||
|
||||
return nconfigs * self.step.nprocs
|
||||
|
||||
def _calculate_total_number_of_sites(self, nsitesref) -> int:
|
||||
nsites_total = self.step.dice.nmol[0] * nsitesref
|
||||
for i in range(1, len(self.step.dice.nmol)):
|
||||
nsites_total += self.step.dice.nmol[i] * len(self.system.molecule[i].atom)
|
||||
|
||||
return nsites_total
|
||||
|
||||
def _read_charges_from_last_step(self, cycle: int, proc: int) -> list[str]:
|
||||
last_xyz_file_path = Path(
|
||||
self.step.simulation_dir, f"step{cycle:02d}", f"p{proc:02d}", "last.xyz"
|
||||
)
|
||||
if not last_xyz_file_path.exists():
|
||||
raise FileNotFoundError(f"File {last_xyz_file_path} does not exist.")
|
||||
|
||||
with open(last_xyz_file_path, "r") as last_xyz_file:
|
||||
lines = last_xyz_file.readlines()
|
||||
|
||||
return lines
|
||||
|
||||
def _evaluate_proc_charges(
|
||||
self, total_nsites: int, proc_charges: list[list[str]]
|
||||
) -> Tuple[List[Dict[str, float | Any]], List[float], List[int]]:
|
||||
asec_charges = []
|
||||
|
||||
thickness = []
|
||||
picked_mols = []
|
||||
|
||||
for charges in proc_charges:
|
||||
charges_nsites = int(charges.pop(0))
|
||||
if int(charges_nsites) != total_nsites:
|
||||
raise ValueError(
|
||||
"Number of sites does not match total number of sites."
|
||||
)
|
||||
|
||||
thickness.append(self._calculate_proc_thickness(charges))
|
||||
nsites_ref_mol = len(self.system.molecule[0].atom)
|
||||
charges = charges[nsites_ref_mol:]
|
||||
|
||||
mol_count = 0
|
||||
for type in range(len(self.step.dice.nmol)):
|
||||
if type == 0:
|
||||
# Reference Molecule must be ignored from type 0
|
||||
nmols = self.step.dice.nmol[type] - 1
|
||||
else:
|
||||
nmols = self.step.dice.nmol[type]
|
||||
|
||||
for mol in range(nmols):
|
||||
new_molecule = Molecule("ASEC TMP MOLECULE")
|
||||
for site in range(len(self.system.molecule[type].atom)):
|
||||
line = charges.pop(0).split()
|
||||
|
||||
if (
|
||||
line[0].title()
|
||||
!= PTable.get_atomic_symbol(
|
||||
self.system.molecule[type].atom[site].na
|
||||
).strip()
|
||||
):
|
||||
raise SyntaxError(
|
||||
"Error: Invalid Dice Output. Atom type does not match."
|
||||
)
|
||||
|
||||
new_molecule.add_atom(
|
||||
Atom(
|
||||
self.system.molecule[type].atom[site].lbl,
|
||||
self.system.molecule[type].atom[site].na,
|
||||
float(line[1]),
|
||||
float(line[2]),
|
||||
float(line[3]),
|
||||
self.system.molecule[type].atom[site].chg,
|
||||
self.system.molecule[type].atom[site].eps,
|
||||
self.system.molecule[type].atom[site].sig,
|
||||
)
|
||||
)
|
||||
|
||||
distance = self.system.molecule[0].minimum_distance(new_molecule)
|
||||
|
||||
if distance < thickness[-1]:
|
||||
for atom in new_molecule.atom:
|
||||
asec_charges.append(
|
||||
{
|
||||
"lbl": PTable.get_atomic_symbol(atom.na),
|
||||
"rx": atom.rx,
|
||||
"ry": atom.ry,
|
||||
"rz": atom.rz,
|
||||
"chg": atom.chg,
|
||||
}
|
||||
)
|
||||
mol_count += 1
|
||||
|
||||
picked_mols.append(mol_count)
|
||||
|
||||
return asec_charges, thickness, picked_mols
|
||||
|
||||
def _calculate_proc_thickness(self, charges: list[str]) -> float:
|
||||
box = charges.pop(0).split()[-3:]
|
||||
box = [float(box[0]), float(box[1]), float(box[2])]
|
||||
sizes = self.system.molecule[0].sizes_of_molecule()
|
||||
|
||||
return min(
|
||||
[
|
||||
(box[0] - sizes[0]) / 2,
|
||||
(box[1] - sizes[1]) / 2,
|
||||
(box[2] - sizes[2]) / 2,
|
||||
]
|
||||
)
|
||||
|
||||
def _make_gaussian_input_file(self, cycle: int, asec_charges: list[dict]) -> None:
|
||||
gaussian_input_file_path = Path(
|
||||
self.step.simulation_dir, f"step{cycle:02d}", "qm", "asec.gjf"
|
||||
)
|
||||
|
||||
with open(gaussian_input_file_path, "w") as gaussian_input_file:
|
||||
gaussian_input_file.writelines(
|
||||
self._generate_gaussian_input(cycle, asec_charges)
|
||||
)
|
||||
|
||||
def _generate_gaussian_input(
|
||||
self, cycle: int, asec_charges: list[dict]
|
||||
) -> list[str]:
|
||||
gaussian_input = ["%Chk=asec.chk\n"]
|
||||
|
||||
if self.step.mem is not None:
|
||||
gaussian_input.append(f"%Mem={self.step.mem}GB\n")
|
||||
|
||||
gaussian_input.append(f"%Nprocs={self.step.nprocs * self.step.ncores}\n")
|
||||
|
||||
kwords_line = f"#P {self.step.gaussian.level}"
|
||||
|
||||
if self.step.gaussian.keywords:
|
||||
kwords_line += " " + self.step.gaussian.keywords
|
||||
|
||||
if self.step.opt == "yes":
|
||||
kwords_line += " Force"
|
||||
|
||||
kwords_line += " NoSymm"
|
||||
kwords_line += f" Pop={self.step.gaussian.pop} Density=Current"
|
||||
|
||||
if cycle > 1:
|
||||
kwords_line += " Guess=Read"
|
||||
|
||||
gaussian_input.append(textwrap.fill(kwords_line, 90))
|
||||
gaussian_input.append("\n")
|
||||
|
||||
gaussian_input.append("\nForce calculation - Cycle number {}\n".format(cycle))
|
||||
gaussian_input.append("\n")
|
||||
gaussian_input.append(
|
||||
f"{self.step.gaussian.chgmult[0]},{self.step.gaussian.chgmult[1]}\n"
|
||||
)
|
||||
|
||||
for atom in self.system.molecule[0].atom:
|
||||
symbol = PTable.get_atomic_symbol(atom.na)
|
||||
gaussian_input.append(
|
||||
"{:<2s} {:>10.5f} {:>10.5f} {:>10.5f}\n".format(
|
||||
symbol, atom.rx, atom.ry, atom.rz
|
||||
)
|
||||
)
|
||||
|
||||
gaussian_input.append("\n")
|
||||
|
||||
for charge in asec_charges:
|
||||
gaussian_input.append(
|
||||
"{:>10.5f} {:>10.5f} {:>10.5f} {:>11.8f}\n".format(
|
||||
charge["rx"], charge["ry"], charge["rz"], charge["chg"]
|
||||
)
|
||||
)
|
||||
|
||||
gaussian_input.append("\n")
|
||||
|
||||
return gaussian_input
|
||||
|
||||
def _run_gaussian(self, cycle: int) -> None:
|
||||
qm_dir = Path(self.step.simulation_dir, f"step{(cycle):02d}", "qm")
|
||||
|
||||
working_dir = os.getcwd()
|
||||
os.chdir(qm_dir)
|
||||
|
||||
infile = "asec.gjf"
|
||||
|
||||
operation = None
|
||||
if self.step.opt:
|
||||
operation = "forces"
|
||||
else:
|
||||
operation = "charges"
|
||||
|
||||
logger.info(
|
||||
f"Calculation of {operation} initiated with Gaussian on {date_time()}\n"
|
||||
)
|
||||
|
||||
if shutil.which("bash") is not None:
|
||||
exit_status = subprocess.call(
|
||||
[
|
||||
"bash",
|
||||
"-c",
|
||||
"exec -a {}-step{} {} {}".format(
|
||||
self.step.gaussian.qmprog,
|
||||
cycle,
|
||||
self.step.gaussian.qmprog,
|
||||
infile,
|
||||
),
|
||||
]
|
||||
)
|
||||
else:
|
||||
exit_status = subprocess.call([self.step.gaussian.qmprog, infile])
|
||||
|
||||
if exit_status != 0:
|
||||
raise SystemError("Gaussian process did not exit properly")
|
||||
|
||||
logger.info(f"Calculation of {operation} finished on {date_time()}")
|
||||
|
||||
os.chdir(working_dir)
|
||||
|
||||
def _run_formchk(self, cycle: int):
|
||||
qm_dir = Path(self.step.simulation_dir, f"step{(cycle):02d}", "qm")
|
||||
|
||||
work_dir = os.getcwd()
|
||||
os.chdir(qm_dir)
|
||||
|
||||
logger.info("Formatting the checkpoint file... \n")
|
||||
|
||||
exit_status = subprocess.call(
|
||||
["formchk", "asec.chk"], stdout=subprocess.DEVNULL
|
||||
)
|
||||
|
||||
if exit_status != 0:
|
||||
raise SystemError("Formchk process did not exit properly")
|
||||
|
||||
logger.info("Done\n")
|
||||
|
||||
os.chdir(work_dir)
|
||||
|
||||
def _read_charges_from_fchk(self, cycle: int):
|
||||
fchk_file_path = Path("simfiles", f"step{cycle:02d}", "qm", "asec.fchk")
|
||||
with open(fchk_file_path) as fchk:
|
||||
fchkfile = fchk.readlines()
|
||||
|
||||
if self.step.gaussian.pop in ["chelpg", "mk"]:
|
||||
CHARGE_FLAG = "ESP Charges"
|
||||
else:
|
||||
CHARGE_FLAG = "ESP Charges"
|
||||
|
||||
start = fchkfile.pop(0).strip()
|
||||
while start.find(CHARGE_FLAG) != 0: # expression in begining of line
|
||||
start = fchkfile.pop(0).strip()
|
||||
|
||||
charges: List[float] = []
|
||||
while len(charges) < len(self.system.molecule[0].atom):
|
||||
charges.extend([float(x) for x in fchkfile.pop(0).split()])
|
||||
|
||||
return charges
|
||||
4
diceplayer/logger.py
Normal file
4
diceplayer/logger.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from diceplayer.utils import RunLogger
|
||||
|
||||
|
||||
logger = RunLogger("diceplayer")
|
||||
19
diceplayer/optimization/optimization_handler.py
Normal file
19
diceplayer/optimization/optimization_handler.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from diceplayer.config.player_config import RoutineType
|
||||
from diceplayer.state.state_model import StateModel
|
||||
|
||||
|
||||
class OptimizationHandler:
|
||||
@staticmethod
|
||||
def run(state: StateModel, current_cycle: int) -> StateModel:
|
||||
print(f"Running Optimization - {current_cycle}")
|
||||
return state
|
||||
|
||||
@staticmethod
|
||||
def _fetch_current_routine(state: StateModel, current_cycle: int) -> RoutineType:
|
||||
if state.config.type != RoutineType.BOTH:
|
||||
return state.config.type
|
||||
|
||||
if current_cycle < state.config.switch_cyc:
|
||||
return RoutineType.CHARGE
|
||||
|
||||
return RoutineType.GEOMETRY
|
||||
@@ -1,465 +1,55 @@
|
||||
from diceplayer import VERSION, logger
|
||||
from diceplayer.config.player_config import PlayerConfig
|
||||
from diceplayer.environment import Atom, Molecule, System
|
||||
from diceplayer.interface import DiceInterface, GaussianInterface
|
||||
from diceplayer.utils import PTable, weekday_date_time
|
||||
from diceplayer.dice.dice_handler import DiceHandler
|
||||
from diceplayer.logger import logger
|
||||
from diceplayer.state.state_handler import StateHandler
|
||||
from diceplayer.state.state_model import StateModel
|
||||
from diceplayer.utils.potential import read_system_from_phb
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Tuple
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing_extensions import TypedDict, Unpack
|
||||
|
||||
|
||||
ENV = ["OMP_STACKSIZE"]
|
||||
class PlayerFlags(TypedDict):
|
||||
continuation: bool
|
||||
force: bool
|
||||
|
||||
|
||||
class Player:
|
||||
def __init__(self, infile: str = None, optimization: bool = False):
|
||||
if infile is None and optimization is False:
|
||||
raise ValueError("Must specify either infile or optimization")
|
||||
def __init__(self, config: PlayerConfig):
|
||||
self.config = config
|
||||
self._state_handler = StateHandler(config.simulation_dir)
|
||||
|
||||
elif infile is not None:
|
||||
self.config = self.set_config(self.read_keywords(infile))
|
||||
def play(self, **flags: Unpack[PlayerFlags]):
|
||||
continuation = flags.get("continuation", False)
|
||||
force = flags.get("force", False)
|
||||
|
||||
self.system = System()
|
||||
|
||||
self.initial_cycle = 1
|
||||
|
||||
elif optimization is True:
|
||||
save = self.load_run_from_pickle()
|
||||
|
||||
self.config = save[0]
|
||||
|
||||
self.system = save[1]
|
||||
|
||||
self.initial_cycle = save[2] + 1
|
||||
state = self._state_handler.get(self.config, force=force)
|
||||
if not continuation and state is not None:
|
||||
logger.info(
|
||||
"Continuation flag is not set. Starting a new simulation and deleting any existing state."
|
||||
)
|
||||
self._state_handler.delete()
|
||||
state = None
|
||||
|
||||
if state is None:
|
||||
system = read_system_from_phb(self.config)
|
||||
state = StateModel(config=self.config, system=system)
|
||||
else:
|
||||
raise ValueError("Must specify either infile or config")
|
||||
logger.info("Resuming from existing state.")
|
||||
|
||||
self.dice_interface = DiceInterface()
|
||||
self.gaussian_interface = GaussianInterface()
|
||||
|
||||
def start(self):
|
||||
logger.info(
|
||||
"==========================================================================================\n"
|
||||
"Starting the iterative process.\n"
|
||||
"==========================================================================================\n"
|
||||
)
|
||||
|
||||
for cycle in range(self.initial_cycle, self.initial_cycle + self.config.maxcyc):
|
||||
while state.current_cycle < self.config.max_cyc:
|
||||
logger.info(
|
||||
f"------------------------------------------------------------------------------------------\n"
|
||||
f" Step # {cycle}\n"
|
||||
f"------------------------------------------------------------------------------------------\n"
|
||||
f"Starting cycle {state.current_cycle + 1} of {self.config.max_cyc}."
|
||||
)
|
||||
|
||||
self.dice_start(cycle)
|
||||
step_directory = self.config.simulation_dir / f"{state.current_cycle:02d}"
|
||||
if not step_directory.exists():
|
||||
step_directory.mkdir(parents=True)
|
||||
|
||||
try:
|
||||
self.gaussian_start(cycle)
|
||||
except StopIteration:
|
||||
break
|
||||
state = DiceHandler(step_directory).run(state, state.current_cycle)
|
||||
|
||||
self.save_run_in_pickle(cycle)
|
||||
# state = OptimizationHandler.run(state, state.current_cycle)
|
||||
|
||||
def prepare_system(self):
|
||||
for i, mol in enumerate(self.system.molecule):
|
||||
logger.info(f"Molecule {i + 1} - {mol.molname}")
|
||||
state.current_cycle += 1
|
||||
self._state_handler.save(state)
|
||||
|
||||
mol.print_mol_info()
|
||||
logger.info(
|
||||
"\n Translating and rotating molecule to standard orientation..."
|
||||
)
|
||||
|
||||
mol.rotate_to_standard_orientation()
|
||||
logger.info("\n Done")
|
||||
logger.info("\nNew values:\n")
|
||||
mol.print_mol_info()
|
||||
|
||||
logger.info("\n")
|
||||
|
||||
def create_simulation_dir(self):
|
||||
simulation_dir_path = Path(self.config.simulation_dir)
|
||||
if simulation_dir_path.exists():
|
||||
raise FileExistsError(
|
||||
f"Error: a file or a directory {self.config.simulation_dir} already exists,"
|
||||
f" move or delete the simfiles directory to continue."
|
||||
)
|
||||
simulation_dir_path.mkdir()
|
||||
|
||||
def create_geoms_file(self):
|
||||
geoms_file_path = Path(self.config.geoms_file)
|
||||
if geoms_file_path.exists():
|
||||
raise FileExistsError(
|
||||
f"Error: a file or a directory {self.config.geoms_file} already exists,"
|
||||
f" move or delete the simfiles directory to continue."
|
||||
)
|
||||
geoms_file_path.touch()
|
||||
|
||||
def print_keywords(self) -> None:
|
||||
def log_keywords(config: BaseModel):
|
||||
for key, value in sorted(config.model_dump().items()):
|
||||
if value is None:
|
||||
continue
|
||||
if isinstance(value, list):
|
||||
string = " ".join(str(x) for x in value)
|
||||
logger.info(f"{key} = [ {string} ]")
|
||||
else:
|
||||
logger.info(f"{key} = {value}")
|
||||
|
||||
logger.info(
|
||||
f"##########################################################################################\n"
|
||||
f"############# Welcome to DICEPLAYER version {VERSION} #############\n"
|
||||
f"##########################################################################################\n"
|
||||
)
|
||||
logger.info("Your python version is {}\n".format(sys.version))
|
||||
logger.info("Program started on {}\n".format(weekday_date_time()))
|
||||
logger.info("Environment variables:")
|
||||
for var in ENV:
|
||||
logger.info(
|
||||
"{} = {}\n".format(
|
||||
var, (os.environ[var] if var in os.environ else "Not set")
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"------------------------------------------------------------------------------------------\n"
|
||||
" DICE variables being used in this run:\n"
|
||||
"------------------------------------------------------------------------------------------\n"
|
||||
)
|
||||
|
||||
log_keywords(self.config.dice)
|
||||
|
||||
logger.info(
|
||||
"------------------------------------------------------------------------------------------\n"
|
||||
" GAUSSIAN variables being used in this run:\n"
|
||||
"------------------------------------------------------------------------------------------\n"
|
||||
)
|
||||
|
||||
log_keywords(self.config.gaussian)
|
||||
|
||||
logger.info("\n")
|
||||
|
||||
def read_potentials(self):
|
||||
ljname_path = Path(self.config.dice.ljname)
|
||||
if ljname_path.exists():
|
||||
with open(self.config.dice.ljname) as file:
|
||||
ljc_data = file.readlines()
|
||||
else:
|
||||
raise RuntimeError(f"Potential file {self.config.dice.ljname} not found.")
|
||||
|
||||
combrule = ljc_data.pop(0).split()[0]
|
||||
if combrule not in ("*", "+"):
|
||||
sys.exit(
|
||||
"Error: expected a '*' or a '+' sign in 1st line of file {}".format(
|
||||
self.config.dice.ljname
|
||||
)
|
||||
)
|
||||
self.config.dice.combrule = combrule
|
||||
|
||||
ntypes = ljc_data.pop(0).split()[0]
|
||||
if not ntypes.isdigit():
|
||||
sys.exit(
|
||||
"Error: expected an integer in the 2nd line of file {}".format(
|
||||
self.config.dice.ljname
|
||||
)
|
||||
)
|
||||
ntypes = int(ntypes)
|
||||
|
||||
if ntypes != len(self.config.dice.nmol):
|
||||
sys.exit(
|
||||
f"Error: number of molecule types in file {self.config.dice.ljname} "
|
||||
f"must match that of 'nmol' keyword in config file"
|
||||
)
|
||||
|
||||
for i in range(ntypes):
|
||||
try:
|
||||
nsites, molname = ljc_data.pop(0).split()[:2]
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Error: expected nsites and molname for the molecule type {i + 1}"
|
||||
)
|
||||
|
||||
if not nsites.isdigit():
|
||||
raise ValueError(
|
||||
f"Error: expected nsites to be an integer for molecule type {i + 1}"
|
||||
)
|
||||
|
||||
nsites = int(nsites)
|
||||
self.system.add_type(Molecule(molname))
|
||||
|
||||
atom_fields = ["lbl", "na", "rx", "ry", "rz", "chg", "eps", "sig"]
|
||||
for j in range(nsites):
|
||||
new_atom = dict(zip(atom_fields, ljc_data.pop(0).split()))
|
||||
self.system.molecule[i].add_atom(
|
||||
Atom(**self.validate_atom_dict(i, j, new_atom))
|
||||
)
|
||||
|
||||
def print_potentials(self) -> None:
|
||||
formatstr = "{:<3d} {:>3d} {:>10.5f} {:>10.5f} {:>10.5f} {:>10.6f} {:>9.5f} {:>7.4f} {:>9.4f}"
|
||||
logger.info(
|
||||
"==========================================================================================\n"
|
||||
f" Potential parameters from file {self.config.dice.ljname}:\n"
|
||||
"------------------------------------------------------------------------------------------"
|
||||
"\n"
|
||||
)
|
||||
|
||||
logger.info(f"Combination rule: {self.config.dice.combrule}")
|
||||
logger.info(f"Types of molecules: {len(self.system.molecule)}\n")
|
||||
|
||||
i = 0
|
||||
for mol in self.system.molecule:
|
||||
i += 1
|
||||
logger.info("{} atoms in molecule type {}:".format(len(mol.atom), i))
|
||||
logger.info(
|
||||
"---------------------------------------------------------------------------------"
|
||||
)
|
||||
logger.info(
|
||||
"Lbl AN X Y Z Charge Epsilon Sigma Mass"
|
||||
)
|
||||
logger.info(
|
||||
"---------------------------------------------------------------------------------"
|
||||
)
|
||||
|
||||
for atom in mol.atom:
|
||||
logger.info(
|
||||
formatstr.format(
|
||||
atom.lbl,
|
||||
atom.na,
|
||||
atom.rx,
|
||||
atom.ry,
|
||||
atom.rz,
|
||||
atom.chg,
|
||||
atom.eps,
|
||||
atom.sig,
|
||||
atom.mass,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("\n")
|
||||
|
||||
def dice_start(self, cycle: int):
|
||||
self.dice_interface.configure(
|
||||
self.config,
|
||||
self.system,
|
||||
)
|
||||
|
||||
self.dice_interface.start(cycle)
|
||||
|
||||
self.dice_interface.reset()
|
||||
|
||||
def gaussian_start(self, cycle: int):
|
||||
self.gaussian_interface.configure(
|
||||
self.config,
|
||||
self.system,
|
||||
)
|
||||
|
||||
result = self.gaussian_interface.start(cycle)
|
||||
|
||||
self.gaussian_interface.reset()
|
||||
|
||||
if self.config.opt:
|
||||
if "position" not in result:
|
||||
raise RuntimeError("Optimization failed. No position found in result.")
|
||||
|
||||
else:
|
||||
if "charges" not in result:
|
||||
raise RuntimeError(
|
||||
"Charges optimization failed. No charges found in result."
|
||||
)
|
||||
|
||||
diff = self.system.molecule[0].update_charges(result["charges"])
|
||||
|
||||
self.print_charges_and_dipole(cycle)
|
||||
self.print_geoms(cycle)
|
||||
|
||||
if diff < self.config.gaussian.chg_tol:
|
||||
logger.info(f"Charges converged after {cycle} cycles.")
|
||||
raise StopIteration()
|
||||
|
||||
def print_charges_and_dipole(self, cycle: int) -> None:
|
||||
"""
|
||||
Print the charges and dipole of the molecule in the Output file
|
||||
|
||||
Args:
|
||||
cycle (int): Number of the cycle
|
||||
fh (TextIO): Output file
|
||||
"""
|
||||
|
||||
logger.info("Cycle # {}\n".format(cycle))
|
||||
logger.info("Number of site: {}\n".format(len(self.system.molecule[0].atom)))
|
||||
|
||||
chargesAndDipole = self.system.molecule[0].charges_and_dipole()
|
||||
|
||||
logger.info(
|
||||
"{:>10.6f} {:>10.6f} {:>10.6f} {:>10.6f} {:>10.6f}\n".format(
|
||||
chargesAndDipole[0],
|
||||
chargesAndDipole[1],
|
||||
chargesAndDipole[2],
|
||||
chargesAndDipole[3],
|
||||
chargesAndDipole[4],
|
||||
)
|
||||
)
|
||||
|
||||
def print_geoms(self, cycle: int):
|
||||
with open(self.config.geoms_file, "a") as file:
|
||||
file.write(f"Cycle # {cycle}\n")
|
||||
|
||||
for atom in self.system.molecule[0].atom:
|
||||
symbol = PTable.get_atomic_symbol(atom.na)
|
||||
file.write(
|
||||
f"{symbol:<2s} {atom.rx:>10.6f} {atom.ry:>10.6f} {atom.rz:>10.6f}\n"
|
||||
)
|
||||
|
||||
file.write("\n")
|
||||
|
||||
@staticmethod
|
||||
def validate_atom_dict(molecule_type, molecule_site, atom_dict: dict) -> dict:
|
||||
molecule_type += 1
|
||||
molecule_site += 1
|
||||
|
||||
if len(atom_dict) < 8:
|
||||
raise ValueError(
|
||||
f"Invalid number of fields for site {molecule_site} for molecule type {molecule_type}."
|
||||
)
|
||||
|
||||
try:
|
||||
atom_dict["lbl"] = int(atom_dict["lbl"])
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"Invalid lbl fields for site {molecule_site} for molecule type {molecule_type}."
|
||||
)
|
||||
|
||||
try:
|
||||
atom_dict["na"] = int(atom_dict["na"])
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"Invalid na fields for site {molecule_site} for molecule type {molecule_type}."
|
||||
)
|
||||
|
||||
try:
|
||||
atom_dict["rx"] = float(atom_dict["rx"])
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"Invalid rx fields for site {molecule_site} for molecule type {molecule_type}. "
|
||||
f"Value must be a float."
|
||||
)
|
||||
|
||||
try:
|
||||
atom_dict["ry"] = float(atom_dict["ry"])
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"Invalid ry fields for site {molecule_site} for molecule type {molecule_type}. "
|
||||
f"Value must be a float."
|
||||
)
|
||||
|
||||
try:
|
||||
atom_dict["rz"] = float(atom_dict["rz"])
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"Invalid rz fields for site {molecule_site} for molecule type {molecule_type}. "
|
||||
f"Value must be a float."
|
||||
)
|
||||
|
||||
try:
|
||||
atom_dict["chg"] = float(atom_dict["chg"])
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"Invalid chg fields for site {molecule_site} for molecule type {molecule_type}. "
|
||||
f"Value must be a float."
|
||||
)
|
||||
|
||||
try:
|
||||
atom_dict["eps"] = float(atom_dict["eps"])
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"Invalid eps fields for site {molecule_site} for molecule type {molecule_type}. "
|
||||
f"Value must be a float."
|
||||
)
|
||||
|
||||
try:
|
||||
atom_dict["sig"] = float(atom_dict["sig"])
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"Invalid sig fields for site {molecule_site} for molecule type {molecule_type}. "
|
||||
f"Value must be a float."
|
||||
)
|
||||
|
||||
return atom_dict
|
||||
|
||||
def print_results(self):
|
||||
formatstr = "{:<3d} {:>3d} {:>10.5f} {:>10.5f} {:>10.5f} {:>10.6f} {:>9.5f} {:>7.4f} {:>9.4f}"
|
||||
|
||||
mol = self.system.molecule[0]
|
||||
logger.info("{} atoms in molecule type {}:".format(len(mol.atom), 1))
|
||||
logger.info(
|
||||
"---------------------------------------------------------------------------------"
|
||||
)
|
||||
logger.info(
|
||||
"Lbl AN X Y Z Charge Epsilon Sigma Mass"
|
||||
)
|
||||
logger.info(
|
||||
"---------------------------------------------------------------------------------"
|
||||
)
|
||||
|
||||
for atom in mol.atom:
|
||||
logger.info(
|
||||
formatstr.format(
|
||||
atom.lbl,
|
||||
atom.na,
|
||||
atom.rx,
|
||||
atom.ry,
|
||||
atom.rz,
|
||||
atom.chg,
|
||||
atom.eps,
|
||||
atom.sig,
|
||||
atom.mass,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("\n")
|
||||
|
||||
def save_run_in_pickle(self, cycle):
|
||||
try:
|
||||
with open("latest-step.pkl", "wb") as pickle_file:
|
||||
pickle.dump((self.config, self.system, cycle), pickle_file)
|
||||
except Exception:
|
||||
raise RuntimeError("Could not save pickle file latest-step.pkl.")
|
||||
|
||||
@staticmethod
|
||||
def load_run_from_pickle() -> Tuple[PlayerConfig, System, int]:
|
||||
pickle_path = Path("latest-step.pkl")
|
||||
try:
|
||||
with open(pickle_path, "rb") as pickle_file:
|
||||
save = pickle.load(pickle_file)
|
||||
return save[0], save[1], save[2] + 1
|
||||
|
||||
except Exception:
|
||||
raise RuntimeError(f"Could not load pickle file {pickle_path}.")
|
||||
|
||||
@staticmethod
|
||||
def set_config(data: dict) -> PlayerConfig:
|
||||
return PlayerConfig.model_validate(data)
|
||||
|
||||
@staticmethod
|
||||
def read_keywords(infile) -> dict:
|
||||
with open(infile, "r") as yml_file:
|
||||
config = yaml.load(yml_file, Loader=yaml.SafeLoader)
|
||||
|
||||
if "diceplayer" in config:
|
||||
return config.get("diceplayer")
|
||||
|
||||
raise RuntimeError(f"Could not find diceplayer section in {infile}.")
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, infile: str) -> "Player":
|
||||
return cls(infile=infile)
|
||||
|
||||
@classmethod
|
||||
def from_save(cls):
|
||||
return cls(optimization=True)
|
||||
logger.info("Reached maximum number of cycles. Simulation complete.")
|
||||
|
||||
37
diceplayer/state/state_handler.py
Normal file
37
diceplayer/state/state_handler.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from diceplayer.config import PlayerConfig
|
||||
from diceplayer.logger import logger
|
||||
from diceplayer.state.state_model import StateModel
|
||||
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class StateHandler:
|
||||
def __init__(self, sim_dir: Path, state_file: str = "state.pkl"):
|
||||
if not sim_dir.exists():
|
||||
sim_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._state_file = sim_dir / state_file
|
||||
|
||||
def get(self, config: PlayerConfig, force=False) -> StateModel | None:
|
||||
if not self._state_file.exists():
|
||||
return None
|
||||
|
||||
with open(self._state_file, mode="rb") as file:
|
||||
data = pickle.load(file)
|
||||
model = StateModel.model_validate(data)
|
||||
|
||||
if config != model.config and not force:
|
||||
logger.warning(
|
||||
"The configuration in the state file does not match the provided configuration."
|
||||
)
|
||||
return None
|
||||
|
||||
return model
|
||||
|
||||
def save(self, state: StateModel) -> None:
|
||||
with self._state_file.open(mode="wb") as f:
|
||||
pickle.dump(state.model_dump(), f)
|
||||
|
||||
def delete(self) -> None:
|
||||
if self._state_file.exists():
|
||||
self._state_file.unlink()
|
||||
19
diceplayer/state/state_model.py
Normal file
19
diceplayer/state/state_model.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from diceplayer.config import PlayerConfig
|
||||
from diceplayer.environment import System
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class StateModel(BaseModel):
|
||||
config: PlayerConfig
|
||||
system: System
|
||||
current_cycle: int = 0
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: PlayerConfig) -> Self:
|
||||
return cls(
|
||||
config=config,
|
||||
system=System(),
|
||||
current_cycle=0,
|
||||
)
|
||||
@@ -1,4 +1,4 @@
|
||||
from .logger import Logger, valid_logger
|
||||
from .logger import RunLogger
|
||||
from .misc import (
|
||||
compress_files_1mb,
|
||||
date_time,
|
||||
@@ -10,8 +10,7 @@ from .ptable import AtomInfo, PTable
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Logger",
|
||||
"valid_logger",
|
||||
"RunLogger",
|
||||
"PTable",
|
||||
"AtomInfo",
|
||||
"weekday_date_time",
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
from typing_extensions import Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Dataclass(Protocol):
|
||||
__dataclass_fields__: dict
|
||||
@@ -1,72 +1,44 @@
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def valid_logger(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
logger = args[0]
|
||||
assert logger._was_set, "Logger is not set. Please call set_logger() first."
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
H = TypeVar("H", bound=logging.Handler)
|
||||
|
||||
|
||||
class Logger:
|
||||
outfile = None
|
||||
class RunLogger(logging.Logger):
|
||||
def __init__(self, name, level=logging.INFO, stream=sys.stdout):
|
||||
super().__init__(name, level)
|
||||
|
||||
_logger = None
|
||||
self.handlers.clear()
|
||||
|
||||
_was_set = False
|
||||
self.handlers.append(
|
||||
self._configure_handler(logging.StreamHandler(stream), level)
|
||||
)
|
||||
|
||||
def __init__(self, logger_name):
|
||||
if self._logger is None:
|
||||
self._logger = logging.getLogger(logger_name)
|
||||
def set_output_file(self, outfile: Path, level=logging.INFO):
|
||||
for handler in list(self.handlers):
|
||||
if not isinstance(handler, logging.FileHandler):
|
||||
continue
|
||||
self.handlers.remove(handler)
|
||||
|
||||
def set_logger(self, outfile="run.log", level=logging.INFO, stream=None):
|
||||
outfile_path = None
|
||||
if outfile is not None and stream is None:
|
||||
outfile_path = Path(outfile)
|
||||
if outfile_path.exists():
|
||||
outfile_path.rename(str(outfile_path) + ".backup")
|
||||
self.handlers.append(self._create_file_handler(outfile, level))
|
||||
|
||||
if level is not None:
|
||||
self._logger.setLevel(level)
|
||||
@staticmethod
|
||||
def _create_file_handler(file: str | Path, level) -> logging.FileHandler:
|
||||
file = Path(file)
|
||||
|
||||
self._create_handlers(outfile_path, stream)
|
||||
if file.exists():
|
||||
file.rename(file.with_suffix(".log.backup"))
|
||||
|
||||
self._was_set = True
|
||||
handler = logging.FileHandler(file)
|
||||
return RunLogger._configure_handler(handler, level)
|
||||
|
||||
@valid_logger
|
||||
def info(self, message):
|
||||
self._logger.info(message)
|
||||
|
||||
@valid_logger
|
||||
def debug(self, message):
|
||||
self._logger.debug(message)
|
||||
|
||||
@valid_logger
|
||||
def warning(self, message):
|
||||
self._logger.warning(message)
|
||||
|
||||
@valid_logger
|
||||
def error(self, message):
|
||||
self._logger.error(message)
|
||||
|
||||
def _create_handlers(self, outfile_path: Path, stream):
|
||||
handlers = []
|
||||
if outfile_path is not None:
|
||||
handlers.append(logging.FileHandler(outfile_path, mode="a+"))
|
||||
elif stream is not None:
|
||||
handlers.append(logging.StreamHandler(stream))
|
||||
else:
|
||||
handlers.append(logging.StreamHandler())
|
||||
|
||||
for handler in handlers:
|
||||
handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
self._logger.addHandler(handler)
|
||||
|
||||
def close(self):
|
||||
for handler in self._logger.handlers:
|
||||
handler.close()
|
||||
self._logger.removeHandler(handler)
|
||||
@staticmethod
|
||||
def _configure_handler(handler: H, level) -> H:
|
||||
handler.setLevel(level)
|
||||
formatter = logging.Formatter("%(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
return handler
|
||||
|
||||
55
diceplayer/utils/potential.py
Normal file
55
diceplayer/utils/potential.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from diceplayer.config import PlayerConfig
|
||||
from diceplayer.environment import Atom, Molecule, System
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def read_system_from_phb(config: PlayerConfig) -> System:
|
||||
phb_file = Path(config.dice.ljname)
|
||||
if not phb_file.exists():
|
||||
raise FileNotFoundError
|
||||
|
||||
ljc_data = phb_file.read_text(encoding="utf-8").splitlines()
|
||||
|
||||
combrule = ljc_data.pop(0).strip()
|
||||
if combrule != config.dice.combrule:
|
||||
raise ValueError(
|
||||
f"Invalid combrule defined in {phb_file}. Expected the same value configured in the config file"
|
||||
)
|
||||
|
||||
ntypes = ljc_data.pop(0).strip()
|
||||
if not ntypes.isdigit():
|
||||
raise ValueError(f"Invalid ntypes defined in {phb_file}")
|
||||
|
||||
nmol = int(ntypes)
|
||||
if nmol != len(config.dice.nmol):
|
||||
raise ValueError(f"Invalid nmol defined in {phb_file}")
|
||||
|
||||
sys = System()
|
||||
|
||||
for i in range(nmol):
|
||||
nsites, molname = ljc_data.pop(0).split()
|
||||
|
||||
if not nsites.isdigit():
|
||||
raise ValueError(f"Invalid nsites defined in {phb_file}")
|
||||
nsites = int(nsites)
|
||||
|
||||
mol = Molecule(molname)
|
||||
|
||||
for j in range(nsites):
|
||||
_fields = ljc_data.pop(0).split()
|
||||
mol.add_atom(Atom(*_fields))
|
||||
|
||||
sys.add_type(mol)
|
||||
|
||||
return sys
|
||||
|
||||
|
||||
# def write_phb(phb_file: Path, state: StateModel) -> None:
|
||||
# if phb_file.exists():
|
||||
# raise RuntimeError(f"File {phb_file} already exists")
|
||||
#
|
||||
# with open(phb_file, "w") as f:
|
||||
# f.write(f"{state.config.dice.combrule}\n")
|
||||
# f.write(f"{len(state.system.nmols)}\n")
|
||||
# f.write(f"{state.config.dice.nmol}\n")
|
||||
@@ -29,7 +29,8 @@ dev = [
|
||||
"black>=24.4.2",
|
||||
"pre-commit>=3.7.1",
|
||||
"poethepoet>=0.27.0",
|
||||
"ruff>=0.15.2"
|
||||
"ruff>=0.15.2",
|
||||
"pytest>=9.0.2",
|
||||
]
|
||||
|
||||
|
||||
|
||||
30
tests/cli/test_read_input_file.py
Normal file
30
tests/cli/test_read_input_file.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import diceplayer
|
||||
from diceplayer.cli import read_input
|
||||
from diceplayer.config import PlayerConfig
|
||||
|
||||
import pytest
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class TestReadInputFile:
|
||||
@pytest.fixture
|
||||
def example_config(self) -> Path:
|
||||
return Path(diceplayer.__path__[0]).parent / "control.example.yml"
|
||||
|
||||
def test_read_input_file(self, example_config: Path):
|
||||
config = read_input(example_config)
|
||||
|
||||
assert config is not None
|
||||
assert isinstance(config, PlayerConfig)
|
||||
|
||||
def test_read_input_non_existing_file(self):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
read_input("nonexistent_file.yml")
|
||||
|
||||
def test_read_input_invalid_yaml(self, tmp_path: Path):
|
||||
invalid_yaml_file = tmp_path / "invalid.yml"
|
||||
invalid_yaml_file.write_text("This is not valid YAML: [unbalanced brackets")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
read_input(invalid_yaml_file)
|
||||
@@ -1,11 +1,12 @@
|
||||
from diceplayer.config.dice_config import DiceConfig
|
||||
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
|
||||
class TestDiceDto(unittest.TestCase):
|
||||
class TestDiceConfig:
|
||||
def test_class_instantiation(self):
|
||||
dice_dto = DiceConfig(
|
||||
nprocs=1,
|
||||
ljname="test",
|
||||
outname="test",
|
||||
dens=1.0,
|
||||
@@ -13,78 +14,77 @@ class TestDiceDto(unittest.TestCase):
|
||||
nstep=[1, 1],
|
||||
)
|
||||
|
||||
self.assertIsInstance(dice_dto, DiceConfig)
|
||||
assert isinstance(dice_dto, DiceConfig)
|
||||
|
||||
def test_validate_jname(self):
|
||||
with self.assertRaises(ValueError) as ex:
|
||||
with pytest.raises(ValueError) as ex:
|
||||
DiceConfig(
|
||||
nprocs=1,
|
||||
ljname=None,
|
||||
outname="test",
|
||||
dens=1.0,
|
||||
nmol=[1],
|
||||
nstep=[1, 1],
|
||||
)
|
||||
self.assertEqual(
|
||||
ex.exception, "Error: 'ljname' keyword not specified in config file"
|
||||
)
|
||||
|
||||
assert ex.value == "Error: 'ljname' keyword not specified in config file"
|
||||
|
||||
def test_validate_outname(self):
|
||||
with self.assertRaises(ValueError) as ex:
|
||||
with pytest.raises(ValueError) as ex:
|
||||
DiceConfig(
|
||||
nprocs=1,
|
||||
ljname="test",
|
||||
outname=None,
|
||||
dens=1.0,
|
||||
nmol=[1],
|
||||
nstep=[1, 1],
|
||||
)
|
||||
self.assertEqual(
|
||||
ex.exception, "Error: 'outname' keyword not specified in config file"
|
||||
)
|
||||
|
||||
assert ex.value == "Error: 'outname' keyword not specified in config file"
|
||||
|
||||
def test_validate_dens(self):
|
||||
with self.assertRaises(ValueError) as ex:
|
||||
with pytest.raises(ValueError) as ex:
|
||||
DiceConfig(
|
||||
nprocs=1,
|
||||
ljname="test",
|
||||
outname="test",
|
||||
dens=None,
|
||||
nmol=[1],
|
||||
nstep=[1, 1],
|
||||
)
|
||||
self.assertEqual(
|
||||
ex.exception, "Error: 'dens' keyword not specified in config file"
|
||||
)
|
||||
|
||||
assert ex.value == "Error: 'dens' keyword not specified in config file"
|
||||
|
||||
def test_validate_nmol(self):
|
||||
with self.assertRaises(ValueError) as ex:
|
||||
with pytest.raises(ValueError) as ex:
|
||||
DiceConfig(
|
||||
nprocs=1,
|
||||
ljname="test",
|
||||
outname="test",
|
||||
dens=1.0,
|
||||
nmol=0,
|
||||
nstep=[1, 1],
|
||||
)
|
||||
self.assertEqual(
|
||||
ex.exception,
|
||||
"Error: 'nmol' keyword not defined appropriately in config file",
|
||||
)
|
||||
|
||||
assert ex.value == "Error: 'nmol' keyword not specified in config file"
|
||||
|
||||
def test_validate_nstep(self):
|
||||
with self.assertRaises(ValueError) as ex:
|
||||
with pytest.raises(ValueError) as ex:
|
||||
DiceConfig(
|
||||
nprocs=1,
|
||||
ljname="test",
|
||||
outname="test",
|
||||
dens=1.0,
|
||||
nmol=[1],
|
||||
nstep=0,
|
||||
)
|
||||
self.assertEqual(
|
||||
ex.exception,
|
||||
"Error: 'nstep' keyword not defined appropriately in config file",
|
||||
)
|
||||
|
||||
assert ex.value == "Error: 'nstep' keyword not specified in config file"
|
||||
|
||||
def test_from_dict(self):
|
||||
dice_dto = DiceConfig.model_validate(
|
||||
{
|
||||
"nprocs": 1,
|
||||
"ljname": "test",
|
||||
"outname": "test",
|
||||
"dens": 1.0,
|
||||
@@ -93,4 +93,4 @@ class TestDiceDto(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
self.assertIsInstance(dice_dto, DiceConfig)
|
||||
assert isinstance(dice_dto, DiceConfig)
|
||||
@@ -1,9 +1,9 @@
|
||||
from diceplayer.config.gaussian_config import GaussianConfig
|
||||
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
|
||||
class TestGaussianDTO(unittest.TestCase):
|
||||
class TestGaussianConfig:
|
||||
def test_class_instantiation(self):
|
||||
gaussian_dto = GaussianConfig(
|
||||
level="test",
|
||||
@@ -11,10 +11,10 @@ class TestGaussianDTO(unittest.TestCase):
|
||||
keywords="test",
|
||||
)
|
||||
|
||||
self.assertIsInstance(gaussian_dto, GaussianConfig)
|
||||
assert isinstance(gaussian_dto, GaussianConfig)
|
||||
|
||||
def test_is_valid_qmprog(self):
|
||||
with self.assertRaises(ValueError):
|
||||
with pytest.raises(ValueError):
|
||||
GaussianConfig(
|
||||
level="test",
|
||||
qmprog="test",
|
||||
@@ -22,7 +22,7 @@ class TestGaussianDTO(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_is_valid_level(self):
|
||||
with self.assertRaises(ValueError):
|
||||
with pytest.raises(ValueError):
|
||||
GaussianConfig(
|
||||
level=None,
|
||||
qmprog="g16",
|
||||
@@ -38,8 +38,4 @@ class TestGaussianDTO(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
self.assertIsInstance(gaussian_dto, GaussianConfig)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
assert isinstance(gaussian_dto, GaussianConfig)
|
||||
90
tests/config/test_player_config.py
Normal file
90
tests/config/test_player_config.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from diceplayer.config.dice_config import DiceConfig
|
||||
from diceplayer.config.gaussian_config import GaussianConfig
|
||||
from diceplayer.config.player_config import PlayerConfig, RoutineType
|
||||
|
||||
import pytest
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class TestPlayerConfig:
|
||||
@pytest.fixture
|
||||
def dice_payload(self) -> dict[str, Any]:
|
||||
return {
|
||||
"nprocs": 4,
|
||||
"ljname": "test",
|
||||
"outname": "test",
|
||||
"dens": 1.0,
|
||||
"nmol": [1],
|
||||
"nstep": [1, 1],
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def gaussian_payload(self) -> dict[str, Any]:
|
||||
return {
|
||||
"level": "test",
|
||||
"qmprog": "g16",
|
||||
"keywords": "test",
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def player_payload(
|
||||
self, dice_payload: dict[str, Any], gaussian_payload: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "both",
|
||||
"mem": 12,
|
||||
"max_cyc": 100,
|
||||
"switch_cyc": 50,
|
||||
"ncores": 4,
|
||||
"dice": dice_payload,
|
||||
"gaussian": gaussian_payload,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def dice_config(self, dice_payload: dict[str, Any]) -> DiceConfig:
|
||||
return DiceConfig.model_validate(dice_payload)
|
||||
|
||||
@pytest.fixture
|
||||
def gaussian_config(self, gaussian_payload: dict[str, Any]):
|
||||
return GaussianConfig.model_validate(gaussian_payload)
|
||||
|
||||
def test_class_instantiation(
|
||||
self, dice_config: DiceConfig, gaussian_config: GaussianConfig
|
||||
):
|
||||
player_dto = PlayerConfig(
|
||||
type=RoutineType.BOTH,
|
||||
mem=12,
|
||||
max_cyc=100,
|
||||
switch_cyc=50,
|
||||
ncores=4,
|
||||
dice=dice_config,
|
||||
gaussian=gaussian_config,
|
||||
)
|
||||
|
||||
assert isinstance(player_dto, PlayerConfig)
|
||||
assert isinstance(player_dto.dice, DiceConfig)
|
||||
assert isinstance(player_dto.gaussian, GaussianConfig)
|
||||
|
||||
def test_min_altsteps(
|
||||
self, dice_config: DiceConfig, gaussian_config: GaussianConfig
|
||||
):
|
||||
player_dto = PlayerConfig(
|
||||
type=RoutineType.BOTH,
|
||||
mem=12,
|
||||
max_cyc=100,
|
||||
switch_cyc=50,
|
||||
ncores=4,
|
||||
altsteps=0,
|
||||
dice=dice_config,
|
||||
gaussian=gaussian_config,
|
||||
)
|
||||
|
||||
assert player_dto.altsteps == 20000
|
||||
|
||||
def test_from_dict(self, player_payload: dict[str, Any]):
|
||||
player_dto = PlayerConfig.model_validate(player_payload)
|
||||
|
||||
assert isinstance(player_dto, PlayerConfig)
|
||||
assert isinstance(player_dto.dice, DiceConfig)
|
||||
assert isinstance(player_dto.gaussian, GaussianConfig)
|
||||
@@ -1,83 +0,0 @@
|
||||
from diceplayer.config.dice_config import DiceConfig
|
||||
from diceplayer.config.gaussian_config import GaussianConfig
|
||||
from diceplayer.config.player_config import PlayerConfig
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
def get_config_dict():
|
||||
return {
|
||||
"opt": True,
|
||||
"mem": 12,
|
||||
"maxcyc": 100,
|
||||
"nprocs": 4,
|
||||
"ncores": 4,
|
||||
"dice": {
|
||||
"ljname": "test",
|
||||
"outname": "test",
|
||||
"dens": 1.0,
|
||||
"nmol": [1],
|
||||
"nstep": [1, 1],
|
||||
},
|
||||
"gaussian": {
|
||||
"level": "test",
|
||||
"qmprog": "g16",
|
||||
"keywords": "test",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestPlayerDTO(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.dice_dto = DiceConfig(
|
||||
ljname="test",
|
||||
outname="test",
|
||||
dens=1.0,
|
||||
nmol=[1],
|
||||
nstep=[1, 1],
|
||||
)
|
||||
self.gaussian_dto = GaussianConfig(
|
||||
level="test",
|
||||
qmprog="g16",
|
||||
keywords="test",
|
||||
)
|
||||
|
||||
def test_class_instantiation(self):
|
||||
player_dto = PlayerConfig(
|
||||
opt=True,
|
||||
mem=12,
|
||||
maxcyc=100,
|
||||
nprocs=4,
|
||||
ncores=4,
|
||||
dice=self.dice_dto,
|
||||
gaussian=self.gaussian_dto,
|
||||
)
|
||||
|
||||
self.assertIsInstance(player_dto, PlayerConfig)
|
||||
self.assertIsInstance(player_dto.dice, DiceConfig)
|
||||
self.assertIsInstance(player_dto.gaussian, GaussianConfig)
|
||||
|
||||
def test_min_altsteps(self):
|
||||
player_dto = PlayerConfig(
|
||||
opt=True,
|
||||
mem=12,
|
||||
maxcyc=100,
|
||||
nprocs=4,
|
||||
ncores=4,
|
||||
altsteps=100,
|
||||
dice=self.dice_dto,
|
||||
gaussian=self.gaussian_dto,
|
||||
)
|
||||
|
||||
self.assertEqual(player_dto.altsteps, 20000)
|
||||
|
||||
def test_from_dict(self):
|
||||
player_dto = PlayerConfig.model_validate(get_config_dict())
|
||||
|
||||
self.assertIsInstance(player_dto, PlayerConfig)
|
||||
self.assertIsInstance(player_dto.dice, DiceConfig)
|
||||
self.assertIsInstance(player_dto.gaussian, GaussianConfig)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
67
tests/dice/test_dice_input.py
Normal file
67
tests/dice/test_dice_input.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from diceplayer.config import PlayerConfig
|
||||
from diceplayer.dice.dice_input import (
|
||||
NPTEqConfig,
|
||||
NPTTerConfig,
|
||||
NVTEqConfig,
|
||||
NVTTerConfig,
|
||||
write_config,
|
||||
)
|
||||
|
||||
import pytest
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class TestDiceInput:
|
||||
@pytest.fixture
|
||||
def player_config(self) -> PlayerConfig:
|
||||
return PlayerConfig.model_validate(
|
||||
{
|
||||
"type": "both",
|
||||
"mem": 12,
|
||||
"max_cyc": 100,
|
||||
"switch_cyc": 50,
|
||||
"ncores": 4,
|
||||
"dice": {
|
||||
"nprocs": 4,
|
||||
"ljname": "test",
|
||||
"outname": "test",
|
||||
"dens": 1.0,
|
||||
"nmol": [1],
|
||||
"nstep": [1, 1],
|
||||
},
|
||||
"gaussian": {
|
||||
"level": "test",
|
||||
"qmprog": "g16",
|
||||
"keywords": "test",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
def test_generate_nvt_ter_input(self, player_config: PlayerConfig):
|
||||
dice_input = NVTTerConfig.from_config(player_config)
|
||||
|
||||
assert isinstance(dice_input, NVTTerConfig)
|
||||
|
||||
def test_generate_nvt_eq_input(self, player_config: PlayerConfig):
|
||||
dice_input = NVTEqConfig.from_config(player_config)
|
||||
|
||||
assert isinstance(dice_input, NVTEqConfig)
|
||||
|
||||
def test_generate_npt_ter_input(self, player_config: PlayerConfig):
|
||||
dice_input = NPTTerConfig.from_config(player_config)
|
||||
|
||||
assert isinstance(dice_input, NPTTerConfig)
|
||||
|
||||
def test_generate_npt_eq_input(self, player_config: PlayerConfig):
|
||||
dice_input = NPTEqConfig.from_config(player_config)
|
||||
|
||||
assert isinstance(dice_input, NPTEqConfig)
|
||||
|
||||
def test_write_dice_config(self, player_config: PlayerConfig, tmp_path: Path):
|
||||
dice_input = NVTTerConfig.from_config(player_config)
|
||||
|
||||
output_file = tmp_path / dice_input.type
|
||||
write_config(dice_input, tmp_path)
|
||||
|
||||
assert output_file.exists()
|
||||
@@ -1,113 +0,0 @@
|
||||
from unittest import mock
|
||||
|
||||
|
||||
def get_config_example():
|
||||
return """
|
||||
diceplayer:
|
||||
opt: no
|
||||
mem: 12
|
||||
maxcyc: 3
|
||||
ncores: 4
|
||||
nprocs: 4
|
||||
qmprog: 'g16'
|
||||
lps: no
|
||||
ghosts: no
|
||||
altsteps: 20000
|
||||
|
||||
dice:
|
||||
nmol: [1, 50]
|
||||
dens: 0.75
|
||||
nstep: [2000, 3000, 4000]
|
||||
isave: 1000
|
||||
outname: 'phb'
|
||||
progname: '~/.local/bin/dice'
|
||||
ljname: 'phb.ljc'
|
||||
randominit: 'first'
|
||||
|
||||
gaussian:
|
||||
qmprog: 'g16'
|
||||
level: 'MP2/aug-cc-pVDZ'
|
||||
keywords: 'freq'
|
||||
"""
|
||||
|
||||
|
||||
def get_potentials_exemple():
|
||||
return """\
|
||||
*
|
||||
2
|
||||
1 TEST
|
||||
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
|
||||
1 PLACEHOLDER
|
||||
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
|
||||
"""
|
||||
|
||||
|
||||
def get_potentials_error_combrule():
|
||||
return """\
|
||||
.
|
||||
2
|
||||
1 TEST
|
||||
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
|
||||
1 PLACEHOLDER
|
||||
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
|
||||
"""
|
||||
|
||||
|
||||
def get_potentials_error_ntypes():
|
||||
return """\
|
||||
*
|
||||
a
|
||||
1 TEST
|
||||
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
|
||||
1 PLACEHOLDER
|
||||
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
|
||||
"""
|
||||
|
||||
|
||||
def get_potentials_error_ntypes_config():
|
||||
return """\
|
||||
*
|
||||
3
|
||||
1 TEST
|
||||
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
|
||||
1 PLACEHOLDER
|
||||
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
|
||||
"""
|
||||
|
||||
|
||||
def get_potentials_error_nsites():
|
||||
return """\
|
||||
*
|
||||
2
|
||||
. TEST
|
||||
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
|
||||
1 PLACEHOLDER
|
||||
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
|
||||
"""
|
||||
|
||||
|
||||
def get_potentials_error_molname():
|
||||
return """\
|
||||
*
|
||||
2
|
||||
1
|
||||
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
|
||||
1 PLACEHOLDER
|
||||
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
|
||||
"""
|
||||
|
||||
|
||||
def mock_open(file, *args, **kwargs):
|
||||
values = {
|
||||
"control.test.yml": get_config_example(),
|
||||
"phb.ljc": get_potentials_exemple(),
|
||||
"phb.error.combrule.ljc": get_potentials_error_combrule(),
|
||||
"phb.error.ntypes.ljc": get_potentials_error_ntypes(),
|
||||
"phb.error.ntypes.config.ljc": get_potentials_error_ntypes_config(),
|
||||
"phb.error.nsites.ljc": get_potentials_error_nsites(),
|
||||
"phb.error.molname.ljc": get_potentials_error_molname(),
|
||||
}
|
||||
if file in values:
|
||||
return mock.mock_open(read_data=values[file])()
|
||||
|
||||
return mock.mock_open(read_data="")()
|
||||
@@ -1,32 +0,0 @@
|
||||
from typing_extensions import List
|
||||
|
||||
import itertools
|
||||
|
||||
|
||||
class MockProc:
|
||||
pid_counter = itertools.count()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.pid = next(MockProc.pid_counter)
|
||||
|
||||
if "exitcode" in kwargs:
|
||||
self.exitcode = kwargs["exitcode"]
|
||||
else:
|
||||
self.exitcode = 0
|
||||
|
||||
self.sentinel = self.pid
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def start(self):
|
||||
pass
|
||||
|
||||
def terminate(self):
|
||||
pass
|
||||
|
||||
|
||||
class MockConnection:
|
||||
@staticmethod
|
||||
def wait(sentinels: List[int]):
|
||||
return sentinels
|
||||
@@ -1,643 +0,0 @@
|
||||
from diceplayer import logger
|
||||
from diceplayer.config.player_config import PlayerConfig
|
||||
from diceplayer.environment import Atom, Molecule, System
|
||||
from diceplayer.interface import DiceInterface
|
||||
from tests.mocks.mock_inputs import get_config_example
|
||||
from tests.mocks.mock_proc import MockConnection, MockProc
|
||||
|
||||
import yaml
|
||||
|
||||
import io
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
|
||||
class TestDiceInterface(unittest.TestCase):
|
||||
def setUp(self):
|
||||
logger.set_logger(stream=io.StringIO())
|
||||
|
||||
config = yaml.load(get_config_example(), Loader=yaml.Loader)
|
||||
self.config = PlayerConfig.model_validate(config["diceplayer"])
|
||||
|
||||
def test_class_instantiation(self):
|
||||
dice = DiceInterface()
|
||||
|
||||
self.assertIsInstance(dice, DiceInterface)
|
||||
|
||||
def test_configure(self):
|
||||
dice = DiceInterface()
|
||||
|
||||
self.assertIsNone(dice.step)
|
||||
self.assertIsNone(dice.system)
|
||||
|
||||
# Ignoring the types for testing purposes
|
||||
dice.configure(self.config, System())
|
||||
|
||||
self.assertIsNotNone(dice.step)
|
||||
self.assertIsNotNone(dice.system)
|
||||
|
||||
def test_reset(self):
|
||||
dice = DiceInterface()
|
||||
|
||||
dice.configure(self.config, System())
|
||||
|
||||
self.assertTrue(hasattr(dice, "step"))
|
||||
self.assertTrue(hasattr(dice, "system"))
|
||||
|
||||
dice.reset()
|
||||
|
||||
self.assertFalse(hasattr(dice, "step"))
|
||||
self.assertFalse(hasattr(dice, "system"))
|
||||
|
||||
@mock.patch("diceplayer.interface.dice_interface.Process", MockProc())
|
||||
@mock.patch("diceplayer.interface.dice_interface.connection", MockConnection)
|
||||
def test_start(self):
|
||||
dice = DiceInterface()
|
||||
dice.configure(self.config, System())
|
||||
|
||||
dice.start(1)
|
||||
|
||||
@mock.patch("diceplayer.interface.dice_interface.connection", MockConnection)
|
||||
@mock.patch("diceplayer.interface.dice_interface.Process", MockProc(exitcode=1))
|
||||
def test_start_with_process_error(self):
|
||||
dice = DiceInterface()
|
||||
dice.configure(self.config, System())
|
||||
|
||||
with self.assertRaises(SystemExit):
|
||||
dice.start(1)
|
||||
|
||||
def test_simulation_process_raises_exception(self):
|
||||
dice = DiceInterface()
|
||||
|
||||
with self.assertRaises(SystemExit):
|
||||
dice._simulation_process(1, 1)
|
||||
|
||||
@mock.patch("diceplayer.interface.dice_interface.DiceInterface._make_proc_dir")
|
||||
@mock.patch("diceplayer.interface.dice_interface.DiceInterface._make_dice_inputs")
|
||||
@mock.patch("diceplayer.interface.dice_interface.DiceInterface._run_dice")
|
||||
def test_simulation_process(
|
||||
self, mock_run_dice, mock_make_dice_inputs, mock_make_proc_dir
|
||||
):
|
||||
dice = DiceInterface()
|
||||
|
||||
dice._simulation_process(1, 1)
|
||||
|
||||
self.assertTrue(dice._make_proc_dir.called)
|
||||
self.assertTrue(dice._make_dice_inputs.called)
|
||||
self.assertTrue(dice._run_dice.called)
|
||||
|
||||
@mock.patch("diceplayer.interface.dice_interface.Path.mkdir")
|
||||
@mock.patch("diceplayer.interface.dice_interface.Path.exists")
|
||||
def test_make_proc_dir_if_simdir_exists(self, mock_path_exists, mock_path_mkdir):
|
||||
dice = DiceInterface()
|
||||
dice.configure(self.config, System())
|
||||
|
||||
mock_path_exists.return_value = False
|
||||
|
||||
dice._make_proc_dir(1, 1)
|
||||
|
||||
self.assertEqual(mock_path_mkdir.call_count, 2)
|
||||
|
||||
@mock.patch("diceplayer.interface.dice_interface.Path.mkdir")
|
||||
@mock.patch("diceplayer.interface.dice_interface.Path.exists")
|
||||
def test_make_proc_dir_if_simdir_doesnt_exists(
|
||||
self, mock_path_exists, mock_path_mkdir
|
||||
):
|
||||
dice = DiceInterface()
|
||||
dice.configure(self.config, System())
|
||||
|
||||
mock_path_exists.return_value = False
|
||||
|
||||
dice._make_proc_dir(1, 1)
|
||||
|
||||
self.assertEqual(mock_path_mkdir.call_count, 2)
|
||||
|
||||
def test_make_dice_seed(self):
|
||||
seed = DiceInterface._make_dice_seed()
|
||||
|
||||
self.assertIsInstance(seed, int)
|
||||
|
||||
def test_make_dice_inputs_nstep_len_two_with_randoninit_first_cycle_one(self):
|
||||
dice = DiceInterface()
|
||||
dice.configure(self.config, System())
|
||||
|
||||
dice.step.dice.nstep = [1, 1]
|
||||
|
||||
dice._make_potentials = mock.Mock()
|
||||
|
||||
dice._make_init_file = mock.Mock()
|
||||
dice._new_density = mock.Mock()
|
||||
|
||||
dice._make_nvt_ter = mock.Mock()
|
||||
dice._make_nvt_eq = mock.Mock()
|
||||
dice._make_npt_ter = mock.Mock()
|
||||
dice._make_npt_eq = mock.Mock()
|
||||
|
||||
dice._make_dice_inputs(1, 1)
|
||||
|
||||
self.assertTrue(dice._make_potentials.called)
|
||||
|
||||
self.assertFalse(dice._make_init_file.called)
|
||||
self.assertFalse(dice._new_density.called)
|
||||
|
||||
self.assertTrue(dice._make_nvt_ter.called)
|
||||
self.assertTrue(dice._make_nvt_eq.called)
|
||||
|
||||
self.assertFalse(dice._make_npt_ter.called)
|
||||
self.assertFalse(dice._make_npt_eq.called)
|
||||
|
||||
@mock.patch("builtins.open", new_callable=mock.mock_open, read_data="test")
|
||||
@mock.patch("diceplayer.interface.dice_interface.Path.exists", return_value=True)
|
||||
def test_make_dice_inputs_nstep_len_two_with_randoninit_first_cycle_two(
|
||||
self, mock_path_exists, mock_open
|
||||
):
|
||||
dice = DiceInterface()
|
||||
dice.configure(self.config, System())
|
||||
|
||||
dice.step.dice.nstep = [1, 1]
|
||||
|
||||
dice._make_potentials = mock.Mock()
|
||||
|
||||
dice._make_init_file = mock.Mock()
|
||||
dice._new_density = mock.Mock()
|
||||
|
||||
dice._make_nvt_ter = mock.Mock()
|
||||
dice._make_nvt_eq = mock.Mock()
|
||||
dice._make_npt_ter = mock.Mock()
|
||||
dice._make_npt_eq = mock.Mock()
|
||||
|
||||
dice._make_dice_inputs(2, 1)
|
||||
|
||||
self.assertTrue(dice._make_potentials.called)
|
||||
|
||||
self.assertTrue(dice._make_init_file.called)
|
||||
self.assertTrue(dice._new_density.called)
|
||||
|
||||
self.assertFalse(dice._make_nvt_ter.called)
|
||||
self.assertTrue(dice._make_nvt_eq.called)
|
||||
|
||||
self.assertFalse(dice._make_npt_ter.called)
|
||||
self.assertFalse(dice._make_npt_eq.called)
|
||||
|
||||
@mock.patch("diceplayer.interface.dice_interface.Path.exists", return_value=False)
|
||||
def test_make_dice_inputs_raises_exception_on_last_not_found(
|
||||
self, mock_path_exists
|
||||
):
|
||||
dice = DiceInterface()
|
||||
dice.configure(self.config, System())
|
||||
|
||||
dice.step.dice.nstep = [1, 1]
|
||||
|
||||
dice._make_potentials = mock.Mock()
|
||||
|
||||
dice._make_init_file = mock.Mock()
|
||||
dice._new_density = mock.Mock()
|
||||
|
||||
dice._make_nvt_ter = mock.Mock()
|
||||
dice._make_nvt_eq = mock.Mock()
|
||||
dice._make_npt_ter = mock.Mock()
|
||||
dice._make_npt_eq = mock.Mock()
|
||||
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
dice._make_dice_inputs(2, 1)
|
||||
|
||||
def test_make_dice_inputs_nstep_len_three_with_randoninit_first_cycle_one(self):
|
||||
dice = DiceInterface()
|
||||
dice.configure(self.config, System())
|
||||
|
||||
dice._make_potentials = mock.Mock()
|
||||
|
||||
dice._make_init_file = mock.Mock()
|
||||
dice._new_density = mock.Mock()
|
||||
|
||||
dice._make_nvt_ter = mock.Mock()
|
||||
dice._make_nvt_eq = mock.Mock()
|
||||
dice._make_npt_ter = mock.Mock()
|
||||
dice._make_npt_eq = mock.Mock()
|
||||
|
||||
dice._make_dice_inputs(1, 1)
|
||||
|
||||
self.assertTrue(dice._make_potentials.called)
|
||||
|
||||
self.assertFalse(dice._make_init_file.called)
|
||||
self.assertFalse(dice._new_density.called)
|
||||
|
||||
self.assertTrue(dice._make_nvt_ter.called)
|
||||
self.assertFalse(dice._make_nvt_eq.called)
|
||||
|
||||
self.assertTrue(dice._make_npt_ter.called)
|
||||
self.assertTrue(dice._make_npt_eq.called)
|
||||
|
||||
@mock.patch("diceplayer.interface.dice_interface.os")
|
||||
@mock.patch("diceplayer.interface.dice_interface.shutil")
|
||||
@mock.patch("diceplayer.interface.dice_interface.Path.exists", return_value=True)
|
||||
def test_run_dice_on_first_cycle_run_successful(
|
||||
self, mock_path_exists, mock_shutils, mock_os
|
||||
):
|
||||
dice = DiceInterface()
|
||||
dice.configure(self.config, System())
|
||||
|
||||
dice.step.dice.nstep = [1, 1, 1]
|
||||
|
||||
dice.run_dice_file = mock.Mock()
|
||||
|
||||
dice._run_dice(1, 1)
|
||||
|
||||
self.assertTrue(mock_os.getcwd.called)
|
||||
self.assertTrue(mock_os.chdir.called)
|
||||
|
||||
self.assertEqual(dice.run_dice_file.call_count, 3)
|
||||
self.assertTrue(mock_shutils.copy.called)
|
||||
|
||||
dice = DiceInterface()
|
||||
dice.configure(self.config, System())
|
||||
|
||||
dice.step.dice.nstep = [1, 1]
|
||||
|
||||
dice.run_dice_file = mock.Mock()
|
||||
|
||||
dice._run_dice(1, 1)
|
||||
|
||||
self.assertTrue(mock_os.getcwd.called)
|
||||
self.assertTrue(mock_os.chdir.called)
|
||||
|
||||
self.assertEqual(dice.run_dice_file.call_count, 2)
|
||||
self.assertTrue(mock_shutils.copy.called)
|
||||
|
||||
@mock.patch("diceplayer.interface.dice_interface.os")
|
||||
@mock.patch("diceplayer.interface.dice_interface.shutil")
|
||||
@mock.patch("diceplayer.interface.dice_interface.Path.exists", return_value=True)
|
||||
def test_run_dice_on_second_cycle_run_successful(
|
||||
self, mock_path_exists, mock_shutils, mock_os
|
||||
):
|
||||
dice = DiceInterface()
|
||||
dice.configure(self.config, System())
|
||||
|
||||
dice.run_dice_file = mock.Mock()
|
||||
|
||||
dice._run_dice(2, 1)
|
||||
|
||||
self.assertTrue(mock_os.getcwd.called)
|
||||
self.assertTrue(mock_os.chdir.called)
|
||||
|
||||
self.assertEqual(dice.run_dice_file.call_count, 2)
|
||||
self.assertTrue(mock_shutils.copy.called)
|
||||
|
||||
dice = DiceInterface()
|
||||
dice.configure(self.config, System())
|
||||
|
||||
dice.run_dice_file = mock.Mock()
|
||||
|
||||
dice._run_dice(2, 1)
|
||||
|
||||
self.assertTrue(mock_os.getcwd.called)
|
||||
self.assertTrue(mock_os.chdir.called)
|
||||
|
||||
self.assertEqual(dice.run_dice_file.call_count, 2)
|
||||
self.assertTrue(mock_shutils.copy.called)
|
||||
|
||||
@mock.patch("diceplayer.interface.dice_interface.os")
|
||||
@mock.patch("diceplayer.interface.dice_interface.shutil")
|
||||
@mock.patch("diceplayer.interface.dice_interface.Path.exists", return_value=False)
|
||||
def test_run_dice_raises_filenotfound_on_invalid_file(
|
||||
self, mock_path_exists, mock_shutils, mock_os
|
||||
):
|
||||
dice = DiceInterface()
|
||||
dice.configure(self.config, System())
|
||||
|
||||
dice.run_dice_file = mock.Mock()
|
||||
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
dice._run_dice(1, 1)
|
||||
|
||||
@mock.patch("builtins.open", new_callable=mock.mock_open)
|
||||
def test_make_init_file(self, mock_open):
|
||||
example_atom = Atom(
|
||||
lbl=1,
|
||||
na=1,
|
||||
rx=1.0,
|
||||
ry=1.0,
|
||||
rz=1.0,
|
||||
chg=1.0,
|
||||
eps=1.0,
|
||||
sig=1.0,
|
||||
)
|
||||
|
||||
main_molecule = Molecule("main_molecule")
|
||||
main_molecule.add_atom(example_atom)
|
||||
|
||||
secondary_molecule = Molecule("secondary_molecule")
|
||||
secondary_molecule.add_atom(example_atom)
|
||||
|
||||
system = System()
|
||||
system.add_type(main_molecule)
|
||||
system.add_type(secondary_molecule)
|
||||
|
||||
dice = DiceInterface()
|
||||
dice.configure(self.config, system)
|
||||
|
||||
dice.step.dice.nmol = [1, 1]
|
||||
|
||||
last_xyz_file = io.StringIO()
|
||||
last_xyz_file.writelines(
|
||||
[
|
||||
" TEST\n",
|
||||
" Configuration number : TEST = TEST TEST TEST\n",
|
||||
" H 1.00000 1.00000 1.00000\n",
|
||||
" H 1.00000 1.00000 1.00000\n",
|
||||
]
|
||||
)
|
||||
last_xyz_file.seek(0)
|
||||
|
||||
dice._make_init_file("test", last_xyz_file)
|
||||
|
||||
mock_handler = mock_open()
|
||||
calls = mock_handler.write.call_args_list
|
||||
|
||||
lines = list(map(lambda x: x[0][0], calls))
|
||||
|
||||
expected_lines = [
|
||||
" 1.000000 1.000000 1.000000\n",
|
||||
" 1.000000 1.000000 1.000000\n",
|
||||
"$end",
|
||||
]
|
||||
|
||||
self.assertEqual(lines, expected_lines)
|
||||
|
||||
@mock.patch("builtins.open", new_callable=mock.mock_open)
|
||||
def test_new_density(self, mock_open):
|
||||
example_atom = Atom(
|
||||
lbl=1,
|
||||
na=1,
|
||||
rx=1.0,
|
||||
ry=1.0,
|
||||
rz=1.0,
|
||||
chg=1.0,
|
||||
eps=1.0,
|
||||
sig=1.0,
|
||||
)
|
||||
|
||||
main_molecule = Molecule("main_molecule")
|
||||
main_molecule.add_atom(example_atom)
|
||||
|
||||
secondary_molecule = Molecule("secondary_molecule")
|
||||
secondary_molecule.add_atom(example_atom)
|
||||
|
||||
system = System()
|
||||
system.add_type(main_molecule)
|
||||
system.add_type(secondary_molecule)
|
||||
|
||||
dice = DiceInterface()
|
||||
dice.configure(self.config, system)
|
||||
|
||||
last_xyz_file = io.StringIO()
|
||||
last_xyz_file.writelines(
|
||||
[
|
||||
" TEST\n",
|
||||
" Configuration number : TEST = 1 1 1\n",
|
||||
" H 1.00000 1.00000 1.00000\n",
|
||||
" H 1.00000 1.00000 1.00000\n",
|
||||
]
|
||||
)
|
||||
last_xyz_file.seek(0)
|
||||
|
||||
density = dice._new_density(last_xyz_file)
|
||||
|
||||
self.assertEqual(density, 85.35451545000001)
|
||||
|
||||
@mock.patch("builtins.open", new_callable=mock.mock_open)
|
||||
@mock.patch("diceplayer.interface.dice_interface.random")
|
||||
def test_make_nvt_ter(self, mock_random, mock_open):
|
||||
mock_random.random.return_value = 1
|
||||
|
||||
dice = DiceInterface()
|
||||
dice.configure(self.config, System())
|
||||
|
||||
dice._make_nvt_ter(1, "test")
|
||||
|
||||
mock_handler = mock_open()
|
||||
calls = mock_handler.write.call_args_list
|
||||
|
||||
lines = list(map(lambda x: x[0][0], calls))
|
||||
|
||||
expected_lines = [
|
||||
"title = Diceplayer run - NVT Thermalization\n",
|
||||
"ncores = 4\n",
|
||||
"ljname = phb.ljc\n",
|
||||
"outname = phb\n",
|
||||
"nmol = 1 50\n",
|
||||
"dens = 0.75\n",
|
||||
"temp = 300.0\n",
|
||||
"init = yes\n",
|
||||
"nstep = 2000\n",
|
||||
"vstep = 0\n",
|
||||
"mstop = 1\n",
|
||||
"accum = no\n",
|
||||
"iprint = 1\n",
|
||||
"isave = 0\n",
|
||||
"irdf = 0\n",
|
||||
"seed = 1000000\n",
|
||||
"upbuf = 360",
|
||||
]
|
||||
|
||||
self.assertEqual(lines, expected_lines)
|
||||
|
||||
@mock.patch("builtins.open", new_callable=mock.mock_open)
|
||||
@mock.patch("diceplayer.interface.dice_interface.random")
|
||||
def test_make_nvt_eq(self, mock_random, mock_open):
|
||||
mock_random.random.return_value = 1
|
||||
|
||||
dice = DiceInterface()
|
||||
dice.configure(self.config, System())
|
||||
|
||||
dice._make_nvt_eq(1, "test")
|
||||
|
||||
mock_handler = mock_open()
|
||||
calls = mock_handler.write.call_args_list
|
||||
|
||||
lines = list(map(lambda x: x[0][0], calls))
|
||||
|
||||
expected_lines = [
|
||||
"title = Diceplayer run - NVT Production\n",
|
||||
"ncores = 4\n",
|
||||
"ljname = phb.ljc\n",
|
||||
"outname = phb\n",
|
||||
"nmol = 1 50\n",
|
||||
"dens = 0.75\n",
|
||||
"temp = 300.0\n",
|
||||
"init = no\n",
|
||||
"nstep = 3000\n",
|
||||
"vstep = 0\n",
|
||||
"mstop = 1\n",
|
||||
"accum = no\n",
|
||||
"iprint = 1\n",
|
||||
"isave = 1000\n",
|
||||
"irdf = 40\n",
|
||||
"seed = 1000000\n",
|
||||
]
|
||||
|
||||
self.assertEqual(lines, expected_lines)
|
||||
|
||||
@mock.patch("builtins.open", new_callable=mock.mock_open)
|
||||
@mock.patch("diceplayer.interface.dice_interface.random")
|
||||
def test_make_npt_ter(self, mock_random, mock_open):
|
||||
mock_random.random.return_value = 1
|
||||
|
||||
dice = DiceInterface()
|
||||
dice.configure(self.config, System())
|
||||
|
||||
dice._make_npt_ter(1, "test")
|
||||
|
||||
mock_handler = mock_open()
|
||||
calls = mock_handler.write.call_args_list
|
||||
|
||||
lines = list(map(lambda x: x[0][0], calls))
|
||||
|
||||
expected_lines = [
|
||||
"title = Diceplayer run - NPT Thermalization\n",
|
||||
"ncores = 4\n",
|
||||
"ljname = phb.ljc\n",
|
||||
"outname = phb\n",
|
||||
"nmol = 1 50\n",
|
||||
"press = 1.0\n",
|
||||
"temp = 300.0\n",
|
||||
"init = no\n",
|
||||
"vstep = 600\n",
|
||||
"nstep = 5\n",
|
||||
"mstop = 1\n",
|
||||
"accum = no\n",
|
||||
"iprint = 1\n",
|
||||
"isave = 0\n",
|
||||
"irdf = 0\n",
|
||||
"seed = 1000000\n",
|
||||
]
|
||||
|
||||
self.assertEqual(lines, expected_lines)
|
||||
|
||||
@mock.patch("builtins.open", new_callable=mock.mock_open)
|
||||
@mock.patch("diceplayer.interface.dice_interface.random")
|
||||
def test_make_npt_eq(self, mock_random, mock_open):
|
||||
mock_random.random.return_value = 1
|
||||
|
||||
dice = DiceInterface()
|
||||
dice.configure(self.config, System())
|
||||
|
||||
dice._make_npt_eq("test")
|
||||
|
||||
mock_handler = mock_open()
|
||||
calls = mock_handler.write.call_args_list
|
||||
|
||||
lines = list(map(lambda x: x[0][0], calls))
|
||||
|
||||
expected_lines = [
|
||||
"title = Diceplayer run - NPT Production\n",
|
||||
"ncores = 4\n",
|
||||
"ljname = phb.ljc\n",
|
||||
"outname = phb\n",
|
||||
"nmol = 1 50\n",
|
||||
"press = 1.0\n",
|
||||
"temp = 300.0\n",
|
||||
"nstep = 5\n",
|
||||
"vstep = 800\n",
|
||||
"init = no\n",
|
||||
"mstop = 1\n",
|
||||
"accum = no\n",
|
||||
"iprint = 1\n",
|
||||
"isave = 1000\n",
|
||||
"irdf = 40\n",
|
||||
"seed = 1000000\n",
|
||||
]
|
||||
|
||||
self.assertEqual(lines, expected_lines)
|
||||
|
||||
@mock.patch("builtins.open", new_callable=mock.mock_open)
|
||||
def test_make_potentials(self, mock_open):
|
||||
example_atom = Atom(
|
||||
lbl=1,
|
||||
na=1,
|
||||
rx=1.0,
|
||||
ry=1.0,
|
||||
rz=1.0,
|
||||
chg=1.0,
|
||||
eps=1.0,
|
||||
sig=1.0,
|
||||
)
|
||||
|
||||
main_molecule = Molecule("main_molecule")
|
||||
main_molecule.add_atom(example_atom)
|
||||
|
||||
secondary_molecule = Molecule("secondary_molecule")
|
||||
secondary_molecule.add_atom(example_atom)
|
||||
|
||||
system = System()
|
||||
system.add_type(main_molecule)
|
||||
system.add_type(secondary_molecule)
|
||||
|
||||
dice = DiceInterface()
|
||||
dice.configure(self.config, system)
|
||||
|
||||
dice._make_potentials("test")
|
||||
|
||||
mock_handler = mock_open()
|
||||
calls = mock_handler.write.call_args_list
|
||||
|
||||
lines = list(map(lambda x: x[0][0], calls))
|
||||
|
||||
expected_lines = [
|
||||
"*\n",
|
||||
"2\n",
|
||||
"1 main_molecule\n",
|
||||
"1 1 1.00000 1.00000 1.00000 1.000000 1.00000 1.0000\n",
|
||||
"1 secondary_molecule\n",
|
||||
"1 1 1.00000 1.00000 1.00000 1.000000 1.00000 1.0000\n",
|
||||
]
|
||||
|
||||
self.assertEqual(lines, expected_lines)
|
||||
|
||||
@mock.patch("diceplayer.interface.dice_interface.subprocess")
|
||||
@mock.patch(
|
||||
"builtins.open",
|
||||
new_callable=mock.mock_open,
|
||||
read_data="End of simulation\nBLABLA",
|
||||
)
|
||||
def test_run_dice_file(self, mock_open, mock_subprocess):
|
||||
mock_subprocess.call.return_value = 0
|
||||
dice = DiceInterface()
|
||||
dice.configure(self.config, System())
|
||||
|
||||
dice.run_dice_file(1, 1, "test")
|
||||
|
||||
self.assertTrue(mock_subprocess.call.called)
|
||||
self.assertTrue(mock_open.called)
|
||||
|
||||
@mock.patch("diceplayer.interface.dice_interface.subprocess")
|
||||
@mock.patch("builtins.open", new_callable=mock.mock_open, read_data="Error\nBLABLA")
|
||||
def test_run_dice_file_raises_runtime_error_on_dice_file(
|
||||
self, mock_open, mock_subprocess
|
||||
):
|
||||
mock_subprocess.call.return_value = 0
|
||||
dice = DiceInterface()
|
||||
dice.configure(self.config, System())
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
dice.run_dice_file(1, 1, "test")
|
||||
|
||||
@mock.patch("diceplayer.interface.dice_interface.subprocess")
|
||||
@mock.patch(
|
||||
"builtins.open",
|
||||
new_callable=mock.mock_open,
|
||||
read_data="End of simulation\nBLABLA",
|
||||
)
|
||||
def test_run_dice_file_raises_runtime_error_of_dice_exit_code(
|
||||
self, mock_open, mock_subprocess
|
||||
):
|
||||
mock_subprocess.call.return_value = 1
|
||||
dice = DiceInterface()
|
||||
dice.configure(self.config, System())
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
dice.run_dice_file(1, 1, "test")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,115 +0,0 @@
|
||||
from diceplayer import logger
|
||||
from diceplayer.config.player_config import PlayerConfig
|
||||
from diceplayer.environment import System
|
||||
from diceplayer.interface import GaussianInterface
|
||||
from tests.mocks.mock_inputs import get_config_example
|
||||
|
||||
import yaml
|
||||
|
||||
import io
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
|
||||
class TestGaussianInterface(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
logger.set_logger(stream=io.StringIO())
|
||||
|
||||
config = yaml.load(get_config_example(), Loader=yaml.Loader)
|
||||
self.config = PlayerConfig.model_validate(config["diceplayer"])
|
||||
|
||||
def test_class_instantiation(self):
|
||||
gaussian_interface = GaussianInterface()
|
||||
self.assertIsInstance(gaussian_interface, GaussianInterface)
|
||||
|
||||
def test_configure(self):
|
||||
gaussian_interface = GaussianInterface()
|
||||
|
||||
self.assertIsNone(gaussian_interface.step)
|
||||
self.assertIsNone(gaussian_interface.system)
|
||||
|
||||
gaussian_interface.configure(self.config, System())
|
||||
|
||||
self.assertIsNotNone(gaussian_interface.step)
|
||||
self.assertIsNotNone(gaussian_interface.system)
|
||||
|
||||
def test_reset(self):
|
||||
gaussian_interface = GaussianInterface()
|
||||
|
||||
gaussian_interface.configure(self.config, System())
|
||||
|
||||
self.assertIsNotNone(gaussian_interface.step)
|
||||
self.assertIsNotNone(gaussian_interface.system)
|
||||
|
||||
gaussian_interface.reset()
|
||||
|
||||
self.assertFalse(hasattr(gaussian_interface, "step"))
|
||||
self.assertFalse(hasattr(gaussian_interface, "system"))
|
||||
|
||||
@mock.patch("diceplayer.interface.gaussian_interface.Path.mkdir")
|
||||
@mock.patch("diceplayer.interface.gaussian_interface.Path.exists")
|
||||
def test_make_qm_dir(self, mock_exists, mock_mkdir):
|
||||
mock_exists.return_value = False
|
||||
|
||||
gaussian_interface = GaussianInterface()
|
||||
gaussian_interface.configure(self.config, System())
|
||||
|
||||
gaussian_interface._make_qm_dir(1)
|
||||
|
||||
mock_exists.assert_called_once()
|
||||
mock_mkdir.assert_called_once()
|
||||
|
||||
@mock.patch("diceplayer.interface.gaussian_interface.shutil.copy")
|
||||
@mock.patch("diceplayer.interface.gaussian_interface.Path.exists")
|
||||
def test_copy_chk_file_from_previous_step(self, mock_exists, mock_copy):
|
||||
gaussian_interface = GaussianInterface()
|
||||
gaussian_interface.configure(self.config, System())
|
||||
|
||||
mock_exists.side_effect = [False, True]
|
||||
|
||||
gaussian_interface._copy_chk_file_from_previous_step(2)
|
||||
|
||||
self.assertTrue(mock_exists.called)
|
||||
self.assertTrue(mock_copy.called)
|
||||
|
||||
@mock.patch("diceplayer.interface.gaussian_interface.shutil.copy")
|
||||
@mock.patch("diceplayer.interface.gaussian_interface.Path.exists")
|
||||
def test_copy_chk_file_from_previous_step_no_previous_step(
|
||||
self, mock_exists, mock_copy
|
||||
):
|
||||
gaussian_interface = GaussianInterface()
|
||||
gaussian_interface.configure(self.config, System())
|
||||
|
||||
mock_exists.side_effect = [False, False]
|
||||
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
gaussian_interface._copy_chk_file_from_previous_step(2)
|
||||
|
||||
@mock.patch("diceplayer.interface.gaussian_interface.shutil.copy")
|
||||
@mock.patch("diceplayer.interface.gaussian_interface.Path.exists")
|
||||
def test_copy_chk_file_from_previous_step_current_exists(
|
||||
self, mock_exists, mock_copy
|
||||
):
|
||||
gaussian_interface = GaussianInterface()
|
||||
gaussian_interface.configure(self.config, System())
|
||||
|
||||
mock_exists.side_effect = [True, True]
|
||||
|
||||
with self.assertRaises(FileExistsError):
|
||||
gaussian_interface._copy_chk_file_from_previous_step(2)
|
||||
|
||||
# def test_start(self):
|
||||
# gaussian_interface = GaussianInterface()
|
||||
# gaussian_interface.configure(self.config, System())
|
||||
#
|
||||
# gaussian_interface._make_qm_dir = mock.Mock()
|
||||
# gaussian_interface._copy_chk_file_from_previous_step = mock.Mock()
|
||||
#
|
||||
# gaussian_interface.start(2)
|
||||
#
|
||||
# gaussian_interface._make_qm_dir.assert_called_once_with(2)
|
||||
# gaussian_interface._copy_chk_file_from_previous_step.assert_called_once_with(2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,132 +0,0 @@
|
||||
from diceplayer.utils import Logger, valid_logger
|
||||
|
||||
import io
|
||||
import logging
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
|
||||
class TestValidateLogger(unittest.TestCase):
|
||||
def test_validate_logger(self):
|
||||
class MockLogger:
|
||||
_was_set = True
|
||||
|
||||
@valid_logger
|
||||
def test_func(self):
|
||||
pass
|
||||
|
||||
MockLogger().test_func()
|
||||
|
||||
def test_validate_logger_exception(self):
|
||||
class MockLogger:
|
||||
_was_set = False
|
||||
|
||||
@valid_logger
|
||||
def test_func(self):
|
||||
pass
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
MockLogger().test_func()
|
||||
|
||||
|
||||
class TestLogger(unittest.TestCase):
|
||||
def test_class_instantiation(self):
|
||||
logger = Logger("test")
|
||||
|
||||
self.assertIsInstance(logger, Logger)
|
||||
|
||||
@mock.patch("builtins.open", mock.mock_open())
|
||||
def test_set_logger_to_file(self):
|
||||
logger = Logger("test")
|
||||
|
||||
logger.set_logger(stream=io.StringIO())
|
||||
|
||||
self.assertIsNotNone(logger._logger)
|
||||
self.assertEqual(logger._logger.name, "test")
|
||||
|
||||
def test_set_logger_to_stream(self):
|
||||
logger = Logger("test")
|
||||
|
||||
logger.set_logger(stream=io.StringIO())
|
||||
|
||||
self.assertIsNotNone(logger._logger)
|
||||
self.assertEqual(logger._logger.name, "test")
|
||||
|
||||
@mock.patch("builtins.open", mock.mock_open())
|
||||
@mock.patch("diceplayer.utils.logger.Path.exists")
|
||||
@mock.patch("diceplayer.utils.logger.Path.rename")
|
||||
def test_set_logger_if_file_exists(self, mock_rename, mock_exists):
|
||||
logger = Logger("test")
|
||||
|
||||
mock_exists.return_value = True
|
||||
logger.set_logger()
|
||||
|
||||
self.assertTrue(mock_rename.called)
|
||||
self.assertIsNotNone(logger._logger)
|
||||
self.assertEqual(logger._logger.name, "test")
|
||||
|
||||
@mock.patch("builtins.open", mock.mock_open())
|
||||
@mock.patch("diceplayer.utils.logger.Path.exists")
|
||||
@mock.patch("diceplayer.utils.logger.Path.rename")
|
||||
def test_set_logger_if_file_not_exists(self, mock_rename, mock_exists):
|
||||
logger = Logger("test")
|
||||
|
||||
mock_exists.return_value = False
|
||||
logger.set_logger()
|
||||
|
||||
self.assertFalse(mock_rename.called)
|
||||
self.assertIsNotNone(logger._logger)
|
||||
self.assertEqual(logger._logger.name, "test")
|
||||
|
||||
@mock.patch("builtins.open", mock.mock_open())
|
||||
def test_close(self):
|
||||
logger = Logger("test")
|
||||
|
||||
logger.set_logger()
|
||||
logger.close()
|
||||
|
||||
self.assertEqual(len(logger._logger.handlers), 0)
|
||||
|
||||
@mock.patch("builtins.open", mock.mock_open())
|
||||
def test_info(self):
|
||||
logger = Logger("test")
|
||||
logger.set_logger(stream=io.StringIO())
|
||||
|
||||
with self.assertLogs(level="INFO") as cm:
|
||||
logger.info("test")
|
||||
|
||||
self.assertEqual(cm.output, ["INFO:test:test"])
|
||||
|
||||
@mock.patch("builtins.open", mock.mock_open())
|
||||
def test_debug(self):
|
||||
logger = Logger("test")
|
||||
logger.set_logger(stream=io.StringIO(), level=logging.DEBUG)
|
||||
|
||||
with self.assertLogs(level="DEBUG") as cm:
|
||||
logger.debug("test")
|
||||
|
||||
self.assertEqual(cm.output, ["DEBUG:test:test"])
|
||||
|
||||
@mock.patch("builtins.open", mock.mock_open())
|
||||
def test_warning(self):
|
||||
logger = Logger("test")
|
||||
logger.set_logger(stream=io.StringIO())
|
||||
|
||||
with self.assertLogs(level="WARNING") as cm:
|
||||
logger.warning("test")
|
||||
|
||||
self.assertEqual(cm.output, ["WARNING:test:test"])
|
||||
|
||||
@mock.patch("builtins.open", mock.mock_open())
|
||||
def test_error(self):
|
||||
logger = Logger("test")
|
||||
logger.set_logger(stream=io.StringIO())
|
||||
|
||||
with self.assertLogs(level="ERROR") as cm:
|
||||
logger.error("test")
|
||||
|
||||
self.assertEqual(cm.output, ["ERROR:test:test"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
0
tests/state/__init__.py
Normal file
0
tests/state/__init__.py
Normal file
118
tests/state/test_state_handler.py
Normal file
118
tests/state/test_state_handler.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from diceplayer.config import DiceConfig, GaussianConfig, PlayerConfig
|
||||
from diceplayer.environment import System
|
||||
from diceplayer.state.state_handler import StateHandler
|
||||
from diceplayer.state.state_model import StateModel
|
||||
|
||||
import pytest
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class TestStateHandler:
|
||||
@pytest.fixture
|
||||
def player_config(self) -> PlayerConfig:
|
||||
return PlayerConfig(
|
||||
type="both",
|
||||
mem=12,
|
||||
max_cyc=100,
|
||||
switch_cyc=50,
|
||||
ncores=4,
|
||||
dice=DiceConfig(
|
||||
nprocs=4,
|
||||
ljname="test",
|
||||
outname="test",
|
||||
dens=1.0,
|
||||
nmol=[1],
|
||||
nstep=[1, 1],
|
||||
),
|
||||
gaussian=GaussianConfig(
|
||||
level="test",
|
||||
qmprog="g16",
|
||||
keywords="test",
|
||||
),
|
||||
)
|
||||
|
||||
def test_initialization(self, tmp_path: Path):
|
||||
state_handler = StateHandler(tmp_path)
|
||||
|
||||
assert isinstance(state_handler, StateHandler)
|
||||
|
||||
def test_save(self, tmp_path: Path, player_config: PlayerConfig):
|
||||
state_handler = StateHandler(tmp_path)
|
||||
|
||||
state = StateModel(
|
||||
config=player_config,
|
||||
system=System(),
|
||||
current_cycle=0,
|
||||
)
|
||||
|
||||
state_handler.save(state)
|
||||
|
||||
assert (tmp_path / "state.pkl").exists()
|
||||
|
||||
def test_get_when_empty(self, tmp_path: Path, player_config: PlayerConfig):
|
||||
state_handler = StateHandler(tmp_path)
|
||||
|
||||
state = state_handler.get(player_config)
|
||||
|
||||
assert state is None
|
||||
|
||||
def test_get(self, tmp_path: Path, player_config: PlayerConfig):
|
||||
state_handler = StateHandler(tmp_path)
|
||||
|
||||
state = StateModel(
|
||||
config=player_config,
|
||||
system=System(),
|
||||
current_cycle=0,
|
||||
)
|
||||
|
||||
state_handler.save(state)
|
||||
|
||||
retrieved_state = state_handler.get(player_config)
|
||||
|
||||
assert retrieved_state is not None
|
||||
assert retrieved_state.config == state.config
|
||||
assert retrieved_state.system == state.system
|
||||
assert retrieved_state.current_cycle == state.current_cycle
|
||||
|
||||
def test_get_with_different_config(
|
||||
self, tmp_path: Path, player_config: PlayerConfig
|
||||
):
|
||||
state_handler = StateHandler(tmp_path)
|
||||
|
||||
state = StateModel(
|
||||
config=player_config,
|
||||
system=System(),
|
||||
current_cycle=0,
|
||||
)
|
||||
|
||||
state_handler.save(state)
|
||||
|
||||
different_config = player_config.model_copy(update={"max_cyc": 200})
|
||||
|
||||
retrieved_state = state_handler.get(different_config)
|
||||
|
||||
assert retrieved_state is None
|
||||
|
||||
def test_get_with_different_config_force(
|
||||
self, tmp_path: Path, player_config: PlayerConfig
|
||||
):
|
||||
state_handler = StateHandler(tmp_path)
|
||||
|
||||
state = StateModel(
|
||||
config=player_config,
|
||||
system=System(),
|
||||
current_cycle=0,
|
||||
)
|
||||
|
||||
state_handler.save(state)
|
||||
|
||||
different_config = player_config.model_copy(update={"max_cyc": 200})
|
||||
|
||||
retrieved_state = state_handler.get(different_config, force=True)
|
||||
|
||||
assert retrieved_state is not None
|
||||
assert retrieved_state.config == state.config
|
||||
assert retrieved_state.config != different_config
|
||||
assert retrieved_state.system == state.system
|
||||
assert retrieved_state.current_cycle == state.current_cycle
|
||||
@@ -1,396 +0,0 @@
|
||||
from diceplayer import logger
|
||||
from diceplayer.player import Player
|
||||
from tests.mocks.mock_inputs import mock_open
|
||||
|
||||
import io
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
|
||||
class TestPlayer(unittest.TestCase):
|
||||
def setUp(self):
|
||||
logger.set_logger(stream=io.StringIO())
|
||||
|
||||
@mock.patch("builtins.open", mock_open)
|
||||
def test_class_instantiation(self):
|
||||
# This file does not exist and it will be mocked
|
||||
player = Player("control.test.yml")
|
||||
|
||||
self.assertIsInstance(player, Player)
|
||||
|
||||
@mock.patch("builtins.open", mock_open)
|
||||
def test_start(self):
|
||||
player = Player("control.test.yml")
|
||||
|
||||
player.gaussian_start = mock.MagicMock()
|
||||
player.dice_start = mock.MagicMock()
|
||||
|
||||
player.start()
|
||||
|
||||
self.assertEqual(player.dice_start.call_count, 3)
|
||||
self.assertEqual(player.gaussian_start.call_count, 3)
|
||||
|
||||
@mock.patch("builtins.open", mock_open)
|
||||
@mock.patch("diceplayer.player.Path")
|
||||
def test_create_simulation_dir_if_already_exists(self, mock_path):
|
||||
player = Player("control.test.yml")
|
||||
mock_path.return_value.exists.return_value = True
|
||||
|
||||
with self.assertRaises(FileExistsError):
|
||||
player.create_simulation_dir()
|
||||
|
||||
self.assertTrue(mock_path.called)
|
||||
|
||||
@mock.patch("builtins.open", mock_open)
|
||||
@mock.patch("diceplayer.player.Path")
|
||||
def test_create_simulation_dir_if_not_exists(self, mock_path):
|
||||
player = Player("control.test.yml")
|
||||
mock_path.return_value.exists.return_value = False
|
||||
|
||||
player.create_simulation_dir()
|
||||
|
||||
self.assertTrue(mock_path.called)
|
||||
|
||||
@mock.patch("builtins.open", mock_open)
|
||||
@mock.patch("diceplayer.player.VERSION", "test")
|
||||
@mock.patch("diceplayer.player.sys")
|
||||
@mock.patch("diceplayer.player.weekday_date_time")
|
||||
def test_print_keywords(self, mock_date_func, mock_sys):
|
||||
player = Player("control.test.yml")
|
||||
|
||||
mock_sys.version = "TEST"
|
||||
mock_date_func.return_value = "00 Test 0000 at 00:00:00"
|
||||
|
||||
with self.assertLogs() as cm:
|
||||
player.print_keywords()
|
||||
|
||||
expected_output = [
|
||||
"INFO:diceplayer:##########################################################################################\n############# Welcome to DICEPLAYER version test #############\n##########################################################################################\n",
|
||||
"INFO:diceplayer:Your python version is TEST\n",
|
||||
"INFO:diceplayer:Program started on 00 Test 0000 at 00:00:00\n",
|
||||
"INFO:diceplayer:Environment variables:",
|
||||
"INFO:diceplayer:OMP_STACKSIZE = Not set\n",
|
||||
"INFO:diceplayer:------------------------------------------------------------------------------------------\n DICE variables being used in this run:\n------------------------------------------------------------------------------------------\n",
|
||||
"INFO:diceplayer:combrule = *",
|
||||
"INFO:diceplayer:dens = 0.75",
|
||||
"INFO:diceplayer:isave = 1000",
|
||||
"INFO:diceplayer:ljname = phb.ljc",
|
||||
"INFO:diceplayer:nmol = [ 1 50 ]",
|
||||
"INFO:diceplayer:nstep = [ 2000 3000 4000 ]",
|
||||
"INFO:diceplayer:outname = phb",
|
||||
"INFO:diceplayer:press = 1.0",
|
||||
"INFO:diceplayer:progname = ~/.local/bin/dice",
|
||||
"INFO:diceplayer:randominit = first",
|
||||
"INFO:diceplayer:temp = 300.0",
|
||||
"INFO:diceplayer:upbuf = 360",
|
||||
"INFO:diceplayer:------------------------------------------------------------------------------------------\n GAUSSIAN variables being used in this run:\n------------------------------------------------------------------------------------------\n",
|
||||
"INFO:diceplayer:chg_tol = 0.01",
|
||||
"INFO:diceplayer:chgmult = [ 0 1 ]",
|
||||
"INFO:diceplayer:keywords = freq",
|
||||
"INFO:diceplayer:level = MP2/aug-cc-pVDZ",
|
||||
"INFO:diceplayer:pop = chelpg",
|
||||
"INFO:diceplayer:qmprog = g16",
|
||||
"INFO:diceplayer:\n",
|
||||
]
|
||||
|
||||
self.assertEqual(cm.output, expected_output)
|
||||
|
||||
def test_validate_atom_dict(self):
|
||||
with self.assertRaises(ValueError) as context:
|
||||
Player.validate_atom_dict(
|
||||
molecule_type=0,
|
||||
molecule_site=0,
|
||||
atom_dict={
|
||||
"lbl": 0,
|
||||
"na": 1,
|
||||
"rx": 1.0,
|
||||
"ry": 1.0,
|
||||
"rz": 1.0,
|
||||
"chg": 1.0,
|
||||
"eps": 1.0,
|
||||
},
|
||||
)
|
||||
self.assertEqual(
|
||||
str(context.exception),
|
||||
"Invalid number of fields for site 1 for molecule type 1.",
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
Player.validate_atom_dict(
|
||||
molecule_type=0,
|
||||
molecule_site=0,
|
||||
atom_dict={
|
||||
"lbl": "",
|
||||
"na": 1,
|
||||
"rx": 1.0,
|
||||
"ry": 1.0,
|
||||
"rz": 1.0,
|
||||
"chg": 1.0,
|
||||
"eps": 1.0,
|
||||
"sig": 1.0,
|
||||
},
|
||||
)
|
||||
self.assertEqual(
|
||||
str(context.exception), "Invalid lbl fields for site 1 for molecule type 1."
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
Player.validate_atom_dict(
|
||||
molecule_type=0,
|
||||
molecule_site=0,
|
||||
atom_dict={
|
||||
"lbl": 1.0,
|
||||
"na": "",
|
||||
"rx": 1.0,
|
||||
"ry": 1.0,
|
||||
"rz": 1.0,
|
||||
"chg": 1.0,
|
||||
"eps": 1.0,
|
||||
"sig": 1.0,
|
||||
},
|
||||
)
|
||||
self.assertEqual(
|
||||
str(context.exception), "Invalid na fields for site 1 for molecule type 1."
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
Player.validate_atom_dict(
|
||||
molecule_type=0,
|
||||
molecule_site=0,
|
||||
atom_dict={
|
||||
"lbl": 1.0,
|
||||
"na": 1,
|
||||
"rx": "",
|
||||
"ry": 1.0,
|
||||
"rz": 1.0,
|
||||
"chg": 1.0,
|
||||
"eps": 1.0,
|
||||
"sig": 1.0,
|
||||
},
|
||||
)
|
||||
self.assertEqual(
|
||||
str(context.exception),
|
||||
"Invalid rx fields for site 1 for molecule type 1. Value must be a float.",
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
Player.validate_atom_dict(
|
||||
molecule_type=0,
|
||||
molecule_site=0,
|
||||
atom_dict={
|
||||
"lbl": 1.0,
|
||||
"na": 1,
|
||||
"rx": 1.0,
|
||||
"ry": "",
|
||||
"rz": 1.0,
|
||||
"chg": 1.0,
|
||||
"eps": 1.0,
|
||||
"sig": 1.0,
|
||||
},
|
||||
)
|
||||
self.assertEqual(
|
||||
str(context.exception),
|
||||
"Invalid ry fields for site 1 for molecule type 1. Value must be a float.",
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
Player.validate_atom_dict(
|
||||
molecule_type=0,
|
||||
molecule_site=0,
|
||||
atom_dict={
|
||||
"lbl": 1.0,
|
||||
"na": 1,
|
||||
"rx": 1.0,
|
||||
"ry": 1.0,
|
||||
"rz": "",
|
||||
"chg": 1.0,
|
||||
"eps": 1.0,
|
||||
"sig": 1.0,
|
||||
},
|
||||
)
|
||||
self.assertEqual(
|
||||
str(context.exception),
|
||||
"Invalid rz fields for site 1 for molecule type 1. Value must be a float.",
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
Player.validate_atom_dict(
|
||||
molecule_type=0,
|
||||
molecule_site=0,
|
||||
atom_dict={
|
||||
"lbl": 1.0,
|
||||
"na": 1,
|
||||
"rx": 1.0,
|
||||
"ry": 1.0,
|
||||
"rz": 1.0,
|
||||
"chg": "",
|
||||
"eps": 1.0,
|
||||
"sig": 1.0,
|
||||
},
|
||||
)
|
||||
self.assertEqual(
|
||||
str(context.exception),
|
||||
"Invalid chg fields for site 1 for molecule type 1. Value must be a float.",
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
Player.validate_atom_dict(
|
||||
molecule_type=0,
|
||||
molecule_site=0,
|
||||
atom_dict={
|
||||
"lbl": 1.0,
|
||||
"na": 1,
|
||||
"rx": 1.0,
|
||||
"ry": 1.0,
|
||||
"rz": 1.0,
|
||||
"chg": 1.0,
|
||||
"eps": "",
|
||||
"sig": 1.0,
|
||||
},
|
||||
)
|
||||
self.assertEqual(
|
||||
str(context.exception),
|
||||
"Invalid eps fields for site 1 for molecule type 1. Value must be a float.",
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
Player.validate_atom_dict(
|
||||
molecule_type=0,
|
||||
molecule_site=0,
|
||||
atom_dict={
|
||||
"lbl": 1.0,
|
||||
"na": 1,
|
||||
"rx": 1.0,
|
||||
"ry": 1.0,
|
||||
"rz": 1.0,
|
||||
"chg": 1.0,
|
||||
"eps": 1.0,
|
||||
"sig": "",
|
||||
},
|
||||
)
|
||||
self.assertEqual(
|
||||
str(context.exception),
|
||||
"Invalid sig fields for site 1 for molecule type 1. Value must be a float.",
|
||||
)
|
||||
|
||||
@mock.patch("builtins.open", mock_open)
|
||||
@mock.patch("diceplayer.player.Path.exists", return_value=True)
|
||||
def test_read_potentials(self, mock_path_exists):
|
||||
player = Player("control.test.yml")
|
||||
|
||||
player.read_potentials()
|
||||
|
||||
self.assertEqual(player.system.molecule[0].molname, "TEST")
|
||||
self.assertEqual(len(player.system.molecule[0].atom), 1)
|
||||
|
||||
self.assertEqual(player.system.molecule[1].molname, "PLACEHOLDER")
|
||||
self.assertEqual(len(player.system.molecule[1].atom), 1)
|
||||
|
||||
@mock.patch("builtins.open", mock_open)
|
||||
@mock.patch("diceplayer.player.Path.exists")
|
||||
def test_read_potentials_error(self, mock_path_exists):
|
||||
player = Player("control.test.yml")
|
||||
|
||||
# Testing file not found error
|
||||
mock_path_exists.return_value = False
|
||||
with self.assertRaises(RuntimeError) as context:
|
||||
player.read_potentials()
|
||||
|
||||
self.assertEqual(str(context.exception), "Potential file phb.ljc not found.")
|
||||
|
||||
# Enabling file found for next tests
|
||||
mock_path_exists.return_value = True
|
||||
|
||||
# Testing combrule error
|
||||
with self.assertRaises(SystemExit) as context:
|
||||
player.config.dice.ljname = "phb.error.combrule.ljc"
|
||||
player.read_potentials()
|
||||
|
||||
self.assertEqual(
|
||||
str(context.exception),
|
||||
"Error: expected a '*' or a '+' sign in 1st line of file phb.error.combrule.ljc",
|
||||
)
|
||||
|
||||
# Testing ntypes error
|
||||
with self.assertRaises(SystemExit) as context:
|
||||
player.config.dice.ljname = "phb.error.ntypes.ljc"
|
||||
player.read_potentials()
|
||||
|
||||
self.assertEqual(
|
||||
str(context.exception),
|
||||
"Error: expected an integer in the 2nd line of file phb.error.ntypes.ljc",
|
||||
)
|
||||
|
||||
# Testing ntypes error on config
|
||||
with self.assertRaises(SystemExit) as context:
|
||||
player.config.dice.ljname = "phb.error.ntypes.config.ljc"
|
||||
player.read_potentials()
|
||||
|
||||
self.assertEqual(
|
||||
str(context.exception),
|
||||
"Error: number of molecule types in file phb.error.ntypes.config.ljc "
|
||||
"must match that of 'nmol' keyword in config file",
|
||||
)
|
||||
|
||||
# Testing nsite error
|
||||
with self.assertRaises(ValueError) as context:
|
||||
player.config.dice.ljname = "phb.error.nsites.ljc"
|
||||
player.read_potentials()
|
||||
|
||||
self.assertEqual(
|
||||
str(context.exception),
|
||||
"Error: expected nsites to be an integer for molecule type 1",
|
||||
)
|
||||
|
||||
# Testing molname error
|
||||
with self.assertRaises(ValueError) as context:
|
||||
player.config.dice.ljname = "phb.error.molname.ljc"
|
||||
player.read_potentials()
|
||||
|
||||
self.assertEqual(
|
||||
str(context.exception),
|
||||
"Error: expected nsites and molname for the molecule type 1",
|
||||
)
|
||||
|
||||
@mock.patch("builtins.open", mock_open)
|
||||
@mock.patch("diceplayer.player.Path.exists", return_value=True)
|
||||
def test_print_potentials(self, mock_path_exists):
|
||||
player = Player("control.test.yml")
|
||||
player.read_potentials()
|
||||
|
||||
with self.assertLogs(level="INFO") as context:
|
||||
player.print_potentials()
|
||||
|
||||
expected_output = [
|
||||
"INFO:diceplayer:==========================================================================================\n Potential parameters from file phb.ljc:\n------------------------------------------------------------------------------------------\n",
|
||||
"INFO:diceplayer:Combination rule: *",
|
||||
"INFO:diceplayer:Types of molecules: 2\n",
|
||||
"INFO:diceplayer:1 atoms in molecule type 1:",
|
||||
"INFO:diceplayer:---------------------------------------------------------------------------------",
|
||||
"INFO:diceplayer:Lbl AN X Y Z Charge Epsilon Sigma Mass",
|
||||
"INFO:diceplayer:---------------------------------------------------------------------------------",
|
||||
"INFO:diceplayer:1 1 0.00000 0.00000 0.00000 0.000000 0.00000 0.0000 1.0079",
|
||||
"INFO:diceplayer:\n",
|
||||
"INFO:diceplayer:1 atoms in molecule type 2:",
|
||||
"INFO:diceplayer:---------------------------------------------------------------------------------",
|
||||
"INFO:diceplayer:Lbl AN X Y Z Charge Epsilon Sigma Mass",
|
||||
"INFO:diceplayer:---------------------------------------------------------------------------------",
|
||||
"INFO:diceplayer:1 1 0.00000 0.00000 0.00000 0.000000 0.00000 0.0000 1.0079",
|
||||
"INFO:diceplayer:\n",
|
||||
]
|
||||
|
||||
self.assertEqual(context.output, expected_output)
|
||||
|
||||
@mock.patch("builtins.open", mock_open)
|
||||
def test_dice_start(self):
|
||||
player = Player("control.test.yml")
|
||||
player.dice_interface = mock.MagicMock()
|
||||
player.dice_interface.start = mock.MagicMock()
|
||||
|
||||
player.dice_start(1)
|
||||
|
||||
player.dice_interface.start.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
0
tests/utils/__init__.py
Normal file
0
tests/utils/__init__.py
Normal file
37
tests/utils/test_potential.py
Normal file
37
tests/utils/test_potential.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from diceplayer.config import PlayerConfig
|
||||
from diceplayer.environment import System
|
||||
from diceplayer.utils.potential import read_system_from_phb
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestPotential:
|
||||
@pytest.fixture
|
||||
def player_config(self) -> PlayerConfig:
|
||||
return PlayerConfig.model_validate(
|
||||
{
|
||||
"type": "both",
|
||||
"mem": 12,
|
||||
"max_cyc": 100,
|
||||
"switch_cyc": 50,
|
||||
"ncores": 4,
|
||||
"dice": {
|
||||
"nprocs": 4,
|
||||
"ljname": "phb.ljc.example",
|
||||
"outname": "test",
|
||||
"dens": 1.0,
|
||||
"nmol": [12, 16],
|
||||
"nstep": [1, 1],
|
||||
},
|
||||
"gaussian": {
|
||||
"level": "test",
|
||||
"qmprog": "g16",
|
||||
"keywords": "test",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
def test_read_phb(self, player_config: PlayerConfig):
|
||||
system = read_system_from_phb(player_config)
|
||||
|
||||
assert isinstance(system, System)
|
||||
59
uv.lock
generated
59
uv.lock
generated
@@ -231,6 +231,7 @@ dev = [
|
||||
{ name = "isort" },
|
||||
{ name = "poethepoet" },
|
||||
{ name = "pre-commit" },
|
||||
{ name = "pytest" },
|
||||
{ name = "ruff" },
|
||||
]
|
||||
|
||||
@@ -251,6 +252,7 @@ dev = [
|
||||
{ name = "isort", specifier = ">=5.13.2" },
|
||||
{ name = "poethepoet", specifier = ">=0.27.0" },
|
||||
{ name = "pre-commit", specifier = ">=3.7.1" },
|
||||
{ name = "pytest", specifier = ">=9.0.2" },
|
||||
{ name = "ruff", specifier = ">=0.15.2" },
|
||||
]
|
||||
|
||||
@@ -263,6 +265,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/33/6b/e0547afaf41bf2c42e52430072fa5658766e3d65bd4b03a563d1b6336f57/distlib-0.4.0-py2.py3-none-any.whl", hash = "sha256:9659f7d87e46584a30b5780e43ac7a2143098441670ff0a49d5f9034c54a6c16", size = 469047, upload-time = "2025-07-17T16:51:58.613Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "exceptiongroup"
|
||||
version = "1.3.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.11'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/8a/0e/97c33bf5009bdbac74fd2beace167cab3f978feb69cc36f1ef79360d6c4e/exceptiongroup-1.3.1-py3-none-any.whl", hash = "sha256:a7a39a3bd276781e98394987d3a5701d0c4edffb633bb7a5144577f82c773598", size = 16740, upload-time = "2025-11-21T23:01:53.443Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "filelock"
|
||||
version = "3.24.3"
|
||||
@@ -281,6 +295,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b8/58/40fbbcefeda82364720eba5cf2270f98496bdfa19ea75b4cccae79c698e6/identify-2.6.16-py2.py3-none-any.whl", hash = "sha256:391ee4d77741d994189522896270b787aed8670389bfd60f326d677d64a6dfb0", size = 99202, upload-time = "2026-01-12T18:58:56.627Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iniconfig"
|
||||
version = "2.3.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "isort"
|
||||
version = "8.0.1"
|
||||
@@ -491,6 +514,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/48/31/05e764397056194206169869b50cf2fee4dbbbc71b344705b9c0d878d4d8/platformdirs-4.9.2-py3-none-any.whl", hash = "sha256:9170634f126f8efdae22fb58ae8a0eaa86f38365bc57897a6c4f781d1f5875bd", size = 21168, upload-time = "2026-02-16T03:56:08.891Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pluggy"
|
||||
version = "1.6.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "poethepoet"
|
||||
version = "0.42.1"
|
||||
@@ -654,6 +686,33 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/36/c7/cfc8e811f061c841d7990b0201912c3556bfeb99cdcb7ed24adc8d6f8704/pydantic_core-2.41.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:56121965f7a4dc965bff783d70b907ddf3d57f6eba29b6d2e5dabfaf07799c51", size = 2145302, upload-time = "2025-11-04T13:43:46.64Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pygments"
|
||||
version = "2.19.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "9.0.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" },
|
||||
{ name = "iniconfig" },
|
||||
{ name = "packaging" },
|
||||
{ name = "pluggy" },
|
||||
{ name = "pygments" },
|
||||
{ name = "tomli", marker = "python_full_version < '3.11'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size = 1568901, upload-time = "2025-12-06T21:30:51.014Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "python-discovery"
|
||||
version = "1.1.0"
|
||||
|
||||
Reference in New Issue
Block a user