feat: improves and initilize player pipeline
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from diceplayer.config import DiceConfig, GaussianConfig, PlayerConfig
|
||||
from diceplayer.environment import System
|
||||
from diceplayer.state.state_handler import StateHandler
|
||||
from diceplayer.state.state_model import StateModel
|
||||
|
||||
import pytest
|
||||
|
||||
from diceplayer.config import PlayerConfig, DiceConfig, GaussianConfig
|
||||
from diceplayer.state.state_handler import StateHandler
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class TestStateHandler:
|
||||
@@ -30,18 +31,87 @@ class TestStateHandler:
|
||||
),
|
||||
)
|
||||
|
||||
def test_state_handler_initialization(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
state_handler = StateHandler(tmpdir_path)
|
||||
def test_initialization(self, tmp_path: Path):
|
||||
state_handler = StateHandler(tmp_path)
|
||||
|
||||
assert isinstance(state_handler, StateHandler)
|
||||
assert isinstance(state_handler, StateHandler)
|
||||
|
||||
def test_state_handler_get_state(self, player_config: PlayerConfig):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
state_handler = StateHandler(tmpdir_path)
|
||||
def test_save(self, tmp_path: Path, player_config: PlayerConfig):
|
||||
state_handler = StateHandler(tmp_path)
|
||||
|
||||
state = state_handler.get_state(player_config)
|
||||
state = StateModel(
|
||||
config=player_config,
|
||||
system=System(),
|
||||
current_cycle=0,
|
||||
)
|
||||
|
||||
assert state is None
|
||||
state_handler.save(state)
|
||||
|
||||
assert (tmp_path / "state.pkl").exists()
|
||||
|
||||
def test_get_when_empty(self, tmp_path: Path, player_config: PlayerConfig):
|
||||
state_handler = StateHandler(tmp_path)
|
||||
|
||||
state = state_handler.get(player_config)
|
||||
|
||||
assert state is None
|
||||
|
||||
def test_get(self, tmp_path: Path, player_config: PlayerConfig):
|
||||
state_handler = StateHandler(tmp_path)
|
||||
|
||||
state = StateModel(
|
||||
config=player_config,
|
||||
system=System(),
|
||||
current_cycle=0,
|
||||
)
|
||||
|
||||
state_handler.save(state)
|
||||
|
||||
retrieved_state = state_handler.get(player_config)
|
||||
|
||||
assert retrieved_state is not None
|
||||
assert retrieved_state.config == state.config
|
||||
assert retrieved_state.system == state.system
|
||||
assert retrieved_state.current_cycle == state.current_cycle
|
||||
|
||||
def test_get_with_different_config(
|
||||
self, tmp_path: Path, player_config: PlayerConfig
|
||||
):
|
||||
state_handler = StateHandler(tmp_path)
|
||||
|
||||
state = StateModel(
|
||||
config=player_config,
|
||||
system=System(),
|
||||
current_cycle=0,
|
||||
)
|
||||
|
||||
state_handler.save(state)
|
||||
|
||||
different_config = player_config.model_copy(update={"opt": False})
|
||||
|
||||
retrieved_state = state_handler.get(different_config)
|
||||
|
||||
assert retrieved_state is None
|
||||
|
||||
def test_get_with_different_config_force(
|
||||
self, tmp_path: Path, player_config: PlayerConfig
|
||||
):
|
||||
state_handler = StateHandler(tmp_path)
|
||||
|
||||
state = StateModel(
|
||||
config=player_config,
|
||||
system=System(),
|
||||
current_cycle=0,
|
||||
)
|
||||
|
||||
state_handler.save(state)
|
||||
|
||||
different_config = player_config.model_copy(update={"opt": False})
|
||||
|
||||
retrieved_state = state_handler.get(different_config, force=True)
|
||||
|
||||
assert retrieved_state is not None
|
||||
assert retrieved_state.config == state.config
|
||||
assert retrieved_state.config != different_config
|
||||
assert retrieved_state.system == state.system
|
||||
assert retrieved_state.current_cycle == state.current_cycle
|
||||
|
||||
Reference in New Issue
Block a user