93 lines
2.6 KiB
Python
93 lines
2.6 KiB
Python
from diceplayer.config import DiceConfig, GaussianConfig, PlayerConfig
|
|
from diceplayer.dice.dice_handler import DiceHandler
|
|
from diceplayer.dice.dice_wrapper import DiceWrapper
|
|
from diceplayer.state.state_model import StateModel
|
|
from diceplayer.utils.potential import read_system_from_phb
|
|
from tests._assets import ASSETS_DIR
|
|
|
|
import pytest
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
class TestDiceHandler:
|
|
@pytest.fixture
|
|
def working_directory(self) -> Path:
|
|
return ASSETS_DIR
|
|
|
|
@pytest.fixture
|
|
def player_config(self, working_directory: Path) -> PlayerConfig:
|
|
return PlayerConfig(
|
|
type="both",
|
|
mem=12,
|
|
max_cyc=100,
|
|
switch_cyc=50,
|
|
ncores=4,
|
|
dice=DiceConfig(
|
|
nprocs=4,
|
|
ljname=working_directory / "phb.ljc",
|
|
outname="test",
|
|
dens=1.0,
|
|
nmol=[1, 200],
|
|
nstep=[1000, 1000],
|
|
),
|
|
gaussian=GaussianConfig(
|
|
level="test",
|
|
qmprog="g16",
|
|
keywords="test",
|
|
),
|
|
)
|
|
|
|
@pytest.fixture
|
|
def state_model(self, player_config: PlayerConfig) -> StateModel:
|
|
return StateModel(
|
|
config=player_config,
|
|
system=read_system_from_phb(player_config),
|
|
current_cycle=0,
|
|
)
|
|
|
|
@pytest.fixture
|
|
def dice_wrapper(
|
|
self, player_config: PlayerConfig, working_directory: Path
|
|
) -> DiceWrapper:
|
|
return DiceWrapper(player_config.dice, working_directory)
|
|
|
|
@pytest.fixture
|
|
def dice_handler(self, tmp_path) -> DiceHandler:
|
|
return DiceHandler(tmp_path)
|
|
|
|
def test_filter_environment_sites(
|
|
self,
|
|
dice_handler: DiceHandler,
|
|
dice_wrapper: DiceWrapper,
|
|
state_model: StateModel,
|
|
) -> None:
|
|
environment = dice_wrapper.parse_results()[0]
|
|
|
|
filtered_environment = dice_handler._filter_environment_sites(
|
|
state_model, environment
|
|
)
|
|
|
|
assert len(filtered_environment) < environment.number_of_sites
|
|
|
|
def test_aggregate_results(
|
|
self,
|
|
dice_handler: DiceHandler,
|
|
dice_wrapper: DiceWrapper,
|
|
state_model: StateModel,
|
|
) -> None:
|
|
environments = dice_wrapper.parse_results()
|
|
|
|
picked_environments = [
|
|
dice_handler._filter_environment_sites(state_model, env)
|
|
for env in environments
|
|
]
|
|
|
|
aggregated_environment = dice_handler._aggregate_results(
|
|
state_model, picked_environments
|
|
)
|
|
|
|
assert len(aggregated_environment) == sum(
|
|
len(env) for env in picked_environments
|
|
)
|