103 lines
3.0 KiB
Python
103 lines
3.0 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import json
|
||
|
|
import os
|
||
|
|
import threading
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
from relay.state import InstanceLock, StateError, read_json, write_atomic, write_json_atomic
|
||
|
|
|
||
|
|
|
||
|
|
def test_write_atomic_round_trips(tmp_path: Path) -> None:
|
||
|
|
target = tmp_path / "f.json"
|
||
|
|
write_atomic(target, "hello")
|
||
|
|
assert target.read_text() == "hello"
|
||
|
|
|
||
|
|
|
||
|
|
def test_write_atomic_does_not_leave_temp_files(tmp_path: Path) -> None:
|
||
|
|
target = tmp_path / "f.json"
|
||
|
|
write_atomic(target, "hello")
|
||
|
|
siblings = [p.name for p in tmp_path.iterdir() if p != target]
|
||
|
|
assert siblings == []
|
||
|
|
|
||
|
|
|
||
|
|
def test_write_json_atomic_round_trips(tmp_path: Path) -> None:
|
||
|
|
target = tmp_path / "value.json"
|
||
|
|
payload = [{"k": "v"}, {"k2": "v2"}]
|
||
|
|
write_json_atomic(target, payload)
|
||
|
|
assert json.loads(target.read_text()) == payload
|
||
|
|
|
||
|
|
|
||
|
|
def test_read_json_returns_default_when_missing(tmp_path: Path) -> None:
|
||
|
|
assert read_json(tmp_path / "nope.json", default=[]) == []
|
||
|
|
assert read_json(tmp_path / "nope.json", default={"x": 1}) == {"x": 1}
|
||
|
|
|
||
|
|
|
||
|
|
def test_read_json_returns_default_when_empty(tmp_path: Path) -> None:
|
||
|
|
target = tmp_path / "empty.json"
|
||
|
|
target.write_text("")
|
||
|
|
assert read_json(target, default=[]) == []
|
||
|
|
|
||
|
|
|
||
|
|
def test_read_json_raises_on_corruption(tmp_path: Path) -> None:
|
||
|
|
target = tmp_path / "bad.json"
|
||
|
|
target.write_text("{not json")
|
||
|
|
with pytest.raises(StateError):
|
||
|
|
read_json(target, default=[])
|
||
|
|
|
||
|
|
|
||
|
|
def test_instance_lock_acquires_and_releases(tmp_path: Path) -> None:
|
||
|
|
lock = InstanceLock(tmp_path / ".lock")
|
||
|
|
lock.acquire()
|
||
|
|
try:
|
||
|
|
# PID written to file
|
||
|
|
assert (tmp_path / ".lock").read_text().strip() == str(os.getpid())
|
||
|
|
finally:
|
||
|
|
lock.release()
|
||
|
|
|
||
|
|
|
||
|
|
def test_instance_lock_rejects_second_acquirer(tmp_path: Path) -> None:
|
||
|
|
lock1 = InstanceLock(tmp_path / ".lock")
|
||
|
|
lock1.acquire()
|
||
|
|
try:
|
||
|
|
lock2 = InstanceLock(tmp_path / ".lock")
|
||
|
|
with pytest.raises(StateError):
|
||
|
|
lock2.acquire()
|
||
|
|
finally:
|
||
|
|
lock1.release()
|
||
|
|
|
||
|
|
|
||
|
|
def test_instance_lock_release_is_idempotent(tmp_path: Path) -> None:
|
||
|
|
lock = InstanceLock(tmp_path / ".lock")
|
||
|
|
lock.acquire()
|
||
|
|
lock.release()
|
||
|
|
lock.release() # second release should not raise
|
||
|
|
|
||
|
|
|
||
|
|
def test_concurrent_acquire_serializes(tmp_path: Path) -> None:
|
||
|
|
"""Confirm flock semantics: two threads acquiring the same lock can't overlap."""
|
||
|
|
lock_path = tmp_path / ".lock"
|
||
|
|
held: list[bool] = []
|
||
|
|
error_count = [0]
|
||
|
|
|
||
|
|
def worker():
|
||
|
|
lock = InstanceLock(lock_path)
|
||
|
|
try:
|
||
|
|
lock.acquire()
|
||
|
|
held.append(True)
|
||
|
|
lock.release()
|
||
|
|
except StateError:
|
||
|
|
error_count[0] += 1
|
||
|
|
|
||
|
|
threads = [threading.Thread(target=worker) for _ in range(2)]
|
||
|
|
for t in threads:
|
||
|
|
t.start()
|
||
|
|
for t in threads:
|
||
|
|
t.join()
|
||
|
|
|
||
|
|
# At least one succeeded; the other either succeeded after the first
|
||
|
|
# released, or failed to acquire (depending on scheduling).
|
||
|
|
assert sum(held) + error_count[0] == 2
|