Adds Formatter to Project
This commit is contained in:
@@ -1,24 +1,24 @@
|
||||
from diceplayer.shared.interface.gaussian_interface import GaussianInterface
|
||||
from diceplayer.shared.interface.dice_interface import DiceInterface
|
||||
from diceplayer.shared.utils.dataclass_protocol import Dataclass
|
||||
from diceplayer import logger
|
||||
from diceplayer.shared.config.dice_config import DiceConfig
|
||||
from diceplayer.shared.config.gaussian_config import GaussianDTO
|
||||
from diceplayer.shared.config.player_config import PlayerConfig
|
||||
from diceplayer.shared.config.dice_config import DiceConfig
|
||||
from diceplayer.shared.utils.misc import weekday_date_time
|
||||
from diceplayer.shared.environment.atom import Atom
|
||||
from diceplayer.shared.environment.molecule import Molecule
|
||||
from diceplayer.shared.environment.system import System
|
||||
from diceplayer.shared.environment.atom import Atom
|
||||
from diceplayer.shared.interface.dice_interface import DiceInterface
|
||||
from diceplayer.shared.interface.gaussian_interface import GaussianInterface
|
||||
from diceplayer.shared.utils.dataclass_protocol import Dataclass
|
||||
from diceplayer.shared.utils.misc import weekday_date_time
|
||||
from diceplayer.shared.utils.ptable import atomsymb
|
||||
from diceplayer import logger
|
||||
|
||||
from dataclasses import fields
|
||||
from typing import Type, Tuple
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import yaml
|
||||
import sys
|
||||
import os
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
from dataclasses import fields
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Type
|
||||
|
||||
ENV = ["OMP_STACKSIZE"]
|
||||
|
||||
@@ -29,9 +29,7 @@ class Player:
|
||||
raise ValueError("Must specify either infile or optimization")
|
||||
|
||||
elif infile is not None:
|
||||
self.config = self.set_config(
|
||||
self.read_keywords(infile)
|
||||
)
|
||||
self.config = self.set_config(self.read_keywords(infile))
|
||||
|
||||
self.system = System()
|
||||
|
||||
@@ -60,7 +58,6 @@ class Player:
|
||||
)
|
||||
|
||||
for cycle in range(self.initial_cycle, self.initial_cycle + self.config.maxcyc):
|
||||
|
||||
logger.info(
|
||||
f"------------------------------------------------------------------------------------------\n"
|
||||
f" Step # {cycle}\n"
|
||||
@@ -78,9 +75,7 @@ class Player:
|
||||
|
||||
def prepare_system(self):
|
||||
for i, mol in enumerate(self.system.molecule):
|
||||
logger.info(
|
||||
f"Molecule {i + 1} - {mol.molname}"
|
||||
)
|
||||
logger.info(f"Molecule {i + 1} - {mol.molname}")
|
||||
|
||||
mol.print_mol_info()
|
||||
logger.info(
|
||||
@@ -113,7 +108,6 @@ class Player:
|
||||
geoms_file_path.touch()
|
||||
|
||||
def print_keywords(self) -> None:
|
||||
|
||||
def log_keywords(config: Dataclass, dto: Type[Dataclass]):
|
||||
for key in sorted(list(map(lambda f: f.name, fields(dto)))):
|
||||
if getattr(config, key) is not None:
|
||||
@@ -162,9 +156,7 @@ class Player:
|
||||
with open(self.config.dice.ljname) as file:
|
||||
ljc_data = file.readlines()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Potential file {self.config.dice.ljname} not found."
|
||||
)
|
||||
raise RuntimeError(f"Potential file {self.config.dice.ljname} not found.")
|
||||
|
||||
combrule = ljc_data.pop(0).split()[0]
|
||||
if combrule not in ("*", "+"):
|
||||
@@ -191,7 +183,6 @@ class Player:
|
||||
)
|
||||
|
||||
for i in range(ntypes):
|
||||
|
||||
try:
|
||||
nsites, molname = ljc_data.pop(0).split()[:2]
|
||||
except ValueError:
|
||||
@@ -209,10 +200,7 @@ class Player:
|
||||
|
||||
atom_fields = ["lbl", "na", "rx", "ry", "rz", "chg", "eps", "sig"]
|
||||
for j in range(nsites):
|
||||
new_atom = dict(zip(
|
||||
atom_fields,
|
||||
ljc_data.pop(0).split()
|
||||
))
|
||||
new_atom = dict(zip(atom_fields, ljc_data.pop(0).split()))
|
||||
self.system.molecule[i].add_atom(
|
||||
Atom(**self.validate_atom_dict(i, j, new_atom))
|
||||
)
|
||||
@@ -227,16 +215,12 @@ class Player:
|
||||
)
|
||||
|
||||
logger.info(f"Combination rule: {self.config.dice.combrule}")
|
||||
logger.info(
|
||||
f"Types of molecules: {len(self.system.molecule)}\n"
|
||||
)
|
||||
logger.info(f"Types of molecules: {len(self.system.molecule)}\n")
|
||||
|
||||
i = 0
|
||||
for mol in self.system.molecule:
|
||||
i += 1
|
||||
logger.info(
|
||||
"{} atoms in molecule type {}:".format(len(mol.atom), i)
|
||||
)
|
||||
logger.info("{} atoms in molecule type {}:".format(len(mol.atom), i))
|
||||
logger.info(
|
||||
"---------------------------------------------------------------------------------"
|
||||
)
|
||||
@@ -285,42 +269,37 @@ class Player:
|
||||
self.gaussian_interface.reset()
|
||||
|
||||
if self.config.opt:
|
||||
if 'position' not in result:
|
||||
raise RuntimeError(
|
||||
'Optimization failed. No position found in result.'
|
||||
)
|
||||
if "position" not in result:
|
||||
raise RuntimeError("Optimization failed. No position found in result.")
|
||||
|
||||
self.system.update_molecule(result['position'])
|
||||
self.system.update_molecule(result["position"])
|
||||
|
||||
else:
|
||||
if 'charges' not in result:
|
||||
if "charges" not in result:
|
||||
raise RuntimeError(
|
||||
'Charges optimization failed. No charges found in result.'
|
||||
"Charges optimization failed. No charges found in result."
|
||||
)
|
||||
|
||||
diff = self.system.molecule[0] \
|
||||
.update_charges(result['charges'])
|
||||
diff = self.system.molecule[0].update_charges(result["charges"])
|
||||
|
||||
self.system.print_charges_and_dipole(cycle)
|
||||
self.print_geoms(cycle)
|
||||
|
||||
if diff < self.config.gaussian.chg_tol:
|
||||
logger.info(
|
||||
f'Charges converged after {cycle} cycles.'
|
||||
)
|
||||
logger.info(f"Charges converged after {cycle} cycles.")
|
||||
raise StopIteration()
|
||||
|
||||
def print_geoms(self, cycle: int):
|
||||
with open(self.config.geoms_file, 'a') as file:
|
||||
file.write(f'Cycle # {cycle}\n')
|
||||
with open(self.config.geoms_file, "a") as file:
|
||||
file.write(f"Cycle # {cycle}\n")
|
||||
|
||||
for atom in self.system.molecule[0].atom:
|
||||
symbol = atomsymb[atom.na]
|
||||
file.write(
|
||||
f'{symbol:<2s} {atom.rx:>10.6f} {atom.ry:>10.6f} {atom.rz:>10.6f}\n'
|
||||
f"{symbol:<2s} {atom.rx:>10.6f} {atom.ry:>10.6f} {atom.rz:>10.6f}\n"
|
||||
)
|
||||
|
||||
file.write('\n')
|
||||
file.write("\n")
|
||||
|
||||
@staticmethod
|
||||
def validate_atom_dict(molecule_type, molecule_site, atom_dict: dict) -> dict:
|
||||
@@ -329,69 +308,69 @@ class Player:
|
||||
|
||||
if len(atom_dict) < 8:
|
||||
raise ValueError(
|
||||
f'Invalid number of fields for site {molecule_site} for molecule type {molecule_type}.'
|
||||
f"Invalid number of fields for site {molecule_site} for molecule type {molecule_type}."
|
||||
)
|
||||
|
||||
try:
|
||||
atom_dict['lbl'] = int(atom_dict['lbl'])
|
||||
atom_dict["lbl"] = int(atom_dict["lbl"])
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f'Invalid lbl fields for site {molecule_site} for molecule type {molecule_type}.'
|
||||
f"Invalid lbl fields for site {molecule_site} for molecule type {molecule_type}."
|
||||
)
|
||||
|
||||
try:
|
||||
atom_dict['na'] = int(atom_dict['na'])
|
||||
atom_dict["na"] = int(atom_dict["na"])
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f'Invalid na fields for site {molecule_site} for molecule type {molecule_type}.'
|
||||
f"Invalid na fields for site {molecule_site} for molecule type {molecule_type}."
|
||||
)
|
||||
|
||||
try:
|
||||
atom_dict['rx'] = float(atom_dict['rx'])
|
||||
atom_dict["rx"] = float(atom_dict["rx"])
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f'Invalid rx fields for site {molecule_site} for molecule type {molecule_type}. '
|
||||
f'Value must be a float.'
|
||||
f"Invalid rx fields for site {molecule_site} for molecule type {molecule_type}. "
|
||||
f"Value must be a float."
|
||||
)
|
||||
|
||||
try:
|
||||
atom_dict['ry'] = float(atom_dict['ry'])
|
||||
atom_dict["ry"] = float(atom_dict["ry"])
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f'Invalid ry fields for site {molecule_site} for molecule type {molecule_type}. '
|
||||
f'Value must be a float.'
|
||||
f"Invalid ry fields for site {molecule_site} for molecule type {molecule_type}. "
|
||||
f"Value must be a float."
|
||||
)
|
||||
|
||||
try:
|
||||
atom_dict['rz'] = float(atom_dict['rz'])
|
||||
atom_dict["rz"] = float(atom_dict["rz"])
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f'Invalid rz fields for site {molecule_site} for molecule type {molecule_type}. '
|
||||
f'Value must be a float.'
|
||||
f"Invalid rz fields for site {molecule_site} for molecule type {molecule_type}. "
|
||||
f"Value must be a float."
|
||||
)
|
||||
|
||||
try:
|
||||
atom_dict['chg'] = float(atom_dict['chg'])
|
||||
atom_dict["chg"] = float(atom_dict["chg"])
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f'Invalid chg fields for site {molecule_site} for molecule type {molecule_type}. '
|
||||
f'Value must be a float.'
|
||||
f"Invalid chg fields for site {molecule_site} for molecule type {molecule_type}. "
|
||||
f"Value must be a float."
|
||||
)
|
||||
|
||||
try:
|
||||
atom_dict['eps'] = float(atom_dict['eps'])
|
||||
atom_dict["eps"] = float(atom_dict["eps"])
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f'Invalid eps fields for site {molecule_site} for molecule type {molecule_type}. '
|
||||
f'Value must be a float.'
|
||||
f"Invalid eps fields for site {molecule_site} for molecule type {molecule_type}. "
|
||||
f"Value must be a float."
|
||||
)
|
||||
|
||||
try:
|
||||
atom_dict['sig'] = float(atom_dict['sig'])
|
||||
atom_dict["sig"] = float(atom_dict["sig"])
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f'Invalid sig fields for site {molecule_site} for molecule type {molecule_type}. '
|
||||
f'Value must be a float.'
|
||||
f"Invalid sig fields for site {molecule_site} for molecule type {molecule_type}. "
|
||||
f"Value must be a float."
|
||||
)
|
||||
|
||||
return atom_dict
|
||||
@@ -400,9 +379,7 @@ class Player:
|
||||
formatstr = "{:<3d} {:>3d} {:>10.5f} {:>10.5f} {:>10.5f} {:>10.6f} {:>9.5f} {:>7.4f} {:>9.4f}"
|
||||
|
||||
mol = self.system.molecule[0]
|
||||
logger.info(
|
||||
"{} atoms in molecule type {}:".format(len(mol.atom), 1)
|
||||
)
|
||||
logger.info("{} atoms in molecule type {}:".format(len(mol.atom), 1))
|
||||
logger.info(
|
||||
"---------------------------------------------------------------------------------"
|
||||
)
|
||||
@@ -432,28 +409,21 @@ class Player:
|
||||
|
||||
def save_run_in_pickle(self, cycle):
|
||||
try:
|
||||
with open('latest-step.pkl', 'wb') as pickle_file:
|
||||
pickle.dump(
|
||||
(self.config, self.system, cycle),
|
||||
pickle_file
|
||||
)
|
||||
with open("latest-step.pkl", "wb") as pickle_file:
|
||||
pickle.dump((self.config, self.system, cycle), pickle_file)
|
||||
except Exception:
|
||||
raise RuntimeError(
|
||||
f'Could not save pickle file latest-step.pkl.'
|
||||
)
|
||||
raise RuntimeError(f"Could not save pickle file latest-step.pkl.")
|
||||
|
||||
@staticmethod
|
||||
def load_run_from_pickle() -> Tuple[PlayerConfig, System, int]:
|
||||
pickle_path = Path("latest-step.pkl")
|
||||
try:
|
||||
with open(pickle_path, 'rb') as pickle_file:
|
||||
with open(pickle_path, "rb") as pickle_file:
|
||||
save = pickle.load(pickle_file)
|
||||
return save[0], save[1], save[2] + 1
|
||||
|
||||
except Exception:
|
||||
raise RuntimeError(
|
||||
f'Could not load pickle file {pickle_path}.'
|
||||
)
|
||||
raise RuntimeError(f"Could not load pickle file {pickle_path}.")
|
||||
|
||||
@staticmethod
|
||||
def set_config(data: dict) -> PlayerConfig:
|
||||
@@ -461,18 +431,16 @@ class Player:
|
||||
|
||||
@staticmethod
|
||||
def read_keywords(infile) -> dict:
|
||||
with open(infile, 'r') as yml_file:
|
||||
with open(infile, "r") as yml_file:
|
||||
config = yaml.load(yml_file, Loader=yaml.SafeLoader)
|
||||
|
||||
if "diceplayer" in config:
|
||||
return config.get("diceplayer")
|
||||
|
||||
raise RuntimeError(
|
||||
f'Could not find diceplayer section in {infile}.'
|
||||
)
|
||||
raise RuntimeError(f"Could not find diceplayer section in {infile}.")
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, infile: str) -> 'Player':
|
||||
def from_file(cls, infile: str) -> "Player":
|
||||
return cls(infile=infile)
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user