from diceplayer.config import DiceConfig, GaussianConfig, PlayerConfig from diceplayer.environment import System from diceplayer.state.state_handler import StateHandler from diceplayer.state.state_model import StateModel import pytest from pathlib import Path class TestStateHandler: @pytest.fixture def player_config(self) -> PlayerConfig: return PlayerConfig( type="both", mem=12, max_cyc=100, switch_cyc=50, nprocs=4, ncores=4, dice=DiceConfig( ljname="test", outname="test", dens=1.0, nmol=[1], nstep=[1, 1], ), gaussian=GaussianConfig( level="test", qmprog="g16", keywords="test", ), ) def test_initialization(self, tmp_path: Path): state_handler = StateHandler(tmp_path) assert isinstance(state_handler, StateHandler) def test_save(self, tmp_path: Path, player_config: PlayerConfig): state_handler = StateHandler(tmp_path) state = StateModel( config=player_config, system=System(), current_cycle=0, ) state_handler.save(state) assert (tmp_path / "state.pkl").exists() def test_get_when_empty(self, tmp_path: Path, player_config: PlayerConfig): state_handler = StateHandler(tmp_path) state = state_handler.get(player_config) assert state is None def test_get(self, tmp_path: Path, player_config: PlayerConfig): state_handler = StateHandler(tmp_path) state = StateModel( config=player_config, system=System(), current_cycle=0, ) state_handler.save(state) retrieved_state = state_handler.get(player_config) assert retrieved_state is not None assert retrieved_state.config == state.config assert retrieved_state.system == state.system assert retrieved_state.current_cycle == state.current_cycle def test_get_with_different_config( self, tmp_path: Path, player_config: PlayerConfig ): state_handler = StateHandler(tmp_path) state = StateModel( config=player_config, system=System(), current_cycle=0, ) state_handler.save(state) different_config = player_config.model_copy(update={"max_cyc": 200}) retrieved_state = state_handler.get(different_config) assert retrieved_state is None def test_get_with_different_config_force( self, tmp_path: Path, player_config: PlayerConfig ): state_handler = StateHandler(tmp_path) state = StateModel( config=player_config, system=System(), current_cycle=0, ) state_handler.save(state) different_config = player_config.model_copy(update={"max_cyc": 200}) retrieved_state = state_handler.get(different_config, force=True) assert retrieved_state is not None assert retrieved_state.config == state.config assert retrieved_state.config != different_config assert retrieved_state.system == state.system assert retrieved_state.current_cycle == state.current_cycle