Files
DicePlayer/tests/state/test_state_handler.py

119 lines
3.3 KiB
Python

from diceplayer.config import DiceConfig, GaussianConfig, PlayerConfig
from diceplayer.environment import System
from diceplayer.state.state_handler import StateHandler
from diceplayer.state.state_model import StateModel
import pytest
from pathlib import Path
class TestStateHandler:
@pytest.fixture
def player_config(self) -> PlayerConfig:
return PlayerConfig(
type="both",
mem=12,
max_cyc=100,
switch_cyc=50,
ncores=4,
dice=DiceConfig(
nprocs=4,
ljname="test",
outname="test",
dens=1.0,
nmol=[1],
nstep=[1, 1],
),
gaussian=GaussianConfig(
level="test",
qmprog="g16",
keywords="test",
),
)
def test_initialization(self, tmp_path: Path):
state_handler = StateHandler(tmp_path)
assert isinstance(state_handler, StateHandler)
def test_save(self, tmp_path: Path, player_config: PlayerConfig):
state_handler = StateHandler(tmp_path)
state = StateModel(
config=player_config,
system=System(),
current_cycle=0,
)
state_handler.save(state)
assert (tmp_path / "state.pkl").exists()
def test_get_when_empty(self, tmp_path: Path, player_config: PlayerConfig):
state_handler = StateHandler(tmp_path)
state = state_handler.get(player_config)
assert state is None
def test_get(self, tmp_path: Path, player_config: PlayerConfig):
state_handler = StateHandler(tmp_path)
state = StateModel(
config=player_config,
system=System(),
current_cycle=0,
)
state_handler.save(state)
retrieved_state = state_handler.get(player_config)
assert retrieved_state is not None
assert retrieved_state.config == state.config
assert retrieved_state.system == state.system
assert retrieved_state.current_cycle == state.current_cycle
def test_get_with_different_config(
self, tmp_path: Path, player_config: PlayerConfig
):
state_handler = StateHandler(tmp_path)
state = StateModel(
config=player_config,
system=System(),
current_cycle=0,
)
state_handler.save(state)
different_config = player_config.model_copy(update={"max_cyc": 200})
retrieved_state = state_handler.get(different_config)
assert retrieved_state is None
def test_get_with_different_config_force(
self, tmp_path: Path, player_config: PlayerConfig
):
state_handler = StateHandler(tmp_path)
state = StateModel(
config=player_config,
system=System(),
current_cycle=0,
)
state_handler.save(state)
different_config = player_config.model_copy(update={"max_cyc": 200})
retrieved_state = state_handler.get(different_config, force=True)
assert retrieved_state is not None
assert retrieved_state.config == state.config
assert retrieved_state.config != different_config
assert retrieved_state.system == state.system
assert retrieved_state.current_cycle == state.current_cycle