import os from crystalpol.shared.system.molecule import Molecule from crystalpol.shared.system.crystal import Crystal from crystalpol.shared.system.atom import Atom from crystalpol.shared.config import Config from crystalpol.gaussian import Gaussian from typing import List import sys class Polarization: __slots__ = ('geom_file', 'outfile', 'config', 'gaussian', 'crystal') def __init__(self, geom_file: str, outfile: str, config: Config) -> None: self.crystal = None self.geom_file = geom_file self.outfile = outfile self.config = config self.gaussian = Gaussian(config) def run(self): self.gaussian.create_simulation_dir() self.read_crystal() cycle = 1 charge_diff = sys.float_info.max while charge_diff > self.config.charge_tolerance: charge_diff = self.update_crystal_charges( self.gaussian.run(cycle, self.crystal), ) cycle += 1 def read_crystal(self) -> None: with open(self.geom_file, 'r') as geom_file: lines = geom_file.readlines() molecules = self._get_molecules_from_lines(lines) structure = self._get_crystal_structure(molecules[0]) self.crystal = Crystal([structure]) for molecule in molecules: self.crystal.add_cell([molecule]) def update_crystal_charges(self, charges: List[float]) -> float: charge_diff = [] for cell in self.crystal: for molecule in cell: for index, atom in enumerate(molecule): if atom.chg: charge_diff.append(abs(atom.chg - charges[index])) atom.chg = charges[index] return abs(max(charge_diff, key=abs)) if charge_diff else sys.float_info.max def _get_molecules_from_lines(self, lines: List[str]) -> List[Molecule]: if (len(lines) % self.config.n_atoms) == 0: molecules: List[Molecule] = [] for index, molecule in enumerate(self.split(lines, self.config.n_atoms)): mol = Molecule(f"Molecule-{index}") for atom_line in molecule: symbol, rx, ry, rz = tuple(atom_line.split()) mol.add_atom( Atom( rx, ry, rz, symbol=symbol.ljust(2), ) ) molecules.append(mol) return molecules else: raise RuntimeError( "Invalid Geom File, the number of atoms doesn't match the number of lines." ) @staticmethod def _get_crystal_structure(molecule: Molecule) -> List[str]: structure: List[str] = [] for atom in molecule: structure.append(atom.symbol) return structure @staticmethod def split(array: List, partitions: int): for i in range(0, len(array), partitions): yield array[i: i + partitions]