feat: reads potentions from ljc file
This commit is contained in:
@@ -3,7 +3,7 @@ diceplayer:
|
|||||||
switch_cyc: 3
|
switch_cyc: 3
|
||||||
max_cyc: 5
|
max_cyc: 5
|
||||||
mem: 24
|
mem: 24
|
||||||
ncores: 5
|
ncores: 20
|
||||||
qmprog: 'g16'
|
qmprog: 'g16'
|
||||||
lps: no
|
lps: no
|
||||||
ghosts: no
|
ghosts: no
|
||||||
@@ -11,14 +11,15 @@ diceplayer:
|
|||||||
|
|
||||||
dice:
|
dice:
|
||||||
nprocs: 4
|
nprocs: 4
|
||||||
nmol: [1, 100]
|
nmol: [1, 1000]
|
||||||
dens: 1.5
|
dens: 1.5
|
||||||
nstep: [2000, 3000]
|
nstep: [2000, 3000]
|
||||||
isave: 1000
|
isave: 0
|
||||||
outname: 'phb'
|
outname: 'phb'
|
||||||
progname: '~/.local/bin/dice'
|
progname: '~/.local/bin/dice'
|
||||||
ljname: 'phb.ljc'
|
ljname: 'phb.ljc.example'
|
||||||
randominit: 'always'
|
randominit: 'always'
|
||||||
|
seed: 12345
|
||||||
|
|
||||||
gaussian:
|
gaussian:
|
||||||
qmprog: 'g16'
|
qmprog: 'g16'
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
@@ -16,7 +18,7 @@ class DiceConfig(BaseModel):
|
|||||||
..., description="Number of processes to use for the DICE simulations"
|
..., description="Number of processes to use for the DICE simulations"
|
||||||
)
|
)
|
||||||
|
|
||||||
ljname: str = Field(..., description="Name of the Lennard-Jones potential file")
|
ljname: Path = Field(..., description="Name of the Lennard-Jones potential file")
|
||||||
outname: str = Field(
|
outname: str = Field(
|
||||||
..., description="Name of the output file for the simulation results"
|
..., description="Name of the output file for the simulation results"
|
||||||
)
|
)
|
||||||
@@ -47,10 +49,13 @@ class DiceConfig(BaseModel):
|
|||||||
isave: int = Field(1000, description="Frequency of saving the simulation results")
|
isave: int = Field(1000, description="Frequency of saving the simulation results")
|
||||||
press: float = Field(1.0, description="Pressure of the system")
|
press: float = Field(1.0, description="Pressure of the system")
|
||||||
temp: float = Field(300.0, description="Temperature of the system")
|
temp: float = Field(300.0, description="Temperature of the system")
|
||||||
progname: str = Field(
|
progname: Path = Field(
|
||||||
"dice", description="Name of the program to run the simulation"
|
"dice", description="Name of the program to run the simulation"
|
||||||
)
|
)
|
||||||
randominit: str = Field(
|
randominit: str = Field(
|
||||||
"first", description="Method for initializing the random number generator"
|
"first", description="Method for initializing the random number generator"
|
||||||
)
|
)
|
||||||
seed: int = Field(default_factory=lambda: random.randint(0, 2**32 - 1), description="Seed for the random number generator")
|
seed: int = Field(
|
||||||
|
default_factory=lambda: int(1e6 * random.random()),
|
||||||
|
description="Seed for the random number generator",
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,10 +1,14 @@
|
|||||||
import shutil
|
from diceplayer.dice.dice_input import (
|
||||||
|
NPTEqConfig,
|
||||||
from diceplayer.dice.dice_input import NVTTerConfig, NVTEqConfig, NPTEqConfig, NPTTerConfig
|
NPTTerConfig,
|
||||||
|
NVTEqConfig,
|
||||||
|
NVTTerConfig,
|
||||||
|
)
|
||||||
from diceplayer.dice.dice_wrapper import DiceWrapper
|
from diceplayer.dice.dice_wrapper import DiceWrapper
|
||||||
from diceplayer.logger import logger
|
from diceplayer.logger import logger
|
||||||
from diceplayer.state.state_model import StateModel
|
from diceplayer.state.state_model import StateModel
|
||||||
|
|
||||||
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
@@ -15,7 +19,9 @@ class DiceHandler:
|
|||||||
|
|
||||||
def run(self, state: StateModel, cycle: int) -> StateModel:
|
def run(self, state: StateModel, cycle: int) -> StateModel:
|
||||||
if self.dice_directory.exists():
|
if self.dice_directory.exists():
|
||||||
logger.info(f"Found dice directory: {self.dice_directory}, this directory will be purged for a clean state")
|
logger.info(
|
||||||
|
f"Found dice directory: {self.dice_directory}, this directory will be purged for a clean state"
|
||||||
|
)
|
||||||
shutil.rmtree(self.dice_directory)
|
shutil.rmtree(self.dice_directory)
|
||||||
self.dice_directory.mkdir(parents=True)
|
self.dice_directory.mkdir(parents=True)
|
||||||
|
|
||||||
@@ -38,7 +44,9 @@ class DiceHandler:
|
|||||||
t.join()
|
t.join()
|
||||||
|
|
||||||
if len(results) != state.config.dice.nprocs:
|
if len(results) != state.config.dice.nprocs:
|
||||||
raise RuntimeError(f"Expected {state.config.dice.nprocs} simulation results, but got {len(results)}")
|
raise RuntimeError(
|
||||||
|
f"Expected {state.config.dice.nprocs} simulation results, but got {len(results)}"
|
||||||
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@@ -47,21 +55,23 @@ class DiceHandler:
|
|||||||
def commit_simulation_state(self, state: StateModel, result: dict) -> StateModel:
|
def commit_simulation_state(self, state: StateModel, result: dict) -> StateModel:
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def _simulation_process(self, state: StateModel, cycle: int, proc: int, results: list[dict]) -> None:
|
def _simulation_process(
|
||||||
|
self, state: StateModel, cycle: int, proc: int, results: list[dict]
|
||||||
|
) -> None:
|
||||||
proc_directory = self.dice_directory / f"{proc:02d}"
|
proc_directory = self.dice_directory / f"{proc:02d}"
|
||||||
if proc_directory.exists():
|
if proc_directory.exists():
|
||||||
shutil.rmtree(proc_directory)
|
shutil.rmtree(proc_directory)
|
||||||
proc_directory.mkdir(parents=True)
|
proc_directory.mkdir(parents=True)
|
||||||
|
|
||||||
dice = DiceWrapper(
|
dice = DiceWrapper(state.config.dice, proc_directory)
|
||||||
state.config.dice, proc_directory
|
|
||||||
)
|
|
||||||
|
|
||||||
if state.config.dice.randominit == "first" and cycle == 0:
|
self._generate_phb_file(state, proc_directory)
|
||||||
|
|
||||||
|
if state.config.dice.randominit == "first" and cycle >= 0:
|
||||||
|
self._generate_last_xyz(state, proc_directory)
|
||||||
|
else:
|
||||||
nvt_ter_config = NVTTerConfig.from_config(state.config)
|
nvt_ter_config = NVTTerConfig.from_config(state.config)
|
||||||
dice.run(nvt_ter_config)
|
dice.run(nvt_ter_config)
|
||||||
else:
|
|
||||||
self._generate_last_xyz(state, proc_directory)
|
|
||||||
|
|
||||||
if len(state.config.dice.nstep) == 2:
|
if len(state.config.dice.nstep) == 2:
|
||||||
nvt_eq_config = NVTEqConfig.from_config(state.config)
|
nvt_eq_config = NVTEqConfig.from_config(state.config)
|
||||||
@@ -76,5 +86,30 @@ class DiceHandler:
|
|||||||
|
|
||||||
results.append(dice.extract_results())
|
results.append(dice.extract_results())
|
||||||
|
|
||||||
def _generate_last_xyz(self, state: StateModel, proc_directory: Path) -> None:
|
@staticmethod
|
||||||
...
|
def _generate_phb_file(state: StateModel, proc_directory: Path) -> None:
|
||||||
|
fstr = "{:<3d} {:>3d} {:>10.5f} {:>10.5f} {:>10.5f} {:>10.6f} {:>9.5f} {:>7.4f}\n"
|
||||||
|
|
||||||
|
phb_file = proc_directory / state.config.dice.ljname
|
||||||
|
|
||||||
|
with open(phb_file, "w") as f:
|
||||||
|
f.write(f"{state.config.dice.combrule}\n")
|
||||||
|
f.write(f"{len(state.config.dice.nmol)}\n")
|
||||||
|
|
||||||
|
for molecule in state.system.molecule:
|
||||||
|
f.write(f"{len(molecule.atom)} {molecule.molname}\n")
|
||||||
|
for atom in molecule.atom:
|
||||||
|
f.write(
|
||||||
|
fstr.format(
|
||||||
|
atom.lbl,
|
||||||
|
atom.na,
|
||||||
|
atom.rx,
|
||||||
|
atom.ry,
|
||||||
|
atom.rz,
|
||||||
|
atom.chg,
|
||||||
|
atom.eps,
|
||||||
|
atom.sig,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_last_xyz(self, state: StateModel, proc_directory: Path) -> None: ...
|
||||||
|
|||||||
@@ -9,6 +9,30 @@ from pathlib import Path
|
|||||||
from typing import Any, Sequence, TextIO
|
from typing import Any, Sequence, TextIO
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
DICE_KEYWORD_ORDER = [
|
||||||
|
"title",
|
||||||
|
"ncores",
|
||||||
|
"ljname",
|
||||||
|
"outname",
|
||||||
|
"nmol",
|
||||||
|
"dens",
|
||||||
|
"temp",
|
||||||
|
"press",
|
||||||
|
"seed",
|
||||||
|
"init",
|
||||||
|
"nstep",
|
||||||
|
"vstep",
|
||||||
|
"mstop",
|
||||||
|
"accum",
|
||||||
|
"iprint",
|
||||||
|
"isave",
|
||||||
|
"irdf",
|
||||||
|
"upbuf",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
class BaseConfig(ABC):
|
class BaseConfig(ABC):
|
||||||
ncores: int
|
ncores: int
|
||||||
@@ -18,13 +42,14 @@ class BaseConfig(ABC):
|
|||||||
temp: float
|
temp: float
|
||||||
seed: int
|
seed: int
|
||||||
isave: int
|
isave: int
|
||||||
press: float = 1.0
|
|
||||||
|
|
||||||
def write(self, directory: Path, filename: str = "input") -> Path:
|
def write(self, directory: Path, filename: str = "input") -> Path:
|
||||||
input_path = directory / filename
|
input_path = directory / filename
|
||||||
|
|
||||||
if input_path.exists():
|
if input_path.exists():
|
||||||
logger.info(f"Dice input file {input_path} already exists and will be overwritten")
|
logger.info(
|
||||||
|
f"Dice input file {input_path} already exists and will be overwritten"
|
||||||
|
)
|
||||||
input_path.unlink()
|
input_path.unlink()
|
||||||
input_path.parent.mkdir(parents=True, exist_ok=True)
|
input_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
@@ -34,13 +59,18 @@ class BaseConfig(ABC):
|
|||||||
return input_path
|
return input_path
|
||||||
|
|
||||||
def write_dice_config(self, io_writer: TextIO) -> None:
|
def write_dice_config(self, io_writer: TextIO) -> None:
|
||||||
for field in fields(self):
|
values = {f.name: getattr(self, f.name) for f in fields(self)}
|
||||||
key = field.name
|
|
||||||
value = getattr(self, key)
|
|
||||||
|
|
||||||
|
for key in DICE_KEYWORD_ORDER:
|
||||||
|
value = values.pop(key, None)
|
||||||
if value is None:
|
if value is None:
|
||||||
continue
|
continue
|
||||||
|
io_writer.write(f"{key} = {self._serialize_value(value)}\n")
|
||||||
|
|
||||||
|
# write any remaining fields (future extensions)
|
||||||
|
for key, value in values.items():
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
io_writer.write(f"{key} = {self._serialize_value(value)}\n")
|
io_writer.write(f"{key} = {self._serialize_value(value)}\n")
|
||||||
|
|
||||||
io_writer.write("$end\n")
|
io_writer.write("$end\n")
|
||||||
@@ -48,7 +78,7 @@ class BaseConfig(ABC):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
|
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
|
||||||
base_fields = cls._extract_base_fields(config)
|
base_fields = cls._extract_base_fields(config)
|
||||||
return cls(**base_fields, **kwargs)
|
return cls(**(base_fields | kwargs))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_base_fields(config: PlayerConfig) -> dict[str, Any]:
|
def _extract_base_fields(config: PlayerConfig) -> dict[str, Any]:
|
||||||
@@ -60,7 +90,6 @@ class BaseConfig(ABC):
|
|||||||
temp=config.dice.temp,
|
temp=config.dice.temp,
|
||||||
seed=config.dice.seed,
|
seed=config.dice.seed,
|
||||||
isave=config.dice.isave,
|
isave=config.dice.isave,
|
||||||
press=config.dice.press,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -91,10 +120,18 @@ class BaseConfig(ABC):
|
|||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
class NVTConfig(BaseConfig):
|
class NVTConfig(BaseConfig):
|
||||||
title: str = "Diceplayer Run - NVT"
|
title: str = "Diceplayer Run - NVT"
|
||||||
dens: float = 0.0
|
dens: float = ...
|
||||||
nstep: int = 0
|
nstep: int = ...
|
||||||
vstep: int = 0
|
vstep: int = 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
|
||||||
|
return super(NVTConfig, cls).from_config(
|
||||||
|
config,
|
||||||
|
dens=config.dice.dens,
|
||||||
|
nstep=cls._get_nstep(config, 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------
|
# -----------------------------------------------------
|
||||||
# NVT THERMALIZATION
|
# NVT THERMALIZATION
|
||||||
@@ -105,12 +142,12 @@ class NVTConfig(BaseConfig):
|
|||||||
class NVTTerConfig(NVTConfig):
|
class NVTTerConfig(NVTConfig):
|
||||||
title: str = "Diceplayer Run - NVT Thermalization"
|
title: str = "Diceplayer Run - NVT Thermalization"
|
||||||
upbuf: int = 360
|
upbuf: int = 360
|
||||||
|
init: str = "yes"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
|
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
|
||||||
return super(NVTTerConfig, cls).from_config(
|
return super(NVTTerConfig, cls).from_config(
|
||||||
config,
|
config,
|
||||||
dens=config.dice.dens,
|
|
||||||
nstep=cls._get_nstep(config, 0),
|
nstep=cls._get_nstep(config, 0),
|
||||||
upbuf=config.dice.upbuf,
|
upbuf=config.dice.upbuf,
|
||||||
vstep=0,
|
vstep=0,
|
||||||
@@ -118,9 +155,7 @@ class NVTTerConfig(NVTConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def write(self, directory: Path, filename: str = "nvt.ter") -> Path:
|
def write(self, directory: Path, filename: str = "nvt.ter") -> Path:
|
||||||
return super(NVTTerConfig, self).write(
|
return super(NVTTerConfig, self).write(directory, filename)
|
||||||
directory, filename
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------
|
# -----------------------------------------------------
|
||||||
@@ -132,12 +167,12 @@ class NVTTerConfig(NVTConfig):
|
|||||||
class NVTEqConfig(NVTConfig):
|
class NVTEqConfig(NVTConfig):
|
||||||
title: str = "Diceplayer Run - NVT Production"
|
title: str = "Diceplayer Run - NVT Production"
|
||||||
irdf: int = 0
|
irdf: int = 0
|
||||||
|
init: str = "yesreadxyz"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
|
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
|
||||||
return super(NVTEqConfig, cls).from_config(
|
return super(NVTEqConfig, cls).from_config(
|
||||||
config,
|
config,
|
||||||
dens=config.dice.dens,
|
|
||||||
nstep=cls._get_nstep(config, 1),
|
nstep=cls._get_nstep(config, 1),
|
||||||
irdf=config.dice.irdf,
|
irdf=config.dice.irdf,
|
||||||
vstep=0,
|
vstep=0,
|
||||||
@@ -145,9 +180,7 @@ class NVTEqConfig(NVTConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def write(self, directory: Path, filename: str = "nvt.eq") -> Path:
|
def write(self, directory: Path, filename: str = "nvt.eq") -> Path:
|
||||||
return super(NVTEqConfig, self).write(
|
return super(NVTEqConfig, self).write(directory, filename)
|
||||||
directory, filename
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------
|
# -----------------------------------------------------
|
||||||
@@ -160,6 +193,14 @@ class NPTConfig(BaseConfig):
|
|||||||
title: str = "Diceplayer Run - NPT"
|
title: str = "Diceplayer Run - NPT"
|
||||||
nstep: int = 0
|
nstep: int = 0
|
||||||
vstep: int = 5000
|
vstep: int = 5000
|
||||||
|
press: float = 1.0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: PlayerConfig, **kwargs) -> Self:
|
||||||
|
return super(NPTConfig, cls).from_config(
|
||||||
|
config,
|
||||||
|
press=config.dice.press,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------
|
# -----------------------------------------------------
|
||||||
@@ -183,9 +224,7 @@ class NPTTerConfig(NPTConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def write(self, directory: Path, filename: str = "npt.ter") -> Path:
|
def write(self, directory: Path, filename: str = "npt.ter") -> Path:
|
||||||
return super(NPTTerConfig, self).write(
|
return super(NPTTerConfig, self).write(directory, filename)
|
||||||
directory, filename
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------
|
# -----------------------------------------------------
|
||||||
@@ -209,6 +248,4 @@ class NPTEqConfig(NPTConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def write(self, directory: Path, filename: str = "npt.eq") -> Path:
|
def write(self, directory: Path, filename: str = "npt.eq") -> Path:
|
||||||
return super(NPTEqConfig, self).write(
|
return super(NPTEqConfig, self).write(directory, filename)
|
||||||
directory, filename
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,12 +1,10 @@
|
|||||||
import subprocess
|
|
||||||
from typing import Final
|
|
||||||
|
|
||||||
import diceplayer.dice.dice_input as dice_input
|
import diceplayer.dice.dice_input as dice_input
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from diceplayer.config import DiceConfig
|
from diceplayer.config import DiceConfig
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Final
|
||||||
|
|
||||||
|
|
||||||
DICE_FLAG_LINE: Final[int] = -2
|
DICE_FLAG_LINE: Final[int] = -2
|
||||||
DICE_END_FLAG: Final[str] = "End of simulation"
|
DICE_END_FLAG: Final[str] = "End of simulation"
|
||||||
@@ -22,8 +20,9 @@ class DiceWrapper:
|
|||||||
output_path = input_path.parent / (input_path.name + ".out")
|
output_path = input_path.parent / (input_path.name + ".out")
|
||||||
|
|
||||||
with open(output_path, "w") as outfile, open(input_path, "r") as infile:
|
with open(output_path, "w") as outfile, open(input_path, "r") as infile:
|
||||||
|
bin_path = self.dice_config.progname.expanduser()
|
||||||
exit_status = subprocess.call(
|
exit_status = subprocess.call(
|
||||||
self.dice_config.progname, stdin=infile, stdout=outfile
|
bin_path, stdin=infile, stdout=outfile, cwd=self.working_directory
|
||||||
)
|
)
|
||||||
|
|
||||||
if exit_status != 0:
|
if exit_status != 0:
|
||||||
@@ -38,4 +37,3 @@ class DiceWrapper:
|
|||||||
|
|
||||||
def extract_results(self) -> dict:
|
def extract_results(self) -> dict:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
from diceplayer.utils.ptable import AtomInfo, PTable
|
from diceplayer.utils.ptable import AtomInfo, PTable
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass(slots=True)
|
||||||
class Atom:
|
class Atom:
|
||||||
"""
|
"""
|
||||||
Atom class declaration. This class is used throughout the DicePlayer program to represent atoms.
|
Atom class declaration. This class is used throughout the DicePlayer program to represent atoms.
|
||||||
|
|||||||
@@ -10,10 +10,11 @@ import numpy as np
|
|||||||
import numpy.linalg as linalg
|
import numpy.linalg as linalg
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
from typing_extensions import List, Self, Tuple
|
from typing_extensions import List, Self, Tuple
|
||||||
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import field
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -16,4 +16,4 @@ class OptimizationHandler:
|
|||||||
if current_cycle < state.config.switch_cyc:
|
if current_cycle < state.config.switch_cyc:
|
||||||
return RoutineType.CHARGE
|
return RoutineType.CHARGE
|
||||||
|
|
||||||
return RoutineType.GEOMETRY
|
return RoutineType.GEOMETRY
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
from diceplayer.config.player_config import PlayerConfig
|
from diceplayer.config.player_config import PlayerConfig
|
||||||
from diceplayer.dice.dice_handler import DiceHandler
|
from diceplayer.dice.dice_handler import DiceHandler
|
||||||
from diceplayer.logger import logger
|
from diceplayer.logger import logger
|
||||||
from diceplayer.optimization.optimization_handler import OptimizationHandler
|
|
||||||
from diceplayer.state.state_handler import StateHandler
|
from diceplayer.state.state_handler import StateHandler
|
||||||
from diceplayer.state.state_model import StateModel
|
from diceplayer.state.state_model import StateModel
|
||||||
|
|
||||||
from typing_extensions import TypedDict, Unpack
|
from typing_extensions import TypedDict, Unpack
|
||||||
|
|
||||||
|
from diceplayer.utils.potential import read_system_from_phb
|
||||||
|
|
||||||
|
|
||||||
class PlayerFlags(TypedDict):
|
class PlayerFlags(TypedDict):
|
||||||
continuation: bool
|
continuation: bool
|
||||||
@@ -30,9 +31,11 @@ class Player:
|
|||||||
self._state_handler.delete()
|
self._state_handler.delete()
|
||||||
state = None
|
state = None
|
||||||
|
|
||||||
|
|
||||||
if state is None:
|
if state is None:
|
||||||
state = StateModel.from_config(self.config)
|
system = read_system_from_phb(self.config)
|
||||||
|
state = StateModel(
|
||||||
|
config=self.config, system=system
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("Resuming from existing state.")
|
logger.info("Resuming from existing state.")
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from typing_extensions import Self
|
|||||||
class StateModel(BaseModel):
|
class StateModel(BaseModel):
|
||||||
config: PlayerConfig
|
config: PlayerConfig
|
||||||
system: System
|
system: System
|
||||||
current_cycle: int
|
current_cycle: int = 0
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: PlayerConfig) -> Self:
|
def from_config(cls, config: PlayerConfig) -> Self:
|
||||||
|
|||||||
62
diceplayer/utils/potential.py
Normal file
62
diceplayer/utils/potential.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
from diceplayer.config import PlayerConfig
|
||||||
|
from diceplayer.environment import System, Molecule, Atom
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from diceplayer.logger import logger
|
||||||
|
from diceplayer.state.state_model import StateModel
|
||||||
|
|
||||||
|
|
||||||
|
def read_system_from_phb(config: PlayerConfig) -> System:
|
||||||
|
phb_file = Path(config.dice.ljname)
|
||||||
|
if not phb_file.exists():
|
||||||
|
raise FileNotFoundError
|
||||||
|
|
||||||
|
ljc_data = phb_file.read_text(encoding="utf-8").splitlines()
|
||||||
|
|
||||||
|
combrule = ljc_data.pop(0).strip()
|
||||||
|
if combrule != config.dice.combrule:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid combrule defined in {phb_file}. Expected the same value configured in the config file"
|
||||||
|
)
|
||||||
|
|
||||||
|
ntypes = ljc_data.pop(0).strip()
|
||||||
|
if not ntypes.isdigit():
|
||||||
|
raise ValueError(f"Invalid ntypes defined in {phb_file}")
|
||||||
|
|
||||||
|
nmol = int(ntypes)
|
||||||
|
if nmol != len(config.dice.nmol):
|
||||||
|
raise ValueError(f"Invalid nmol defined in {phb_file}")
|
||||||
|
|
||||||
|
sys = System()
|
||||||
|
|
||||||
|
for i in range(nmol):
|
||||||
|
nsites, molname = ljc_data.pop(0).split()
|
||||||
|
|
||||||
|
if not nsites.isdigit():
|
||||||
|
raise ValueError(f"Invalid nsites defined in {phb_file}")
|
||||||
|
nsites = int(nsites)
|
||||||
|
|
||||||
|
mol = Molecule(molname)
|
||||||
|
|
||||||
|
for j in range(nsites):
|
||||||
|
_fields = ljc_data.pop(0).split()
|
||||||
|
mol.add_atom(
|
||||||
|
Atom(*_fields)
|
||||||
|
)
|
||||||
|
|
||||||
|
sys.add_type(mol)
|
||||||
|
|
||||||
|
return sys
|
||||||
|
|
||||||
|
|
||||||
|
# def write_phb(phb_file: Path, state: StateModel) -> None:
|
||||||
|
# if phb_file.exists():
|
||||||
|
# raise RuntimeError(f"File {phb_file} already exists")
|
||||||
|
#
|
||||||
|
# with open(phb_file, "w") as f:
|
||||||
|
# f.write(f"{state.config.dice.combrule}\n")
|
||||||
|
# f.write(f"{len(state.system.nmols)}\n")
|
||||||
|
# f.write(f"{state.config.dice.nmol}\n")
|
||||||
|
|
||||||
|
|
||||||
@@ -6,6 +6,7 @@ import pytest
|
|||||||
class TestDiceConfig:
|
class TestDiceConfig:
|
||||||
def test_class_instantiation(self):
|
def test_class_instantiation(self):
|
||||||
dice_dto = DiceConfig(
|
dice_dto = DiceConfig(
|
||||||
|
nprocs=1,
|
||||||
ljname="test",
|
ljname="test",
|
||||||
outname="test",
|
outname="test",
|
||||||
dens=1.0,
|
dens=1.0,
|
||||||
@@ -18,6 +19,7 @@ class TestDiceConfig:
|
|||||||
def test_validate_jname(self):
|
def test_validate_jname(self):
|
||||||
with pytest.raises(ValueError) as ex:
|
with pytest.raises(ValueError) as ex:
|
||||||
DiceConfig(
|
DiceConfig(
|
||||||
|
nprocs=1,
|
||||||
ljname=None,
|
ljname=None,
|
||||||
outname="test",
|
outname="test",
|
||||||
dens=1.0,
|
dens=1.0,
|
||||||
@@ -30,6 +32,7 @@ class TestDiceConfig:
|
|||||||
def test_validate_outname(self):
|
def test_validate_outname(self):
|
||||||
with pytest.raises(ValueError) as ex:
|
with pytest.raises(ValueError) as ex:
|
||||||
DiceConfig(
|
DiceConfig(
|
||||||
|
nprocs=1,
|
||||||
ljname="test",
|
ljname="test",
|
||||||
outname=None,
|
outname=None,
|
||||||
dens=1.0,
|
dens=1.0,
|
||||||
@@ -42,6 +45,7 @@ class TestDiceConfig:
|
|||||||
def test_validate_dens(self):
|
def test_validate_dens(self):
|
||||||
with pytest.raises(ValueError) as ex:
|
with pytest.raises(ValueError) as ex:
|
||||||
DiceConfig(
|
DiceConfig(
|
||||||
|
nprocs=1,
|
||||||
ljname="test",
|
ljname="test",
|
||||||
outname="test",
|
outname="test",
|
||||||
dens=None,
|
dens=None,
|
||||||
@@ -54,6 +58,7 @@ class TestDiceConfig:
|
|||||||
def test_validate_nmol(self):
|
def test_validate_nmol(self):
|
||||||
with pytest.raises(ValueError) as ex:
|
with pytest.raises(ValueError) as ex:
|
||||||
DiceConfig(
|
DiceConfig(
|
||||||
|
nprocs=1,
|
||||||
ljname="test",
|
ljname="test",
|
||||||
outname="test",
|
outname="test",
|
||||||
dens=1.0,
|
dens=1.0,
|
||||||
@@ -66,6 +71,7 @@ class TestDiceConfig:
|
|||||||
def test_validate_nstep(self):
|
def test_validate_nstep(self):
|
||||||
with pytest.raises(ValueError) as ex:
|
with pytest.raises(ValueError) as ex:
|
||||||
DiceConfig(
|
DiceConfig(
|
||||||
|
nprocs=1,
|
||||||
ljname="test",
|
ljname="test",
|
||||||
outname="test",
|
outname="test",
|
||||||
dens=1.0,
|
dens=1.0,
|
||||||
@@ -78,6 +84,7 @@ class TestDiceConfig:
|
|||||||
def test_from_dict(self):
|
def test_from_dict(self):
|
||||||
dice_dto = DiceConfig.model_validate(
|
dice_dto = DiceConfig.model_validate(
|
||||||
{
|
{
|
||||||
|
"nprocs": 1,
|
||||||
"ljname": "test",
|
"ljname": "test",
|
||||||
"outname": "test",
|
"outname": "test",
|
||||||
"dens": 1.0,
|
"dens": 1.0,
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ class TestPlayerConfig:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def dice_payload(self) -> dict[str, Any]:
|
def dice_payload(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
|
"nprocs": 4,
|
||||||
"ljname": "test",
|
"ljname": "test",
|
||||||
"outname": "test",
|
"outname": "test",
|
||||||
"dens": 1.0,
|
"dens": 1.0,
|
||||||
@@ -35,7 +36,6 @@ class TestPlayerConfig:
|
|||||||
"mem": 12,
|
"mem": 12,
|
||||||
"max_cyc": 100,
|
"max_cyc": 100,
|
||||||
"switch_cyc": 50,
|
"switch_cyc": 50,
|
||||||
"nprocs": 4,
|
|
||||||
"ncores": 4,
|
"ncores": 4,
|
||||||
"dice": dice_payload,
|
"dice": dice_payload,
|
||||||
"gaussian": gaussian_payload,
|
"gaussian": gaussian_payload,
|
||||||
@@ -57,7 +57,6 @@ class TestPlayerConfig:
|
|||||||
mem=12,
|
mem=12,
|
||||||
max_cyc=100,
|
max_cyc=100,
|
||||||
switch_cyc=50,
|
switch_cyc=50,
|
||||||
nprocs=4,
|
|
||||||
ncores=4,
|
ncores=4,
|
||||||
dice=dice_config,
|
dice=dice_config,
|
||||||
gaussian=gaussian_config,
|
gaussian=gaussian_config,
|
||||||
@@ -75,7 +74,6 @@ class TestPlayerConfig:
|
|||||||
mem=12,
|
mem=12,
|
||||||
max_cyc=100,
|
max_cyc=100,
|
||||||
switch_cyc=50,
|
switch_cyc=50,
|
||||||
nprocs=4,
|
|
||||||
ncores=4,
|
ncores=4,
|
||||||
altsteps=0,
|
altsteps=0,
|
||||||
dice=dice_config,
|
dice=dice_config,
|
||||||
|
|||||||
@@ -20,9 +20,9 @@ class TestDiceInput:
|
|||||||
"mem": 12,
|
"mem": 12,
|
||||||
"max_cyc": 100,
|
"max_cyc": 100,
|
||||||
"switch_cyc": 50,
|
"switch_cyc": 50,
|
||||||
"nprocs": 4,
|
|
||||||
"ncores": 4,
|
"ncores": 4,
|
||||||
"dice": {
|
"dice": {
|
||||||
|
"nprocs": 4,
|
||||||
"ljname": "test",
|
"ljname": "test",
|
||||||
"outname": "test",
|
"outname": "test",
|
||||||
"dens": 1.0,
|
"dens": 1.0,
|
||||||
|
|||||||
@@ -1,113 +0,0 @@
|
|||||||
from unittest import mock
|
|
||||||
|
|
||||||
|
|
||||||
def get_config_example():
|
|
||||||
return """
|
|
||||||
diceplayer:
|
|
||||||
opt: no
|
|
||||||
mem: 12
|
|
||||||
maxcyc: 3
|
|
||||||
ncores: 4
|
|
||||||
nprocs: 4
|
|
||||||
qmprog: 'g16'
|
|
||||||
lps: no
|
|
||||||
ghosts: no
|
|
||||||
altsteps: 20000
|
|
||||||
|
|
||||||
dice:
|
|
||||||
nmol: [1, 50]
|
|
||||||
dens: 0.75
|
|
||||||
nstep: [2000, 3000, 4000]
|
|
||||||
isave: 1000
|
|
||||||
outname: 'phb'
|
|
||||||
progname: '~/.local/bin/dice'
|
|
||||||
ljname: 'phb.ljc'
|
|
||||||
randominit: 'first'
|
|
||||||
|
|
||||||
gaussian:
|
|
||||||
qmprog: 'g16'
|
|
||||||
level: 'MP2/aug-cc-pVDZ'
|
|
||||||
keywords: 'freq'
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def get_potentials_exemple():
|
|
||||||
return """\
|
|
||||||
*
|
|
||||||
2
|
|
||||||
1 TEST
|
|
||||||
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
|
|
||||||
1 PLACEHOLDER
|
|
||||||
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def get_potentials_error_combrule():
|
|
||||||
return """\
|
|
||||||
.
|
|
||||||
2
|
|
||||||
1 TEST
|
|
||||||
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
|
|
||||||
1 PLACEHOLDER
|
|
||||||
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def get_potentials_error_ntypes():
|
|
||||||
return """\
|
|
||||||
*
|
|
||||||
a
|
|
||||||
1 TEST
|
|
||||||
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
|
|
||||||
1 PLACEHOLDER
|
|
||||||
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def get_potentials_error_ntypes_config():
|
|
||||||
return """\
|
|
||||||
*
|
|
||||||
3
|
|
||||||
1 TEST
|
|
||||||
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
|
|
||||||
1 PLACEHOLDER
|
|
||||||
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def get_potentials_error_nsites():
|
|
||||||
return """\
|
|
||||||
*
|
|
||||||
2
|
|
||||||
. TEST
|
|
||||||
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
|
|
||||||
1 PLACEHOLDER
|
|
||||||
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def get_potentials_error_molname():
|
|
||||||
return """\
|
|
||||||
*
|
|
||||||
2
|
|
||||||
1
|
|
||||||
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
|
|
||||||
1 PLACEHOLDER
|
|
||||||
1 1 0.000000 0.000000 0.000000 0.000000 0.0000 0.0000
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def mock_open(file, *args, **kwargs):
|
|
||||||
values = {
|
|
||||||
"control.test.yml": get_config_example(),
|
|
||||||
"phb.ljc": get_potentials_exemple(),
|
|
||||||
"phb.error.combrule.ljc": get_potentials_error_combrule(),
|
|
||||||
"phb.error.ntypes.ljc": get_potentials_error_ntypes(),
|
|
||||||
"phb.error.ntypes.config.ljc": get_potentials_error_ntypes_config(),
|
|
||||||
"phb.error.nsites.ljc": get_potentials_error_nsites(),
|
|
||||||
"phb.error.molname.ljc": get_potentials_error_molname(),
|
|
||||||
}
|
|
||||||
if file in values:
|
|
||||||
return mock.mock_open(read_data=values[file])()
|
|
||||||
|
|
||||||
return mock.mock_open(read_data="")()
|
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
from typing_extensions import List
|
|
||||||
|
|
||||||
import itertools
|
|
||||||
|
|
||||||
|
|
||||||
class MockProc:
|
|
||||||
pid_counter = itertools.count()
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
self.pid = next(MockProc.pid_counter)
|
|
||||||
|
|
||||||
if "exitcode" in kwargs:
|
|
||||||
self.exitcode = kwargs["exitcode"]
|
|
||||||
else:
|
|
||||||
self.exitcode = 0
|
|
||||||
|
|
||||||
self.sentinel = self.pid
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def terminate(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class MockConnection:
|
|
||||||
@staticmethod
|
|
||||||
def wait(sentinels: List[int]):
|
|
||||||
return sentinels
|
|
||||||
@@ -16,9 +16,9 @@ class TestStateHandler:
|
|||||||
mem=12,
|
mem=12,
|
||||||
max_cyc=100,
|
max_cyc=100,
|
||||||
switch_cyc=50,
|
switch_cyc=50,
|
||||||
nprocs=4,
|
|
||||||
ncores=4,
|
ncores=4,
|
||||||
dice=DiceConfig(
|
dice=DiceConfig(
|
||||||
|
nprocs=4,
|
||||||
ljname="test",
|
ljname="test",
|
||||||
outname="test",
|
outname="test",
|
||||||
dens=1.0,
|
dens=1.0,
|
||||||
|
|||||||
38
tests/utils/test_potential.py
Normal file
38
tests/utils/test_potential.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from diceplayer.config import PlayerConfig
|
||||||
|
from diceplayer.environment import System
|
||||||
|
from diceplayer.utils.potential import read_system_from_phb
|
||||||
|
|
||||||
|
|
||||||
|
class TestPotential:
|
||||||
|
@pytest.fixture
|
||||||
|
def player_config(self) -> PlayerConfig:
|
||||||
|
return PlayerConfig.model_validate({
|
||||||
|
"type": "both",
|
||||||
|
"mem": 12,
|
||||||
|
"max_cyc": 100,
|
||||||
|
"switch_cyc": 50,
|
||||||
|
"ncores": 4,
|
||||||
|
"dice": {
|
||||||
|
"nprocs": 4,
|
||||||
|
"ljname": "phb.ljc.example",
|
||||||
|
"outname": "test",
|
||||||
|
"dens": 1.0,
|
||||||
|
"nmol": [12, 16],
|
||||||
|
"nstep": [1, 1],
|
||||||
|
},
|
||||||
|
"gaussian": {
|
||||||
|
"level": "test",
|
||||||
|
"qmprog": "g16",
|
||||||
|
"keywords": "test",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
def test_read_phb(self, player_config: PlayerConfig):
|
||||||
|
system = read_system_from_phb(player_config)
|
||||||
|
|
||||||
|
assert isinstance(system, System)
|
||||||
Reference in New Issue
Block a user