Adds Formatter to Project

This commit is contained in:
2023-06-11 16:04:25 -03:00
parent 82f3092f3e
commit c4dae5e8d1
29 changed files with 1151 additions and 721 deletions

View File

@@ -1,24 +1,24 @@
from diceplayer.shared.interface.gaussian_interface import GaussianInterface
from diceplayer.shared.interface.dice_interface import DiceInterface
from diceplayer.shared.utils.dataclass_protocol import Dataclass
from diceplayer import logger
from diceplayer.shared.config.dice_config import DiceConfig
from diceplayer.shared.config.gaussian_config import GaussianDTO
from diceplayer.shared.config.player_config import PlayerConfig
from diceplayer.shared.config.dice_config import DiceConfig
from diceplayer.shared.utils.misc import weekday_date_time
from diceplayer.shared.environment.atom import Atom
from diceplayer.shared.environment.molecule import Molecule
from diceplayer.shared.environment.system import System
from diceplayer.shared.environment.atom import Atom
from diceplayer.shared.interface.dice_interface import DiceInterface
from diceplayer.shared.interface.gaussian_interface import GaussianInterface
from diceplayer.shared.utils.dataclass_protocol import Dataclass
from diceplayer.shared.utils.misc import weekday_date_time
from diceplayer.shared.utils.ptable import atomsymb
from diceplayer import logger
from dataclasses import fields
from typing import Type, Tuple
from pathlib import Path
import pickle
import yaml
import sys
import os
import os
import pickle
import sys
from dataclasses import fields
from pathlib import Path
from typing import Tuple, Type
ENV = ["OMP_STACKSIZE"]
@@ -29,9 +29,7 @@ class Player:
raise ValueError("Must specify either infile or optimization")
elif infile is not None:
self.config = self.set_config(
self.read_keywords(infile)
)
self.config = self.set_config(self.read_keywords(infile))
self.system = System()
@@ -60,7 +58,6 @@ class Player:
)
for cycle in range(self.initial_cycle, self.initial_cycle + self.config.maxcyc):
logger.info(
f"------------------------------------------------------------------------------------------\n"
f" Step # {cycle}\n"
@@ -78,9 +75,7 @@ class Player:
def prepare_system(self):
for i, mol in enumerate(self.system.molecule):
logger.info(
f"Molecule {i + 1} - {mol.molname}"
)
logger.info(f"Molecule {i + 1} - {mol.molname}")
mol.print_mol_info()
logger.info(
@@ -113,7 +108,6 @@ class Player:
geoms_file_path.touch()
def print_keywords(self) -> None:
def log_keywords(config: Dataclass, dto: Type[Dataclass]):
for key in sorted(list(map(lambda f: f.name, fields(dto)))):
if getattr(config, key) is not None:
@@ -162,9 +156,7 @@ class Player:
with open(self.config.dice.ljname) as file:
ljc_data = file.readlines()
else:
raise RuntimeError(
f"Potential file {self.config.dice.ljname} not found."
)
raise RuntimeError(f"Potential file {self.config.dice.ljname} not found.")
combrule = ljc_data.pop(0).split()[0]
if combrule not in ("*", "+"):
@@ -191,7 +183,6 @@ class Player:
)
for i in range(ntypes):
try:
nsites, molname = ljc_data.pop(0).split()[:2]
except ValueError:
@@ -209,10 +200,7 @@ class Player:
atom_fields = ["lbl", "na", "rx", "ry", "rz", "chg", "eps", "sig"]
for j in range(nsites):
new_atom = dict(zip(
atom_fields,
ljc_data.pop(0).split()
))
new_atom = dict(zip(atom_fields, ljc_data.pop(0).split()))
self.system.molecule[i].add_atom(
Atom(**self.validate_atom_dict(i, j, new_atom))
)
@@ -227,16 +215,12 @@ class Player:
)
logger.info(f"Combination rule: {self.config.dice.combrule}")
logger.info(
f"Types of molecules: {len(self.system.molecule)}\n"
)
logger.info(f"Types of molecules: {len(self.system.molecule)}\n")
i = 0
for mol in self.system.molecule:
i += 1
logger.info(
"{} atoms in molecule type {}:".format(len(mol.atom), i)
)
logger.info("{} atoms in molecule type {}:".format(len(mol.atom), i))
logger.info(
"---------------------------------------------------------------------------------"
)
@@ -285,42 +269,37 @@ class Player:
self.gaussian_interface.reset()
if self.config.opt:
if 'position' not in result:
raise RuntimeError(
'Optimization failed. No position found in result.'
)
if "position" not in result:
raise RuntimeError("Optimization failed. No position found in result.")
self.system.update_molecule(result['position'])
self.system.update_molecule(result["position"])
else:
if 'charges' not in result:
if "charges" not in result:
raise RuntimeError(
'Charges optimization failed. No charges found in result.'
"Charges optimization failed. No charges found in result."
)
diff = self.system.molecule[0] \
.update_charges(result['charges'])
diff = self.system.molecule[0].update_charges(result["charges"])
self.system.print_charges_and_dipole(cycle)
self.print_geoms(cycle)
if diff < self.config.gaussian.chg_tol:
logger.info(
f'Charges converged after {cycle} cycles.'
)
logger.info(f"Charges converged after {cycle} cycles.")
raise StopIteration()
def print_geoms(self, cycle: int):
with open(self.config.geoms_file, 'a') as file:
file.write(f'Cycle # {cycle}\n')
with open(self.config.geoms_file, "a") as file:
file.write(f"Cycle # {cycle}\n")
for atom in self.system.molecule[0].atom:
symbol = atomsymb[atom.na]
file.write(
f'{symbol:<2s} {atom.rx:>10.6f} {atom.ry:>10.6f} {atom.rz:>10.6f}\n'
f"{symbol:<2s} {atom.rx:>10.6f} {atom.ry:>10.6f} {atom.rz:>10.6f}\n"
)
file.write('\n')
file.write("\n")
@staticmethod
def validate_atom_dict(molecule_type, molecule_site, atom_dict: dict) -> dict:
@@ -329,69 +308,69 @@ class Player:
if len(atom_dict) < 8:
raise ValueError(
f'Invalid number of fields for site {molecule_site} for molecule type {molecule_type}.'
f"Invalid number of fields for site {molecule_site} for molecule type {molecule_type}."
)
try:
atom_dict['lbl'] = int(atom_dict['lbl'])
atom_dict["lbl"] = int(atom_dict["lbl"])
except Exception:
raise ValueError(
f'Invalid lbl fields for site {molecule_site} for molecule type {molecule_type}.'
f"Invalid lbl fields for site {molecule_site} for molecule type {molecule_type}."
)
try:
atom_dict['na'] = int(atom_dict['na'])
atom_dict["na"] = int(atom_dict["na"])
except Exception:
raise ValueError(
f'Invalid na fields for site {molecule_site} for molecule type {molecule_type}.'
f"Invalid na fields for site {molecule_site} for molecule type {molecule_type}."
)
try:
atom_dict['rx'] = float(atom_dict['rx'])
atom_dict["rx"] = float(atom_dict["rx"])
except Exception:
raise ValueError(
f'Invalid rx fields for site {molecule_site} for molecule type {molecule_type}. '
f'Value must be a float.'
f"Invalid rx fields for site {molecule_site} for molecule type {molecule_type}. "
f"Value must be a float."
)
try:
atom_dict['ry'] = float(atom_dict['ry'])
atom_dict["ry"] = float(atom_dict["ry"])
except Exception:
raise ValueError(
f'Invalid ry fields for site {molecule_site} for molecule type {molecule_type}. '
f'Value must be a float.'
f"Invalid ry fields for site {molecule_site} for molecule type {molecule_type}. "
f"Value must be a float."
)
try:
atom_dict['rz'] = float(atom_dict['rz'])
atom_dict["rz"] = float(atom_dict["rz"])
except Exception:
raise ValueError(
f'Invalid rz fields for site {molecule_site} for molecule type {molecule_type}. '
f'Value must be a float.'
f"Invalid rz fields for site {molecule_site} for molecule type {molecule_type}. "
f"Value must be a float."
)
try:
atom_dict['chg'] = float(atom_dict['chg'])
atom_dict["chg"] = float(atom_dict["chg"])
except Exception:
raise ValueError(
f'Invalid chg fields for site {molecule_site} for molecule type {molecule_type}. '
f'Value must be a float.'
f"Invalid chg fields for site {molecule_site} for molecule type {molecule_type}. "
f"Value must be a float."
)
try:
atom_dict['eps'] = float(atom_dict['eps'])
atom_dict["eps"] = float(atom_dict["eps"])
except Exception:
raise ValueError(
f'Invalid eps fields for site {molecule_site} for molecule type {molecule_type}. '
f'Value must be a float.'
f"Invalid eps fields for site {molecule_site} for molecule type {molecule_type}. "
f"Value must be a float."
)
try:
atom_dict['sig'] = float(atom_dict['sig'])
atom_dict["sig"] = float(atom_dict["sig"])
except Exception:
raise ValueError(
f'Invalid sig fields for site {molecule_site} for molecule type {molecule_type}. '
f'Value must be a float.'
f"Invalid sig fields for site {molecule_site} for molecule type {molecule_type}. "
f"Value must be a float."
)
return atom_dict
@@ -400,9 +379,7 @@ class Player:
formatstr = "{:<3d} {:>3d} {:>10.5f} {:>10.5f} {:>10.5f} {:>10.6f} {:>9.5f} {:>7.4f} {:>9.4f}"
mol = self.system.molecule[0]
logger.info(
"{} atoms in molecule type {}:".format(len(mol.atom), 1)
)
logger.info("{} atoms in molecule type {}:".format(len(mol.atom), 1))
logger.info(
"---------------------------------------------------------------------------------"
)
@@ -432,28 +409,21 @@ class Player:
def save_run_in_pickle(self, cycle):
try:
with open('latest-step.pkl', 'wb') as pickle_file:
pickle.dump(
(self.config, self.system, cycle),
pickle_file
)
with open("latest-step.pkl", "wb") as pickle_file:
pickle.dump((self.config, self.system, cycle), pickle_file)
except Exception:
raise RuntimeError(
f'Could not save pickle file latest-step.pkl.'
)
raise RuntimeError(f"Could not save pickle file latest-step.pkl.")
@staticmethod
def load_run_from_pickle() -> Tuple[PlayerConfig, System, int]:
pickle_path = Path("latest-step.pkl")
try:
with open(pickle_path, 'rb') as pickle_file:
with open(pickle_path, "rb") as pickle_file:
save = pickle.load(pickle_file)
return save[0], save[1], save[2] + 1
except Exception:
raise RuntimeError(
f'Could not load pickle file {pickle_path}.'
)
raise RuntimeError(f"Could not load pickle file {pickle_path}.")
@staticmethod
def set_config(data: dict) -> PlayerConfig:
@@ -461,18 +431,16 @@ class Player:
@staticmethod
def read_keywords(infile) -> dict:
with open(infile, 'r') as yml_file:
with open(infile, "r") as yml_file:
config = yaml.load(yml_file, Loader=yaml.SafeLoader)
if "diceplayer" in config:
return config.get("diceplayer")
raise RuntimeError(
f'Could not find diceplayer section in {infile}.'
)
raise RuntimeError(f"Could not find diceplayer section in {infile}.")
@classmethod
def from_file(cls, infile: str) -> 'Player':
def from_file(cls, infile: str) -> "Player":
return cls(infile=infile)
@classmethod