Files
CrystalPol/crystalpol/polarization.py

111 lines
3.3 KiB
Python

import os
from crystalpol.shared.system.molecule import Molecule
from crystalpol.shared.system.crystal import Crystal
from crystalpol.shared.system.atom import Atom
from crystalpol.shared.config import Config
from crystalpol.gaussian import Gaussian
from typing import List, Tuple
import sys
from crystalpol.shared.utils.log import Log
class Polarization:
__slots__ = ('geom_file', 'outfile', 'config', 'gaussian', 'crystal')
def __init__(self, geom_file: str, outfile: str, config: Config) -> None:
self.crystal = None
self.geom_file = geom_file
self.outfile = outfile
self.config = config
self.gaussian = Gaussian(config)
def run(self):
self.gaussian.create_simulation_dir()
self.read_crystal()
cycle = 1
max_charge_diff = sys.float_info.max
while max_charge_diff >= self.config.charge_tolerance:
max_charge_diff, charge_diff = self.update_crystal_charges(
self.gaussian.run(cycle, self.crystal),
)
Log.make_run(
cycle,
max_charge_diff if cycle != 1 else 0,
charge_diff,
self.crystal
)
cycle += 1
def read_crystal(self) -> None:
with open(self.geom_file, 'r') as geom_file:
lines = geom_file.readlines()
molecules = self._get_molecules_from_lines(lines)
structure = self._get_crystal_structure(molecules[0])
self.crystal = Crystal([structure])
for molecule in molecules:
self.crystal.add_cell([molecule])
def update_crystal_charges(self, charges: List[float]) -> Tuple[float, list]:
charge_diff = []
for cell in self.crystal:
for molecule in cell:
for index, atom in enumerate(molecule):
if atom.chg:
charge_diff.append(abs(atom.chg - charges[index]))
atom.chg = charges[index]
max_charge_diff = abs(max(charge_diff, key=abs)) if charge_diff else sys.float_info.max
return max_charge_diff, charge_diff
def _get_molecules_from_lines(self, lines: List[str]) -> List[Molecule]:
if (len(lines) % self.config.n_atoms) == 0:
molecules: List[Molecule] = []
for index, molecule in enumerate(self.split(lines, self.config.n_atoms)):
mol = Molecule(f"Molecule-{index}")
for atom_line in molecule:
symbol, rx, ry, rz = tuple(atom_line.split())
mol.add_atom(
Atom(
rx, ry, rz,
symbol=symbol.ljust(2),
)
)
molecules.append(mol)
return molecules
else:
raise RuntimeError(
"Invalid Geom File, the number of atoms doesn't match the number of lines."
)
@staticmethod
def _get_crystal_structure(molecule: Molecule) -> List[str]:
structure: List[str] = []
for atom in molecule:
structure.append(atom.symbol)
return structure
@staticmethod
def split(array: List, partitions: int):
for i in range(0, len(array), partitions):
yield array[i: i + partitions]