feat: improves and initilize player pipeline
This commit is contained in:
0
tests/cli/__init__.py
Normal file
0
tests/cli/__init__.py
Normal file
30
tests/cli/test_read_input_file.py
Normal file
30
tests/cli/test_read_input_file.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import diceplayer
|
||||
from diceplayer.cli import read_input
|
||||
from diceplayer.config import PlayerConfig
|
||||
|
||||
import pytest
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class TestReadInputFile:
|
||||
@pytest.fixture
|
||||
def example_config(self) -> Path:
|
||||
return Path(diceplayer.__path__[0]).parent / "control.example.yml"
|
||||
|
||||
def test_read_input_file(self, example_config: Path):
|
||||
config = read_input(example_config)
|
||||
|
||||
assert config is not None
|
||||
assert isinstance(config, PlayerConfig)
|
||||
|
||||
def test_read_input_non_existing_file(self):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
read_input("nonexistent_file.yml")
|
||||
|
||||
def test_read_input_invalid_yaml(self, tmp_path: Path):
|
||||
invalid_yaml_file = tmp_path / "invalid.yml"
|
||||
invalid_yaml_file.write_text("This is not valid YAML: [unbalanced brackets")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
read_input(invalid_yaml_file)
|
||||
@@ -1,10 +1,11 @@
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from diceplayer.config import DiceConfig, GaussianConfig, PlayerConfig
|
||||
from diceplayer.environment import System
|
||||
from diceplayer.state.state_handler import StateHandler
|
||||
from diceplayer.state.state_model import StateModel
|
||||
|
||||
import pytest
|
||||
|
||||
from diceplayer.config import PlayerConfig, DiceConfig, GaussianConfig
|
||||
from diceplayer.state.state_handler import StateHandler
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class TestStateHandler:
|
||||
@@ -30,18 +31,87 @@ class TestStateHandler:
|
||||
),
|
||||
)
|
||||
|
||||
def test_state_handler_initialization(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
state_handler = StateHandler(tmpdir_path)
|
||||
def test_initialization(self, tmp_path: Path):
|
||||
state_handler = StateHandler(tmp_path)
|
||||
|
||||
assert isinstance(state_handler, StateHandler)
|
||||
assert isinstance(state_handler, StateHandler)
|
||||
|
||||
def test_state_handler_get_state(self, player_config: PlayerConfig):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir_path = Path(tmpdir)
|
||||
state_handler = StateHandler(tmpdir_path)
|
||||
def test_save(self, tmp_path: Path, player_config: PlayerConfig):
|
||||
state_handler = StateHandler(tmp_path)
|
||||
|
||||
state = state_handler.get_state(player_config)
|
||||
state = StateModel(
|
||||
config=player_config,
|
||||
system=System(),
|
||||
current_cycle=0,
|
||||
)
|
||||
|
||||
assert state is None
|
||||
state_handler.save(state)
|
||||
|
||||
assert (tmp_path / "state.pkl").exists()
|
||||
|
||||
def test_get_when_empty(self, tmp_path: Path, player_config: PlayerConfig):
|
||||
state_handler = StateHandler(tmp_path)
|
||||
|
||||
state = state_handler.get(player_config)
|
||||
|
||||
assert state is None
|
||||
|
||||
def test_get(self, tmp_path: Path, player_config: PlayerConfig):
|
||||
state_handler = StateHandler(tmp_path)
|
||||
|
||||
state = StateModel(
|
||||
config=player_config,
|
||||
system=System(),
|
||||
current_cycle=0,
|
||||
)
|
||||
|
||||
state_handler.save(state)
|
||||
|
||||
retrieved_state = state_handler.get(player_config)
|
||||
|
||||
assert retrieved_state is not None
|
||||
assert retrieved_state.config == state.config
|
||||
assert retrieved_state.system == state.system
|
||||
assert retrieved_state.current_cycle == state.current_cycle
|
||||
|
||||
def test_get_with_different_config(
|
||||
self, tmp_path: Path, player_config: PlayerConfig
|
||||
):
|
||||
state_handler = StateHandler(tmp_path)
|
||||
|
||||
state = StateModel(
|
||||
config=player_config,
|
||||
system=System(),
|
||||
current_cycle=0,
|
||||
)
|
||||
|
||||
state_handler.save(state)
|
||||
|
||||
different_config = player_config.model_copy(update={"opt": False})
|
||||
|
||||
retrieved_state = state_handler.get(different_config)
|
||||
|
||||
assert retrieved_state is None
|
||||
|
||||
def test_get_with_different_config_force(
|
||||
self, tmp_path: Path, player_config: PlayerConfig
|
||||
):
|
||||
state_handler = StateHandler(tmp_path)
|
||||
|
||||
state = StateModel(
|
||||
config=player_config,
|
||||
system=System(),
|
||||
current_cycle=0,
|
||||
)
|
||||
|
||||
state_handler.save(state)
|
||||
|
||||
different_config = player_config.model_copy(update={"opt": False})
|
||||
|
||||
retrieved_state = state_handler.get(different_config, force=True)
|
||||
|
||||
assert retrieved_state is not None
|
||||
assert retrieved_state.config == state.config
|
||||
assert retrieved_state.config != different_config
|
||||
assert retrieved_state.system == state.system
|
||||
assert retrieved_state.current_cycle == state.current_cycle
|
||||
|
||||
Reference in New Issue
Block a user