119 lines
3.3 KiB
Python
119 lines
3.3 KiB
Python
from diceplayer.config import DiceConfig, GaussianConfig, PlayerConfig
|
|
from diceplayer.environment import System
|
|
from diceplayer.state.state_handler import StateHandler
|
|
from diceplayer.state.state_model import StateModel
|
|
|
|
import pytest
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
class TestStateHandler:
|
|
@pytest.fixture
|
|
def player_config(self) -> PlayerConfig:
|
|
return PlayerConfig(
|
|
type="both",
|
|
mem=12,
|
|
max_cyc=100,
|
|
switch_cyc=50,
|
|
ncores=4,
|
|
dice=DiceConfig(
|
|
nprocs=4,
|
|
ljname="test",
|
|
outname="test",
|
|
dens=1.0,
|
|
nmol=[1],
|
|
nstep=[1, 1],
|
|
),
|
|
gaussian=GaussianConfig(
|
|
level="test",
|
|
qmprog="g16",
|
|
keywords="test",
|
|
),
|
|
)
|
|
|
|
def test_initialization(self, tmp_path: Path):
|
|
state_handler = StateHandler(tmp_path)
|
|
|
|
assert isinstance(state_handler, StateHandler)
|
|
|
|
def test_save(self, tmp_path: Path, player_config: PlayerConfig):
|
|
state_handler = StateHandler(tmp_path)
|
|
|
|
state = StateModel(
|
|
config=player_config,
|
|
system=System(),
|
|
current_cycle=0,
|
|
)
|
|
|
|
state_handler.save(state)
|
|
|
|
assert (tmp_path / "state.pkl").exists()
|
|
|
|
def test_get_when_empty(self, tmp_path: Path, player_config: PlayerConfig):
|
|
state_handler = StateHandler(tmp_path)
|
|
|
|
state = state_handler.get(player_config)
|
|
|
|
assert state is None
|
|
|
|
def test_get(self, tmp_path: Path, player_config: PlayerConfig):
|
|
state_handler = StateHandler(tmp_path)
|
|
|
|
state = StateModel(
|
|
config=player_config,
|
|
system=System(),
|
|
current_cycle=0,
|
|
)
|
|
|
|
state_handler.save(state)
|
|
|
|
retrieved_state = state_handler.get(player_config)
|
|
|
|
assert retrieved_state is not None
|
|
assert retrieved_state.config == state.config
|
|
assert retrieved_state.system == state.system
|
|
assert retrieved_state.current_cycle == state.current_cycle
|
|
|
|
def test_get_with_different_config(
|
|
self, tmp_path: Path, player_config: PlayerConfig
|
|
):
|
|
state_handler = StateHandler(tmp_path)
|
|
|
|
state = StateModel(
|
|
config=player_config,
|
|
system=System(),
|
|
current_cycle=0,
|
|
)
|
|
|
|
state_handler.save(state)
|
|
|
|
different_config = player_config.model_copy(update={"max_cyc": 200})
|
|
|
|
retrieved_state = state_handler.get(different_config)
|
|
|
|
assert retrieved_state is None
|
|
|
|
def test_get_with_different_config_force(
|
|
self, tmp_path: Path, player_config: PlayerConfig
|
|
):
|
|
state_handler = StateHandler(tmp_path)
|
|
|
|
state = StateModel(
|
|
config=player_config,
|
|
system=System(),
|
|
current_cycle=0,
|
|
)
|
|
|
|
state_handler.save(state)
|
|
|
|
different_config = player_config.model_copy(update={"max_cyc": 200})
|
|
|
|
retrieved_state = state_handler.get(different_config, force=True)
|
|
|
|
assert retrieved_state is not None
|
|
assert retrieved_state.config == state.config
|
|
assert retrieved_state.config != different_config
|
|
assert retrieved_state.system == state.system
|
|
assert retrieved_state.current_cycle == state.current_cycle
|