from diceplayer.config import DiceConfig, GaussianConfig, PlayerConfig from diceplayer.dice.dice_handler import DiceHandler from diceplayer.dice.dice_wrapper import DiceWrapper from diceplayer.state.state_model import StateModel from diceplayer.utils.potential import read_system_from_phb from tests._assets import ASSETS_DIR import pytest from pathlib import Path class TestDiceHandler: @pytest.fixture def working_directory(self) -> Path: return ASSETS_DIR @pytest.fixture def player_config(self, working_directory: Path) -> PlayerConfig: return PlayerConfig( type="both", mem=12, max_cyc=100, switch_cyc=50, ncores=4, dice=DiceConfig( nprocs=4, ljname=working_directory / "phb.ljc", outname="test", dens=1.0, nmol=[1, 200], nstep=[1000, 1000], ), gaussian=GaussianConfig( level="test", qmprog="g16", keywords="test", ), ) @pytest.fixture def state_model(self, player_config: PlayerConfig) -> StateModel: return StateModel( config=player_config, system=read_system_from_phb(player_config), current_cycle=0, ) @pytest.fixture def dice_wrapper( self, player_config: PlayerConfig, working_directory: Path ) -> DiceWrapper: return DiceWrapper(player_config.dice, working_directory) @pytest.fixture def dice_handler(self, tmp_path) -> DiceHandler: return DiceHandler(tmp_path) def test_filter_environment_sites( self, dice_handler: DiceHandler, dice_wrapper: DiceWrapper, state_model: StateModel, ) -> None: environment = dice_wrapper.parse_results()[0] filtered_environment = dice_handler._filter_environment_sites( state_model, environment ) assert len(filtered_environment) < environment.number_of_sites def test_aggregate_results( self, dice_handler: DiceHandler, dice_wrapper: DiceWrapper, state_model: StateModel, ) -> None: environments = dice_wrapper.parse_results() picked_environments = [ dice_handler._filter_environment_sites(state_model, env) for env in environments ] aggregated_environment = dice_handler._aggregate_results( state_model, picked_environments ) assert len(aggregated_environment) == sum( len(env) for env in picked_environments )