diff --git a/README.md b/README.md index f31156c..9b25c34 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Relay daemon between Claude Code instances and a Claude.ai chat-equivalent sessi When CC produces output that would normally be pasted to a Claude.ai chat for review or a decision, the daemon does the relay automatically: 1. CC drops a JSON envelope into `queue/`. -2. Daemon picks oldest-first, appends to a running conversation history, calls the Anthropic API. +2. Daemon picks oldest-first, appends to a running conversation history, calls the Anthropic API with prompt caching on the system prompt. 3. If the response contains `[NEEDS-JC]` in its first 200 characters, the daemon pauses and notifies via [ntfy.sh](https://ntfy.sh). 4. Otherwise, the response is written to `dispatch//input.txt` for the originating CC session to consume. @@ -18,7 +18,7 @@ JC can override at any time by writing to `state/jc_input.txt`. ```sh git clone git@localhost:AC/risv3-relay.git cd risv3-relay -python3 -m venv .venv +python3.14 -m venv .venv .venv/bin/pip install -e '.[dev]' cp .env.example .env # edit .env — add ANTHROPIC_API_KEY @@ -30,16 +30,107 @@ cp .env.example .env .venv/bin/python -m relay run ``` -The daemon prints the ntfy subscription URL on startup. Subscribe to it from your phone/laptop to receive `needs_jc` and error notifications. +On first boot the daemon generates a random ntfy topic, persists it back to `.env`, and prints the subscription URL: + +``` +ntfy topic: https://ntfy.sh/ +Subscribe on phone/laptop to receive needs_jc + error alerts. +``` + +Subscribe at that URL on phone/laptop to get pings on `needs_jc` and errors. The topic is functionally a password — anyone subscribed receives the messages, so don't share it. + +## CC-side protocol + +A CC session integrates with the daemon by **dropping queue envelopes** and **polling its dispatch input**. + +### Sending CC output to chat-Claude + +Drop a JSON file into `queue/`: + +```json +{ + "session_id": "session-1", + "timestamp": "2026-05-02T15:30:00Z", + "content": "Sub-PR A is open at #438. Tests pass. Awaiting review." +} +``` + +File name doesn't matter (use `--.json` for sortability). The daemon picks oldest-first by file mtime and processes one entry per loop tick. + +### Receiving chat-Claude responses + +Poll your session's dispatch directory: + +```sh +while true; do + if [ -s dispatch/session-1/input.txt ]; then + cat dispatch/session-1/input.txt + rm dispatch/session-1/input.txt + fi + sleep 1 +done +``` + +Deletion is the acknowledgement. The daemon will not write a new `input.txt` until the previous one is consumed (deleted). + +## JC operations + +Status: + +```sh +.venv/bin/python -m relay status +``` + +ntfy URL: + +```sh +.venv/bin/python -m relay topic +``` + +Override at any time — the daemon picks up `state/jc_input.txt` on the next tick (≤ 1 second): + +| Format | Effect | +|---|---| +| `@session-1: do X` | Direct dispatch to `session-1`. No API call. Body after the prefix becomes the input. Clears `needs_jc` if set. | +| `(any text without @prefix)` | Treated as the next chat-side turn. The daemon sends it through the API, dispatches the response to the originating session (or `sessions[0]` if not in queue context). Clears `needs_jc` if set. | ## Project layout -- `relay/` — Python package -- `tests/` — pytest tests +- `relay/` — Python package (config, state, conversation, anthropic_client, queue, dispatch, ntfy, daemon, __main__) +- `tests/` — pytest tests; `pytest -m real_api` opts into live-API smoke - `queue/`, `dispatch/`, `state/`, `logs/` — runtime directories created on first run; gitignored -- `config.yaml` — registered CC sessions and per-session settings +- `config.yaml` — registered CC sessions, system prompt, summarization prompt; auto-seeded on first run - `.env` — secrets and per-host overrides; gitignored -## Status +## Trust model -First-PR scope: daemon skeleton, queue + dispatch loop, single-CC-session integration, basic logging, ntfy notifications, conversation history with summarization. See `docs/` and PR descriptions for follow-up scope (status web UI, multi-session, error recovery, cost tracking, systemd unit). +- The daemon is **pure transport** between CC and chat-Claude. It does not make merge decisions, override CC's existing rules, or run arbitrary commands. +- The Anthropic API call is the only outbound integration besides ntfy. The system prompt and summarization prompt live in `config.yaml`; edit them to shape chat-Claude's behavior. +- JC's `jc_input.txt` is authoritative — anything written there is treated as the next chat-side turn (or a direct dispatch with the `@session-N:` prefix). + +## Status of the project + +First-PR scope (this repo's `main` after merge): daemon skeleton, queue + dispatch loop, single-CC-session integration, basic logging, ntfy notifications, conversation history with summarization, prompt caching on the system prompt, per-call cost estimation logged. + +Follow-up PRs will add: status web UI (Flask endpoint), multi-session config and per-session prompts, exponential backoff on transient API errors, systemd unit, cost-tracking dashboard. + +## Development + +Run the unit suite: + +```sh +.venv/bin/python -m pytest +``` + +Run the live-API smoke (cost ~$0.0001 against Haiku 4.5; needs a billed `ANTHROPIC_API_KEY`): + +```sh +.venv/bin/python -m pytest -m real_api +``` + +Lint / format: + +```sh +.venv/bin/ruff check relay/ tests/ +.venv/bin/ruff format relay/ tests/ +``` diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..1d5deab --- /dev/null +++ b/config.yaml @@ -0,0 +1,15 @@ +system_prompt: 'You are the chat-side counterpart to Claude Code instances working + on the risv3 project. CC sends you its progress; you respond as a project lead would: + check decisions, ask clarifying questions, approve or correct. When a CC turn raises + a question only JC (the human owner) can answer, begin your response with the literal + token [NEEDS-JC] so the relay daemon pauses and notifies JC. Otherwise reply normally + and the relay forwards your reply to the originating CC session.' +summarization_prompt: Summarize the conversation so far. Preserve project context, + decisions made, open work items, and any outstanding [NEEDS-JC] questions. The summary + will replace earlier turns in the conversation history; the most recent turns will + be retained verbatim. Be specific where specificity matters (file paths, issue numbers, + decisions); be brief on routine back-and-forth. +sessions: +- session_id: session-1 + working_dir: null + description: Default CC session diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..8f175e5 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,45 @@ +[build-system] +requires = ["setuptools>=68"] +build-backend = "setuptools.build_meta" + +[project] +name = "risv3-relay" +version = "0.1.0" +description = "Relay daemon between Claude Code instances and a Claude.ai chat-equivalent session via the Anthropic API." +requires-python = ">=3.10" +authors = [{name = "AC"}] +dependencies = [ + "anthropic>=0.42,<1", + "flask>=3.0,<4", + "requests>=2.32,<3", + "pyyaml>=6.0,<7", + "python-dotenv>=1.0,<2", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.3", + "ruff>=0.6", +] + +[project.scripts] +relay = "relay.__main__:main" + +[tool.setuptools.packages.find] +include = ["relay*"] + +[tool.ruff] +line-length = 100 +target-version = "py310" + +[tool.ruff.lint] +select = ["E", "F", "W", "I", "B", "UP", "RUF"] +ignore = ["E501"] # line-length already set + +[tool.pytest.ini_options] +testpaths = ["tests"] +addopts = "-m 'not real_api'" +filterwarnings = ["ignore::DeprecationWarning"] +markers = [ + "real_api: hits the live Anthropic API; gated on ANTHROPIC_API_KEY env var. Run with `pytest -m real_api` to opt in", +] diff --git a/relay/__init__.py b/relay/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/relay/__main__.py b/relay/__main__.py new file mode 100644 index 0000000..3d620dc --- /dev/null +++ b/relay/__main__.py @@ -0,0 +1,75 @@ +"""CLI entry point: ``relay ``.""" + +from __future__ import annotations + +import argparse +import json +import logging +import signal +import sys + +from relay.config import load_settings +from relay.daemon import Daemon +from relay.logs import configure as configure_logs +from relay.ntfy import topic_url +from relay.state import read_json + + +def _cmd_run(args: argparse.Namespace) -> int: + settings = load_settings() + configure_logs(settings.logs_dir, level=logging.DEBUG if args.verbose else logging.INFO) + daemon = Daemon(settings) + + def _stop(signum: int, _frame: object) -> None: + logging.getLogger(__name__).info("received signal %s; shutting down", signum) + daemon.stop() + + signal.signal(signal.SIGINT, _stop) + signal.signal(signal.SIGTERM, _stop) + + try: + daemon.run() + except RuntimeError as exc: + # Most commonly the lock file: another daemon is running. + print(f"error: {exc}", file=sys.stderr) + return 1 + return 0 + + +def _cmd_status(_args: argparse.Namespace) -> int: + settings = load_settings() + status_path = settings.state_dir / "status.json" + payload = read_json(status_path, default=None) + if payload is None: + print("no status file yet — has the daemon run?", file=sys.stderr) + return 1 + print(json.dumps(payload, indent=2, sort_keys=False)) + return 0 + + +def _cmd_topic(_args: argparse.Namespace) -> int: + settings = load_settings() + print(topic_url(settings.ntfy_topic)) + return 0 + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser(prog="relay", description="risv3-relay daemon") + sub = parser.add_subparsers(dest="cmd", required=True) + + run = sub.add_parser("run", help="Run the daemon in the foreground") + run.add_argument("-v", "--verbose", action="store_true") + run.set_defaults(func=_cmd_run) + + status = sub.add_parser("status", help="Print the current daemon status") + status.set_defaults(func=_cmd_status) + + topic = sub.add_parser("topic", help="Print the ntfy subscription URL") + topic.set_defaults(func=_cmd_topic) + + args = parser.parse_args(argv) + return args.func(args) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/relay/anthropic_client.py b/relay/anthropic_client.py new file mode 100644 index 0000000..1ab13cd --- /dev/null +++ b/relay/anthropic_client.py @@ -0,0 +1,119 @@ +"""Anthropic API wrapper. + +Wraps the SDK call so the daemon can send a turn and get back the +assistant text + token usage + estimated cost. Implements prompt +caching on the system prompt: subsequent calls within the 5-minute TTL +get a cache hit on the (large, repeated) system prompt and pay the +cheaper cache-hit rate. Per-turn user/assistant content is never marked +cacheable because it changes every call. + +Cost estimation uses a model→price table; the table is the source of +truth and is easy to update when pricing changes. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +from anthropic import Anthropic +from anthropic.types import Message + +# Pricing per 1M tokens (USD), pulled from Anthropic's published schedule. +# Cache-hit input is billed at the cache-read rate (~10% of standard input). +# Cache-write is ~25% more than standard input. These numbers are +# approximate and used only for log-line cost estimation; the source of +# truth for billing is Anthropic's invoice. +_PRICES_PER_MILLION: dict[str, dict[str, float]] = { + "claude-opus-4-7": {"input": 15.0, "output": 75.0, "cache_write": 18.75, "cache_read": 1.50}, + "claude-opus-4-7-1m": {"input": 15.0, "output": 75.0, "cache_write": 18.75, "cache_read": 1.50}, + "claude-sonnet-4-6": {"input": 3.0, "output": 15.0, "cache_write": 3.75, "cache_read": 0.30}, + "claude-haiku-4-5-20251001": { + "input": 0.80, + "output": 4.0, + "cache_write": 1.00, + "cache_read": 0.08, + }, +} + + +def _price_for(model: str) -> dict[str, float]: + if model in _PRICES_PER_MILLION: + return _PRICES_PER_MILLION[model] + # Fallback: charge as Opus 4.7 (worst-case) so estimates don't + # under-report for unknown models. Logged once at startup. + return _PRICES_PER_MILLION["claude-opus-4-7"] + + +@dataclass +class TurnResult: + text: str + input_tokens: int + output_tokens: int + cache_creation_input_tokens: int + cache_read_input_tokens: int + estimated_cost_usd: float + raw: Message + + +@dataclass +class AnthropicClient: + api_key: str + model: str + max_output_tokens: int = 4096 + + def __post_init__(self) -> None: + self._sdk = Anthropic(api_key=self.api_key) + + def send( + self, + *, + system_prompt: str, + messages: list[dict[str, str]], + ) -> TurnResult: + """Send a single API turn. The system prompt is marked cacheable. + + ``messages`` is the full conversation history shaped for the + Messages API. The most recent message is the user turn we're + responding to. The daemon is responsible for appending the + assistant text it gets back into history before the next call. + """ + + response = self._sdk.messages.create( + model=self.model, + max_tokens=self.max_output_tokens, + system=[ + { + "type": "text", + "text": system_prompt, + "cache_control": {"type": "ephemeral"}, + } + ], + messages=messages, + ) + + text = "".join( + block.text for block in response.content if getattr(block, "type", None) == "text" + ) + usage = response.usage + in_tokens = int(getattr(usage, "input_tokens", 0) or 0) + out_tokens = int(getattr(usage, "output_tokens", 0) or 0) + cache_create = int(getattr(usage, "cache_creation_input_tokens", 0) or 0) + cache_read = int(getattr(usage, "cache_read_input_tokens", 0) or 0) + + prices = _price_for(self.model) + cost = ( + in_tokens * prices["input"] + + out_tokens * prices["output"] + + cache_create * prices["cache_write"] + + cache_read * prices["cache_read"] + ) / 1_000_000.0 + + return TurnResult( + text=text, + input_tokens=in_tokens, + output_tokens=out_tokens, + cache_creation_input_tokens=cache_create, + cache_read_input_tokens=cache_read, + estimated_cost_usd=cost, + raw=response, + ) diff --git a/relay/config.py b/relay/config.py new file mode 100644 index 0000000..9332439 --- /dev/null +++ b/relay/config.py @@ -0,0 +1,163 @@ +"""Settings + config.yaml loader. + +The relay reads two config sources: + +1. ``.env`` for secrets (API key) and per-host overrides (status port, + history cap). The ntfy topic is auto-generated on first run and + written back to ``.env`` so it persists across restarts. + +2. ``config.yaml`` for project-level settings: the registered CC + sessions, the system prompt, the summarization prompt. First-PR + default config is generated if the file doesn't exist, with one + ``session-1`` registered. + +Settings are read once at startup; the daemon does not hot-reload them. +""" + +from __future__ import annotations + +import os +import secrets +from dataclasses import dataclass, field +from pathlib import Path + +import yaml +from dotenv import dotenv_values, set_key + +REPO_ROOT = Path(__file__).resolve().parent.parent +ENV_FILE = REPO_ROOT / ".env" +CONFIG_FILE = REPO_ROOT / "config.yaml" + +DEFAULT_SYSTEM_PROMPT = ( + "You are the chat-side counterpart to Claude Code instances working on the " + "risv3 project. CC sends you its progress; you respond as a project lead " + "would: check decisions, ask clarifying questions, approve or correct. " + "When a CC turn raises a question only JC (the human owner) can answer, " + "begin your response with the literal token [NEEDS-JC] so the relay daemon " + "pauses and notifies JC. Otherwise reply normally and the relay forwards " + "your reply to the originating CC session." +) + +DEFAULT_SUMMARIZATION_PROMPT = ( + "Summarize the conversation so far. Preserve project context, decisions " + "made, open work items, and any outstanding [NEEDS-JC] questions. The " + "summary will replace earlier turns in the conversation history; the most " + "recent turns will be retained verbatim. Be specific where specificity " + "matters (file paths, issue numbers, decisions); be brief on routine " + "back-and-forth." +) + + +@dataclass(frozen=True) +class SessionConfig: + """One registered CC session.""" + + session_id: str + working_dir: str | None = None + description: str | None = None + + +@dataclass(frozen=True) +class Settings: + api_key: str + model: str + ntfy_topic: str + status_port: int + history_char_cap: int + repo_root: Path + queue_dir: Path + dispatch_dir: Path + state_dir: Path + logs_dir: Path + system_prompt: str + summarization_prompt: str + sessions: tuple[SessionConfig, ...] = field(default_factory=tuple) + + +def _ensure_runtime_dirs(repo_root: Path) -> None: + for sub in ("queue", "dispatch", "state", "logs"): + (repo_root / sub).mkdir(exist_ok=True) + + +def _generate_ntfy_topic() -> str: + """Cryptographically random 16-character topic. Functionally a password.""" + return secrets.token_urlsafe(12) + + +def _load_or_init_ntfy_topic(env_values: dict[str, str | None]) -> str: + raw = (env_values.get("NTFY_TOPIC") or "").strip() + # Defensive: dotenv treats everything after = as the value, so an + # inline `#` comment ends up as the topic. Treat any value that + # starts with `#` (or contains whitespace, which a real topic never + # does) as empty. + if raw and not raw.startswith("#") and " " not in raw: + return raw + topic = _generate_ntfy_topic() + set_key(str(ENV_FILE), "NTFY_TOPIC", topic, quote_mode="never") + return topic + + +def _load_yaml_config() -> dict: + if not CONFIG_FILE.exists(): + default = { + "system_prompt": DEFAULT_SYSTEM_PROMPT, + "summarization_prompt": DEFAULT_SUMMARIZATION_PROMPT, + "sessions": [ + { + "session_id": "session-1", + "working_dir": None, + "description": "Default CC session", + } + ], + } + CONFIG_FILE.write_text(yaml.safe_dump(default, sort_keys=False)) + return default + with CONFIG_FILE.open() as f: + return yaml.safe_load(f) or {} + + +def load_settings() -> Settings: + """Read .env + config.yaml. Mutates .env on first run to record the ntfy topic.""" + + env_values = dotenv_values(ENV_FILE) + api_key = ( + env_values.get("ANTHROPIC_API_KEY") or os.environ.get("ANTHROPIC_API_KEY") or "" + ).strip() + if not api_key: + raise RuntimeError(f"ANTHROPIC_API_KEY missing from {ENV_FILE}") + + model = (env_values.get("ANTHROPIC_MODEL") or "claude-opus-4-7").strip() + status_port = int((env_values.get("STATUS_PORT") or "8765").strip()) + history_char_cap = int((env_values.get("HISTORY_CHAR_CAP") or "400000").strip()) + ntfy_topic = _load_or_init_ntfy_topic(env_values) + + yaml_cfg = _load_yaml_config() + system_prompt = str(yaml_cfg.get("system_prompt") or DEFAULT_SYSTEM_PROMPT) + summarization_prompt = str(yaml_cfg.get("summarization_prompt") or DEFAULT_SUMMARIZATION_PROMPT) + sessions_cfg = yaml_cfg.get("sessions") or [] + sessions = tuple( + SessionConfig( + session_id=str(s["session_id"]), + working_dir=s.get("working_dir"), + description=s.get("description"), + ) + for s in sessions_cfg + ) + + _ensure_runtime_dirs(REPO_ROOT) + + return Settings( + api_key=api_key, + model=model, + ntfy_topic=ntfy_topic, + status_port=status_port, + history_char_cap=history_char_cap, + repo_root=REPO_ROOT, + queue_dir=REPO_ROOT / "queue", + dispatch_dir=REPO_ROOT / "dispatch", + state_dir=REPO_ROOT / "state", + logs_dir=REPO_ROOT / "logs", + system_prompt=system_prompt, + summarization_prompt=summarization_prompt, + sessions=sessions, + ) diff --git a/relay/conversation.py b/relay/conversation.py new file mode 100644 index 0000000..3d34361 --- /dev/null +++ b/relay/conversation.py @@ -0,0 +1,99 @@ +"""Conversation history with summarization. + +History shape: list of turns ``{role, content, ts, session_id?}`` where +``role`` is ``"user"`` or ``"assistant"``. The user role represents +either CC output or JC override; the assistant role represents the +chat-Claude response. + +Summarization fires when the total content character count exceeds +``HISTORY_CHAR_CAP``. The summarization prompt is sent as a normal user +turn, the API's response replaces all earlier turns, and the most +recent ``RECENT_TURNS_KEPT`` turns are appended verbatim. The summary +turn is marked with ``meta="summary"`` so the daemon can recognize it +when paginating. +""" + +from __future__ import annotations + +from dataclasses import asdict, dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Any + +from relay.state import read_json, write_json_atomic + +RECENT_TURNS_KEPT = 10 + + +@dataclass +class Turn: + role: str # 'user' | 'assistant' + content: str + ts: str = field(default_factory=lambda: datetime.utcnow().isoformat(timespec="seconds") + "Z") + session_id: str | None = None + meta: str | None = None # 'summary' for replaced summary turns + + def to_api_message(self) -> dict[str, str]: + """Anthropic Messages API shape — only role + content needed.""" + return {"role": self.role, "content": self.content} + + +class Conversation: + """In-memory + on-disk conversation history.""" + + def __init__(self, history_path: Path): + self._path = history_path + raw = read_json(self._path, default=[]) + if not isinstance(raw, list): + raise ValueError(f"Expected list at {self._path}, got {type(raw).__name__}") + self._turns: list[Turn] = [Turn(**dict(t)) for t in raw] + + @property + def turns(self) -> list[Turn]: + return list(self._turns) + + def append( + self, role: str, content: str, *, session_id: str | None = None, meta: str | None = None + ) -> Turn: + turn = Turn(role=role, content=content, session_id=session_id, meta=meta) + self._turns.append(turn) + self._persist() + return turn + + def replace_with_summary(self, summary_text: str) -> None: + """Replace all but the last RECENT_TURNS_KEPT turns with one summary turn.""" + recent = self._turns[-RECENT_TURNS_KEPT:] if len(self._turns) > RECENT_TURNS_KEPT else [] + summary_turn = Turn(role="assistant", content=summary_text, meta="summary") + self._turns = [summary_turn, *recent] + self._persist() + + def total_chars(self) -> int: + return sum(len(t.content) for t in self._turns) + + def needs_summarization(self, cap: int) -> bool: + return self.total_chars() > cap + + def to_api_messages(self) -> list[dict[str, str]]: + return [t.to_api_message() for t in self._turns] + + def _persist(self) -> None: + write_json_atomic(self._path, [asdict(t) for t in self._turns]) + + # Test/debug helper + def reset(self) -> None: + self._turns = [] + self._persist() + + +def render_for_log(turn: Turn, max_chars: int = 200) -> dict[str, Any]: + """Compact representation for log lines — full content elided.""" + content = turn.content + if len(content) > max_chars: + content = content[:max_chars] + f"... <{len(turn.content) - max_chars} more chars>" + return { + "role": turn.role, + "ts": turn.ts, + "session_id": turn.session_id, + "meta": turn.meta, + "content": content, + } diff --git a/relay/daemon.py b/relay/daemon.py new file mode 100644 index 0000000..e310b83 --- /dev/null +++ b/relay/daemon.py @@ -0,0 +1,330 @@ +"""Main relay loop. + +One process, one thread, polling-based. The loop: + +1. Drain ``state/jc_input.txt`` if present (highest priority). +2. Drain the ``queue/`` directory oldest-first. +3. Heartbeat: check for stuck-queue alerts. +4. Sleep briefly, repeat. + +Each turn (queue entry or jc_input) goes through ``handle_turn`` which: + +1. Appends the user-side content to history. +2. Summarizes if history exceeds the cap. +3. Sends to the Anthropic API. +4. Appends the assistant response to history. +5. Routes the response: if it begins (within the first 200 chars) with + ``[NEEDS-JC]``, set status to ``needs_jc`` and ntfy JC; otherwise + dispatch to the originating session. + +The status flag is in-memory only (single process); it controls whether +new queue entries are processed while the daemon is paused waiting for +JC input. ``state/status.json`` mirrors it on disk for the future +status endpoint. +""" + +from __future__ import annotations + +import logging +import re +import time +from dataclasses import dataclass, field +from datetime import datetime, timezone + +from relay.anthropic_client import AnthropicClient, TurnResult +from relay.config import Settings +from relay.conversation import Conversation +from relay.dispatch import DispatchManager +from relay.ntfy import notify, topic_url +from relay.queue import QueueEntry, ack, stuck_age_seconds, take_oldest +from relay.state import InstanceLock, write_json_atomic + +logger = logging.getLogger(__name__) + +NEEDS_JC_TOKEN = "[NEEDS-JC]" +NEEDS_JC_SCAN_CHARS = 200 +JC_INPUT_FILE = "jc_input.txt" +STATUS_FILE = "status.json" +STUCK_QUEUE_THRESHOLD_SEC = 600 # 10 min per spec +STUCK_QUEUE_REPEAT_SEC = 600 # don't re-notify more often than this +LOOP_SLEEP_SEC = 1.0 +DISPATCH_PREFIX = re.compile(r"^@(?P[A-Za-z0-9_-]+):\s*", re.MULTILINE) + + +@dataclass +class DaemonStatus: + started_at: str = field( + default_factory=lambda: datetime.now(timezone.utc).isoformat(timespec="seconds") + ) + state: str = "running" # running | needs_jc | error + last_needs_jc_at: str | None = None + last_needs_jc_text: str | None = None + last_dispatch_at: str | None = None + last_dispatch_session: str | None = None + queue_depth: int = 0 + history_chars: int = 0 + history_turns: int = 0 + total_input_tokens: int = 0 + total_output_tokens: int = 0 + total_cost_usd: float = 0.0 + last_stuck_alert_ts: float = 0.0 + + def as_dict(self) -> dict: + d = self.__dict__.copy() + d.pop("last_stuck_alert_ts", None) + return d + + +class Daemon: + def __init__(self, settings: Settings): + self.settings = settings + self.lock = InstanceLock(settings.state_dir / ".lock") + self.conversation = Conversation(settings.state_dir / "conversation.json") + self.dispatch = DispatchManager(settings.dispatch_dir) + self.client = AnthropicClient(api_key=settings.api_key, model=settings.model) + self.status = DaemonStatus() + self._stop = False + + # ---- public API used by __main__ ---- + + def run(self) -> None: + self.lock.acquire() + try: + self._announce_startup() + while not self._stop: + try: + self._tick() + except Exception: + logger.exception("uncaught error in daemon loop; continuing") + self._notify_error( + "Daemon loop error", + "An uncaught exception was logged. Check logs/relay.log.", + ) + self._persist_status() + time.sleep(LOOP_SLEEP_SEC) + finally: + self.lock.release() + + def stop(self) -> None: + self._stop = True + + # ---- internals ---- + + def _announce_startup(self) -> None: + url = topic_url(self.settings.ntfy_topic) + logger.info("=" * 72) + logger.info("relay daemon starting") + logger.info("model: %s", self.settings.model) + logger.info("status state: %s", self.settings.state_dir) + logger.info("ntfy topic: %s", url) + logger.info("Subscribe on phone/laptop to receive needs_jc + error alerts.") + logger.info("registered sessions: %s", [s.session_id for s in self.settings.sessions]) + logger.info("history cap: %d chars", self.settings.history_char_cap) + logger.info("=" * 72) + notify( + self.settings.ntfy_topic, + title="relay daemon online", + message=f"model={self.settings.model}, sessions={len(self.settings.sessions)}", + tags=["robot"], + ) + + def _tick(self) -> None: + # 1) Try to flush any queued dispatches that were waiting on CC consumption. + self.dispatch.flush_all() + + # 2) JC override always takes priority. + if self._handle_jc_input(): + return + + # 3) If paused for needs_jc, do nothing further on the queue. + if self.status.state == "needs_jc": + return + + # 4) Drain queue (one entry per tick — keeps logs and dispatch ordering predictable). + entry = take_oldest(self.settings.queue_dir) + if entry is not None: + self._handle_queue_entry(entry) + + # 5) Heartbeat: stuck-queue check. + self._check_stuck_queue() + + def _handle_jc_input(self) -> bool: + path = self.settings.state_dir / JC_INPUT_FILE + if not path.exists(): + return False + + try: + content = path.read_text(encoding="utf-8") + except OSError as exc: + logger.error("could not read %s: %s", path, exc) + return False + + path.unlink() # consume immediately so a slow API call doesn't double-process + if not content.strip(): + logger.info("jc_input.txt was empty; ignoring") + return False + + # Prefix routing: "@session-id: ..." dispatches directly without an API call. + match = DISPATCH_PREFIX.match(content) + if match: + session_id = match.group("session") + payload = content[match.end() :] + logger.info("JC override: direct dispatch to %s (%d chars)", session_id, len(payload)) + self.dispatch.queue_or_write(session_id, payload) + self.status.last_dispatch_at = datetime.now(timezone.utc).isoformat(timespec="seconds") + self.status.last_dispatch_session = session_id + # JC override clears any needs_jc pause. + self._clear_needs_jc() + return True + + # No prefix → treat as next chat-side turn (JC speaking from chat). + logger.info("JC override: chat-side turn (%d chars)", len(content)) + self._clear_needs_jc() + self._send_chat_turn(user_role_content=content, originating_session=None, source="jc") + return True + + def _handle_queue_entry(self, entry: QueueEntry) -> None: + logger.info("queue entry from %s, %d chars", entry.session_id, len(entry.content)) + try: + self._send_chat_turn( + user_role_content=entry.content, + originating_session=entry.session_id, + source="queue", + ) + except Exception: + logger.exception( + "error processing queue entry %s; leaving in queue for retry", entry.path + ) + return + ack(entry) + + def _send_chat_turn( + self, *, user_role_content: str, originating_session: str | None, source: str + ) -> None: + # Append the user-side turn to history before the API call so a crash + # mid-call doesn't lose the prompt. + self.conversation.append("user", user_role_content, session_id=originating_session) + + # Summarize if we've outgrown the cap. + if self.conversation.needs_summarization(self.settings.history_char_cap): + self._summarize() + + # API call. + result = self.client.send( + system_prompt=self.settings.system_prompt, + messages=self.conversation.to_api_messages(), + ) + self._record_usage(result) + logger.info( + "[%s] api turn ok: in=%d out=%d cache_w=%d cache_r=%d cost=$%.4f", + source, + result.input_tokens, + result.output_tokens, + result.cache_creation_input_tokens, + result.cache_read_input_tokens, + result.estimated_cost_usd, + ) + + # Append assistant response. + self.conversation.append("assistant", result.text) + + # Route: NEEDS-JC pause vs dispatch. + if self._contains_needs_jc(result.text): + self._enter_needs_jc(result.text) + return + + target_session = originating_session or self._fallback_session_id() + if not target_session: + logger.warning( + "no originating session and no fallback session in config; chat reply dropped" + ) + return + self.dispatch.queue_or_write(target_session, result.text) + self.status.last_dispatch_at = datetime.now(timezone.utc).isoformat(timespec="seconds") + self.status.last_dispatch_session = target_session + + def _summarize(self) -> None: + before = self.conversation.total_chars() + # Send the summarization as a fresh user turn appended to current history. + # The API responds with the summary; we then collapse history into + # [summary, last 10 turns]. + self.conversation.append( + "user", self.settings.summarization_prompt, meta="summarize_request" + ) + result = self.client.send( + system_prompt=self.settings.system_prompt, + messages=self.conversation.to_api_messages(), + ) + self._record_usage(result) + # Replace history with summary + most-recent. This drops the + # summarize_request turn we just appended (it's only there to + # produce the summary; not useful in the rolling history). + self.conversation.replace_with_summary(result.text) + after = self.conversation.total_chars() + logger.info( + "summarization: %d chars -> %d chars (cost $%.4f)", + before, + after, + result.estimated_cost_usd, + ) + + def _contains_needs_jc(self, text: str) -> bool: + return NEEDS_JC_TOKEN in text[:NEEDS_JC_SCAN_CHARS] + + def _enter_needs_jc(self, response_text: str) -> None: + self.status.state = "needs_jc" + self.status.last_needs_jc_at = datetime.now(timezone.utc).isoformat(timespec="seconds") + self.status.last_needs_jc_text = response_text[:1000] + logger.warning("[NEEDS-JC] flagged; daemon paused awaiting state/jc_input.txt") + notify( + self.settings.ntfy_topic, + title="[NEEDS-JC] relay paused", + message=response_text[:400], + priority="high", + tags=["warning"], + ) + + def _clear_needs_jc(self) -> None: + if self.status.state != "running": + logger.info("clearing needs_jc state (was %s)", self.status.state) + self.status.state = "running" + + def _fallback_session_id(self) -> str | None: + if self.settings.sessions: + return self.settings.sessions[0].session_id + return None + + def _record_usage(self, result: TurnResult) -> None: + self.status.total_input_tokens += result.input_tokens + self.status.total_output_tokens += result.output_tokens + self.status.total_cost_usd += result.estimated_cost_usd + + def _persist_status(self) -> None: + self.status.queue_depth = ( + len(list((self.settings.queue_dir).iterdir())) + if self.settings.queue_dir.exists() + else 0 + ) + self.status.history_chars = self.conversation.total_chars() + self.status.history_turns = len(self.conversation.turns) + write_json_atomic(self.settings.state_dir / STATUS_FILE, self.status.as_dict()) + + def _check_stuck_queue(self) -> None: + age = stuck_age_seconds(self.settings.queue_dir) + if age <= STUCK_QUEUE_THRESHOLD_SEC: + return + now = time.time() + if now - self.status.last_stuck_alert_ts < STUCK_QUEUE_REPEAT_SEC: + return + self.status.last_stuck_alert_ts = now + logger.warning("queue stuck: oldest entry is %.0fs old", age) + notify( + self.settings.ntfy_topic, + title="relay queue stuck", + message=f"oldest entry is {int(age)}s old; daemon may be paused or the API failing.", + priority="high", + tags=["warning"], + ) + + def _notify_error(self, title: str, message: str) -> None: + notify(self.settings.ntfy_topic, title=title, message=message, priority="high", tags=["x"]) diff --git a/relay/dispatch.py b/relay/dispatch.py new file mode 100644 index 0000000..577d7df --- /dev/null +++ b/relay/dispatch.py @@ -0,0 +1,71 @@ +"""Dispatch writer. + +Each registered CC session has a directory ``dispatch//``. +The daemon delivers a chat-side response by writing it to +``dispatch//input.txt``. CC's polling loop reads, acts on +the content, and **deletes** the file as the acknowledgement. + +The "only write when prior is consumed" rule means the daemon must not +overwrite a pending dispatch — otherwise CC could miss a turn. If the +daemon has new content for a session that hasn't yet consumed the +previous one, it queues internally and waits. +""" + +from __future__ import annotations + +import logging +from collections import defaultdict, deque +from dataclasses import dataclass, field +from pathlib import Path + +from relay.state import write_atomic + +logger = logging.getLogger(__name__) + + +@dataclass +class DispatchManager: + dispatch_dir: Path + _pending: dict[str, deque[str]] = field(default_factory=lambda: defaultdict(deque)) + + def session_dir(self, session_id: str) -> Path: + return self.dispatch_dir / session_id + + def input_path(self, session_id: str) -> Path: + return self.session_dir(session_id) / "input.txt" + + def has_pending(self, session_id: str) -> bool: + return bool(self._pending[session_id]) + + def session_input_present(self, session_id: str) -> bool: + """True iff the session's input.txt exists (CC hasn't consumed yet).""" + return self.input_path(session_id).exists() + + def queue_or_write(self, session_id: str, content: str) -> bool: + """Try to deliver content. If session's input.txt is still present, + queue internally and return False; otherwise write and return True. + """ + self._pending[session_id].append(content) + return self._flush_session(session_id) + + def flush_all(self) -> int: + """Try to flush queued dispatches for every session. Returns count delivered.""" + delivered = 0 + for session_id in list(self._pending.keys()): + while self._pending[session_id]: + if not self._flush_session(session_id): + break + delivered += 1 + return delivered + + def _flush_session(self, session_id: str) -> bool: + if not self._pending[session_id]: + return False + if self.session_input_present(session_id): + return False + next_content = self._pending[session_id].popleft() + path = self.input_path(session_id) + path.parent.mkdir(parents=True, exist_ok=True) + write_atomic(path, next_content) + logger.info("dispatched %d chars to %s -> %s", len(next_content), session_id, path) + return True diff --git a/relay/logs.py b/relay/logs.py new file mode 100644 index 0000000..9fa9d91 --- /dev/null +++ b/relay/logs.py @@ -0,0 +1,46 @@ +"""Logging setup. + +Root logger writes to stdout (so tmux/systemd captures it) and to +``logs/relay-YYYY-MM-DD.log`` with daily rotation. Format includes +timestamp, level, logger name, and message. +""" + +from __future__ import annotations + +import logging +import sys +from logging.handlers import TimedRotatingFileHandler +from pathlib import Path + + +def configure(logs_dir: Path, level: int = logging.INFO) -> None: + logs_dir.mkdir(parents=True, exist_ok=True) + fmt = logging.Formatter( + fmt="%(asctime)s %(levelname)-7s %(name)-22s %(message)s", + datefmt="%Y-%m-%dT%H:%M:%S", + ) + + root = logging.getLogger() + root.setLevel(level) + + # Remove pre-existing handlers (idempotent across reload during tests) + for handler in list(root.handlers): + root.removeHandler(handler) + + stdout = logging.StreamHandler(stream=sys.stdout) + stdout.setFormatter(fmt) + root.addHandler(stdout) + + file_handler = TimedRotatingFileHandler( + filename=str(logs_dir / "relay.log"), + when="midnight", + backupCount=14, + encoding="utf-8", + utc=True, + ) + file_handler.setFormatter(fmt) + root.addHandler(file_handler) + + # Quiet libraries that are too chatty at INFO + for noisy in ("urllib3", "httpx", "httpcore", "anthropic"): + logging.getLogger(noisy).setLevel(logging.WARNING) diff --git a/relay/ntfy.py b/relay/ntfy.py new file mode 100644 index 0000000..2f75903 --- /dev/null +++ b/relay/ntfy.py @@ -0,0 +1,59 @@ +"""ntfy.sh notifications. + +Topic is loaded from settings (auto-generated on first run). The topic +is functionally a password — anyone subscribed to the topic URL +receives notifications. We use cryptographically-random topics +(secrets.token_urlsafe(12)) to make brute-force discovery impractical. + +Notifications are best-effort: a failure to deliver (network down, +ntfy.sh outage) is logged but does NOT block the daemon's main loop. +""" + +from __future__ import annotations + +import logging + +import requests + +logger = logging.getLogger(__name__) + +NTFY_BASE = "https://ntfy.sh" + + +def topic_url(topic: str) -> str: + return f"{NTFY_BASE}/{topic}" + + +def notify( + topic: str, + title: str, + message: str, + *, + priority: str = "default", + tags: list[str] | None = None, +) -> bool: + """Post one notification. Returns True on HTTP 200, False otherwise. + + Never raises on transport errors — this path runs from the daemon's + main loop and a failed notification should not stop work. + """ + if not topic: + logger.warning("ntfy topic is empty; skipping notification: %s", title) + return False + headers = { + "Title": title, + "Priority": priority, + } + if tags: + headers["Tags"] = ",".join(tags) + try: + resp = requests.post( + topic_url(topic), data=message.encode("utf-8"), headers=headers, timeout=10 + ) + except requests.RequestException as exc: + logger.warning("ntfy delivery failed for topic : %s", exc) + return False + if resp.status_code != 200: + logger.warning("ntfy non-200 (%s): %s", resp.status_code, resp.text[:200]) + return False + return True diff --git a/relay/queue.py b/relay/queue.py new file mode 100644 index 0000000..b65fe9a --- /dev/null +++ b/relay/queue.py @@ -0,0 +1,129 @@ +"""Queue intake. + +CC sessions drop JSON envelopes into ``queue/``. The envelope shape: + + { + "session_id": "session-1", + "timestamp": "2026-05-02T12:34:56Z", + "content": "..." + } + +The daemon picks the oldest entry by mtime, validates the envelope, +returns it, and deletes the file once processing succeeds. A +malformed envelope is moved to ``queue/.rejected/`` so the daemon +doesn't retry it forever. +""" + +from __future__ import annotations + +import json +import logging +import shutil +import time +from dataclasses import dataclass +from pathlib import Path + +logger = logging.getLogger(__name__) + + +class QueueError(RuntimeError): + pass + + +@dataclass(frozen=True) +class QueueEntry: + path: Path + session_id: str + timestamp: str + content: str + mtime: float + + +def _validate_envelope(payload: object, *, path: Path) -> tuple[str, str, str]: + if not isinstance(payload, dict): + raise QueueError(f"{path}: envelope must be a JSON object") + for required in ("session_id", "timestamp", "content"): + if required not in payload: + raise QueueError(f"{path}: missing required key '{required}'") + session_id = str(payload["session_id"]).strip() + timestamp = str(payload["timestamp"]).strip() + content = str(payload["content"]) + if not session_id: + raise QueueError(f"{path}: session_id is empty") + if not timestamp: + raise QueueError(f"{path}: timestamp is empty") + if not content.strip(): + raise QueueError(f"{path}: content is empty") + return session_id, timestamp, content + + +def _reject(path: Path, reason: str) -> None: + rejected_dir = path.parent / ".rejected" + rejected_dir.mkdir(exist_ok=True) + target = rejected_dir / path.name + shutil.move(str(path), str(target)) + logger.warning("rejected queue entry %s -> %s: %s", path, target, reason) + + +def list_pending(queue_dir: Path) -> list[Path]: + """Return queue entries oldest-first, excluding the .rejected dir.""" + if not queue_dir.exists(): + return [] + entries = [ + p + for p in queue_dir.iterdir() + if p.is_file() and p.suffix == ".json" and not p.name.startswith(".") + ] + entries.sort(key=lambda p: p.stat().st_mtime) + return entries + + +def take_oldest(queue_dir: Path) -> QueueEntry | None: + """Read and validate the oldest queue entry. Returns None if queue empty. + + Rejects malformed envelopes by moving them to .rejected/. The caller + is responsible for calling ``ack(entry)`` once processing succeeds + to delete the file; until then it remains in the queue. + """ + pending = list_pending(queue_dir) + if not pending: + return None + path = pending[0] + try: + raw = path.read_text(encoding="utf-8") + payload = json.loads(raw) + session_id, timestamp, content = _validate_envelope(payload, path=path) + except (OSError, json.JSONDecodeError) as exc: + _reject(path, f"unreadable / not JSON: {exc}") + return None + except QueueError as exc: + _reject(path, str(exc)) + return None + + return QueueEntry( + path=path, + session_id=session_id, + timestamp=timestamp, + content=content, + mtime=path.stat().st_mtime, + ) + + +def ack(entry: QueueEntry) -> None: + """Delete the queue file. Called by the daemon after the turn is dispatched.""" + try: + entry.path.unlink() + except FileNotFoundError: + pass + + +def stuck_age_seconds(queue_dir: Path) -> float: + """How long the oldest pending entry has been waiting. 0 if queue empty. + + Used by the daemon's heartbeat to fire a ntfy alert when the queue + is stuck (e.g., persistent API failure). + """ + pending = list_pending(queue_dir) + if not pending: + return 0.0 + return time.time() - pending[0].stat().st_mtime diff --git a/relay/state.py b/relay/state.py new file mode 100644 index 0000000..fb02504 --- /dev/null +++ b/relay/state.py @@ -0,0 +1,100 @@ +"""Atomic state-file I/O and instance lock. + +The conversation history lives at ``state/conversation.json``. Mutated +in-memory by the daemon and written via temp+rename to avoid partial +writes if the process is killed mid-write. A ``state/.lock`` advisory +file stops two daemons from running against the same directory. +""" + +from __future__ import annotations + +import errno +import fcntl +import json +import os +from dataclasses import dataclass +from pathlib import Path +from tempfile import NamedTemporaryFile + + +class StateError(RuntimeError): + pass + + +@dataclass +class InstanceLock: + """Holds a flock on state/.lock for the daemon's lifetime. + + Released automatically on process exit (kernel closes the fd) or by + calling ``release()``. ``acquire()`` raises StateError if another + daemon already holds the lock. + """ + + lock_path: Path + _fd: int | None = None + + def acquire(self) -> None: + self._fd = os.open(self.lock_path, os.O_RDWR | os.O_CREAT, 0o600) + try: + fcntl.flock(self._fd, fcntl.LOCK_EX | fcntl.LOCK_NB) + except OSError as exc: + os.close(self._fd) + self._fd = None + if exc.errno in {errno.EAGAIN, errno.EACCES}: + raise StateError( + f"Another daemon is holding {self.lock_path}; refusing to start" + ) from exc + raise + os.write(self._fd, str(os.getpid()).encode()) + + def release(self) -> None: + if self._fd is None: + return + try: + fcntl.flock(self._fd, fcntl.LOCK_UN) + finally: + os.close(self._fd) + self._fd = None + + +def write_atomic(path: Path, data: str) -> None: + """Atomically write text to ``path`` via temp file + rename. + + Crash-safe: a partial write leaves the temp file but does not + overwrite the target. fsync the data file (not the directory) so the + rename atomicity gives us durability up to the OS-level rename + barrier. + """ + + path.parent.mkdir(parents=True, exist_ok=True) + with NamedTemporaryFile( + mode="w", + encoding="utf-8", + dir=str(path.parent), + prefix=f".{path.name}.", + suffix=".tmp", + delete=False, + ) as tmp: + tmp.write(data) + tmp.flush() + os.fsync(tmp.fileno()) + tmp_path = Path(tmp.name) + os.replace(tmp_path, path) + + +def write_json_atomic(path: Path, value: object) -> None: + write_atomic(path, json.dumps(value, indent=2, ensure_ascii=False, sort_keys=False)) + + +def read_json(path: Path, default: object) -> object: + """Read JSON or return default when the file is missing or empty.""" + + if not path.exists(): + return default + raw = path.read_text(encoding="utf-8").strip() + if not raw: + return default + try: + return json.loads(raw) + except json.JSONDecodeError as exc: + raise StateError(f"Corrupt JSON at {path}: {exc}") from exc diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_anthropic_client.py b/tests/test_anthropic_client.py new file mode 100644 index 0000000..1461313 --- /dev/null +++ b/tests/test_anthropic_client.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +from relay.anthropic_client import AnthropicClient + + +def _fake_message( + text: str, *, in_tokens: int = 100, out_tokens: int = 50, cache_w: int = 0, cache_r: int = 0 +): + """Build a stand-in for anthropic.types.Message with .content + .usage.""" + return SimpleNamespace( + content=[SimpleNamespace(type="text", text=text)], + usage=SimpleNamespace( + input_tokens=in_tokens, + output_tokens=out_tokens, + cache_creation_input_tokens=cache_w, + cache_read_input_tokens=cache_r, + ), + ) + + +def _client_with_mock(model: str = "claude-opus-4-7") -> tuple[AnthropicClient, MagicMock]: + client = AnthropicClient(api_key="sk-fake", model=model) + mock_create = MagicMock(return_value=_fake_message("response text")) + client._sdk = SimpleNamespace(messages=SimpleNamespace(create=mock_create)) + return client, mock_create + + +def test_send_returns_assistant_text_and_usage() -> None: + client, _ = _client_with_mock() + result = client.send(system_prompt="sys", messages=[{"role": "user", "content": "u1"}]) + assert result.text == "response text" + assert result.input_tokens == 100 + assert result.output_tokens == 50 + + +def test_send_passes_system_prompt_with_cache_control() -> None: + client, mock_create = _client_with_mock() + client.send(system_prompt="SYSTEM PROMPT", messages=[{"role": "user", "content": "u"}]) + call = mock_create.call_args + system_arg = call.kwargs["system"] + assert system_arg[0]["text"] == "SYSTEM PROMPT" + assert system_arg[0]["cache_control"] == {"type": "ephemeral"} + + +def test_cost_estimation_opus() -> None: + client, _ = _client_with_mock() + # Override response with specific token counts + client._sdk.messages.create = MagicMock( + return_value=_fake_message("x", in_tokens=1_000_000, out_tokens=0) + ) + result = client.send(system_prompt="s", messages=[{"role": "user", "content": "u"}]) + # Opus: $15/M input → $15 + assert abs(result.estimated_cost_usd - 15.0) < 0.01 + + +def test_cost_estimation_includes_cache_savings() -> None: + client, _ = _client_with_mock() + client._sdk.messages.create = MagicMock( + return_value=_fake_message( + "x", + in_tokens=0, + out_tokens=0, + cache_w=1_000_000, # 1M cache write at $18.75/M + cache_r=1_000_000, # 1M cache read at $1.50/M + ) + ) + result = client.send(system_prompt="s", messages=[{"role": "user", "content": "u"}]) + # cache_w 1M @ $18.75 + cache_r 1M @ $1.50 = $20.25 + assert abs(result.estimated_cost_usd - 20.25) < 0.01 + + +def test_cost_estimation_unknown_model_falls_back_to_opus() -> None: + client = AnthropicClient(api_key="sk-fake", model="claude-future-9000") + client._sdk = SimpleNamespace( + messages=SimpleNamespace( + create=MagicMock(return_value=_fake_message("x", in_tokens=1_000_000, out_tokens=0)) + ) + ) + result = client.send(system_prompt="s", messages=[{"role": "user", "content": "u"}]) + # Falls back to Opus pricing + assert abs(result.estimated_cost_usd - 15.0) < 0.01 + + +def test_send_concatenates_multiple_text_blocks() -> None: + client, _ = _client_with_mock() + multi = SimpleNamespace( + content=[ + SimpleNamespace(type="text", text="hello "), + SimpleNamespace(type="text", text="world"), + ], + usage=SimpleNamespace( + input_tokens=0, + output_tokens=0, + cache_creation_input_tokens=0, + cache_read_input_tokens=0, + ), + ) + client._sdk.messages.create = MagicMock(return_value=multi) + result = client.send(system_prompt="s", messages=[{"role": "user", "content": "u"}]) + assert result.text == "hello world" + + +def test_send_ignores_non_text_blocks() -> None: + client, _ = _client_with_mock() + mixed = SimpleNamespace( + content=[ + SimpleNamespace(type="text", text="text part"), + SimpleNamespace(type="tool_use"), # no .text — would crash if not filtered + ], + usage=SimpleNamespace( + input_tokens=0, + output_tokens=0, + cache_creation_input_tokens=0, + cache_read_input_tokens=0, + ), + ) + client._sdk.messages.create = MagicMock(return_value=mixed) + result = client.send(system_prompt="s", messages=[{"role": "user", "content": "u"}]) + assert result.text == "text part" diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..c199674 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +from pathlib import Path +from unittest.mock import patch + +import pytest + + +def _write_env(tmp_path: Path, body: str) -> Path: + env = tmp_path / ".env" + env.write_text(body) + return env + + +def test_load_settings_reads_env_and_creates_runtime_dirs( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + _write_env( + tmp_path, + "ANTHROPIC_API_KEY=sk-ant-test\nANTHROPIC_MODEL=claude-haiku-4-5-20251001\nNTFY_TOPIC=topictopic12\n", + ) + monkeypatch.setattr("relay.config.REPO_ROOT", tmp_path) + monkeypatch.setattr("relay.config.ENV_FILE", tmp_path / ".env") + monkeypatch.setattr("relay.config.CONFIG_FILE", tmp_path / "config.yaml") + + from relay.config import load_settings # import after monkeypatch + + s = load_settings() + assert s.api_key == "sk-ant-test" + assert s.model == "claude-haiku-4-5-20251001" + assert s.ntfy_topic == "topictopic12" + for d in (s.queue_dir, s.dispatch_dir, s.state_dir, s.logs_dir): + assert d.exists() + + +def test_load_settings_generates_ntfy_topic_when_blank( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + _write_env(tmp_path, "ANTHROPIC_API_KEY=sk-ant-test\nNTFY_TOPIC=\n") + monkeypatch.setattr("relay.config.REPO_ROOT", tmp_path) + monkeypatch.setattr("relay.config.ENV_FILE", tmp_path / ".env") + monkeypatch.setattr("relay.config.CONFIG_FILE", tmp_path / "config.yaml") + + from relay.config import load_settings + + with patch("relay.config._generate_ntfy_topic", return_value="generatedtopic1"): + s = load_settings() + assert s.ntfy_topic == "generatedtopic1" + # Persisted back to .env + env_after = (tmp_path / ".env").read_text() + assert "NTFY_TOPIC=generatedtopic1" in env_after + + +def test_load_settings_treats_inline_hash_comment_as_blank( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """Defensive: dotenv parses 'KEY= # comment' as 'KEY' = ' # comment'. + + The settings loader must recognise that as effectively blank so a + fresh topic is generated. + """ + _write_env(tmp_path, "ANTHROPIC_API_KEY=sk-ant-test\nNTFY_TOPIC= # generated on first run\n") + monkeypatch.setattr("relay.config.REPO_ROOT", tmp_path) + monkeypatch.setattr("relay.config.ENV_FILE", tmp_path / ".env") + monkeypatch.setattr("relay.config.CONFIG_FILE", tmp_path / "config.yaml") + + from relay.config import load_settings + + with patch("relay.config._generate_ntfy_topic", return_value="newtopic"): + s = load_settings() + assert s.ntfy_topic == "newtopic" + + +def test_missing_api_key_raises(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + _write_env(tmp_path, "ANTHROPIC_MODEL=claude-opus-4-7\n") + monkeypatch.setattr("relay.config.REPO_ROOT", tmp_path) + monkeypatch.setattr("relay.config.ENV_FILE", tmp_path / ".env") + monkeypatch.setattr("relay.config.CONFIG_FILE", tmp_path / "config.yaml") + # Also ensure env var doesn't leak in + monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) + + from relay.config import load_settings + + with pytest.raises(RuntimeError, match="ANTHROPIC_API_KEY"): + load_settings() + + +def test_default_config_yaml_is_seeded_when_missing( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + _write_env(tmp_path, "ANTHROPIC_API_KEY=sk-ant-test\nNTFY_TOPIC=t1234\n") + monkeypatch.setattr("relay.config.REPO_ROOT", tmp_path) + monkeypatch.setattr("relay.config.ENV_FILE", tmp_path / ".env") + monkeypatch.setattr("relay.config.CONFIG_FILE", tmp_path / "config.yaml") + + from relay.config import load_settings + + s = load_settings() + cfg_path = tmp_path / "config.yaml" + assert cfg_path.exists() + assert any(sess.session_id == "session-1" for sess in s.sessions) diff --git a/tests/test_conversation.py b/tests/test_conversation.py new file mode 100644 index 0000000..4bae2ee --- /dev/null +++ b/tests/test_conversation.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import json +from pathlib import Path + +from relay.conversation import RECENT_TURNS_KEPT, Conversation, render_for_log + + +def test_append_and_read_back(tmp_path: Path) -> None: + convo = Conversation(tmp_path / "c.json") + convo.append("user", "hello", session_id="session-1") + convo.append("assistant", "hi back") + assert [t.role for t in convo.turns] == ["user", "assistant"] + persisted = json.loads((tmp_path / "c.json").read_text()) + assert persisted[0]["session_id"] == "session-1" + + +def test_total_chars_sums_content(tmp_path: Path) -> None: + convo = Conversation(tmp_path / "c.json") + convo.append("user", "abc") + convo.append("assistant", "defg") + assert convo.total_chars() == 7 + + +def test_needs_summarization_threshold(tmp_path: Path) -> None: + convo = Conversation(tmp_path / "c.json") + convo.append("user", "x" * 100) + assert not convo.needs_summarization(200) + assert convo.needs_summarization(50) + + +def test_replace_with_summary_keeps_recent_turns(tmp_path: Path) -> None: + convo = Conversation(tmp_path / "c.json") + for i in range(RECENT_TURNS_KEPT + 5): + convo.append("user", f"u{i}") + convo.append("assistant", f"a{i}") + convo.replace_with_summary("SUMMARY") + assert convo.turns[0].role == "assistant" + assert convo.turns[0].content == "SUMMARY" + assert convo.turns[0].meta == "summary" + # 1 summary + RECENT_TURNS_KEPT verbatim + assert len(convo.turns) == 1 + RECENT_TURNS_KEPT + + +def test_replace_with_summary_persists(tmp_path: Path) -> None: + convo = Conversation(tmp_path / "c.json") + for i in range(20): + convo.append("user", f"u{i}") + convo.replace_with_summary("S") + reloaded = Conversation(tmp_path / "c.json") + assert reloaded.turns[0].content == "S" + + +def test_to_api_messages_strips_metadata(tmp_path: Path) -> None: + convo = Conversation(tmp_path / "c.json") + convo.append("user", "x", session_id="s1") + convo.append("assistant", "y") + msgs = convo.to_api_messages() + assert msgs == [{"role": "user", "content": "x"}, {"role": "assistant", "content": "y"}] + + +def test_render_for_log_truncates_long_content(tmp_path: Path) -> None: + convo = Conversation(tmp_path / "c.json") + convo.append("user", "a" * 1000) + rendered = render_for_log(convo.turns[0], max_chars=50) + assert len(rendered["content"]) < 100 + assert "more chars" in rendered["content"] diff --git a/tests/test_daemon_loop.py b/tests/test_daemon_loop.py new file mode 100644 index 0000000..ac283d2 --- /dev/null +++ b/tests/test_daemon_loop.py @@ -0,0 +1,229 @@ +"""End-to-end loop test with a fake AnthropicClient. + +Drives one daemon `_tick` at a time so we can inspect state after each +event. Tests: + +- queue entry → API call → dispatch +- response containing [NEEDS-JC] → state=needs_jc, no dispatch +- jc_input.txt with @session-N: prefix → direct dispatch, no API call +- jc_input.txt without prefix → next chat-side turn (API call), clears needs_jc +- summarization triggers when history exceeds the cap +- malformed queue file is rejected, doesn't break the loop +""" + +from __future__ import annotations + +import json +from dataclasses import replace +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import patch + +import pytest + +from relay.anthropic_client import TurnResult +from relay.config import SessionConfig, Settings +from relay.daemon import Daemon + + +@pytest.fixture +def settings(tmp_path: Path) -> Settings: + s = Settings( + api_key="sk-fake", + model="claude-opus-4-7", + ntfy_topic="test-topic", + status_port=0, + history_char_cap=400_000, + repo_root=tmp_path, + queue_dir=tmp_path / "queue", + dispatch_dir=tmp_path / "dispatch", + state_dir=tmp_path / "state", + logs_dir=tmp_path / "logs", + system_prompt="SYS", + summarization_prompt="please summarize", + sessions=(SessionConfig(session_id="session-1"),), + ) + for d in (s.queue_dir, s.dispatch_dir, s.state_dir, s.logs_dir): + d.mkdir(parents=True, exist_ok=True) + return s + + +def _fake_turn_result(text: str, *, cost: float = 0.001) -> TurnResult: + return TurnResult( + text=text, + input_tokens=100, + output_tokens=50, + cache_creation_input_tokens=0, + cache_read_input_tokens=0, + estimated_cost_usd=cost, + raw=SimpleNamespace(), + ) + + +def _drop_queue(settings: Settings, name: str, content: str, session_id: str = "session-1") -> Path: + path = settings.queue_dir / name + path.write_text( + json.dumps( + {"session_id": session_id, "timestamp": "2026-05-02T00:00:00Z", "content": content} + ) + ) + return path + + +def _build_daemon(settings: Settings, *, response_text: str) -> Daemon: + daemon = Daemon(settings) + daemon.client.send = lambda **kwargs: _fake_turn_result(response_text) # type: ignore[assignment] + return daemon + + +def test_queue_entry_flows_through_to_dispatch(settings: Settings) -> None: + _drop_queue(settings, "a.json", "first message") + daemon = _build_daemon(settings, response_text="ok cool reply") + + with patch("relay.daemon.notify"): + daemon._tick() + + dispatched = settings.dispatch_dir / "session-1" / "input.txt" + assert dispatched.read_text() == "ok cool reply" + # Queue entry consumed + assert list(settings.queue_dir.glob("*.json")) == [] + # History recorded + assert daemon.conversation.total_chars() > 0 + + +def test_needs_jc_token_pauses_dispatch(settings: Settings) -> None: + _drop_queue(settings, "a.json", "i need a decision") + daemon = _build_daemon(settings, response_text="[NEEDS-JC] should I pick option A or B?") + + with patch("relay.daemon.notify") as ntfy: + daemon._tick() + + assert daemon.status.state == "needs_jc" + assert daemon.status.last_needs_jc_text is not None + # No dispatch happened + dispatched = settings.dispatch_dir / "session-1" / "input.txt" + assert not dispatched.exists() + # ntfy was called with the high priority alert + assert any("paused" in c.kwargs.get("title", "").lower() for c in ntfy.call_args_list) + + +def test_needs_jc_pause_blocks_subsequent_queue_entries(settings: Settings) -> None: + _drop_queue(settings, "a.json", "first") + daemon = _build_daemon(settings, response_text="[NEEDS-JC] need decision") + with patch("relay.daemon.notify"): + daemon._tick() + + # New queue entry while paused + _drop_queue(settings, "b.json", "second") + with patch("relay.daemon.notify"): + daemon._tick() + + # Second queue entry NOT consumed + pending = list(settings.queue_dir.glob("*.json")) + assert len(pending) == 1 + + +def test_jc_input_prefix_dispatches_directly_without_api_call(settings: Settings) -> None: + daemon = _build_daemon(settings, response_text="should not be called") + api_calls: list[object] = [] + daemon.client.send = lambda **kwargs: (api_calls.append(kwargs), _fake_turn_result("x"))[1] # type: ignore[assignment] + + (settings.state_dir / "jc_input.txt").write_text("@session-1: do this thing") + with patch("relay.daemon.notify"): + daemon._tick() + + dispatched = settings.dispatch_dir / "session-1" / "input.txt" + assert dispatched.read_text() == "do this thing" + assert api_calls == [] # no API call for direct dispatch + assert not (settings.state_dir / "jc_input.txt").exists() + + +def test_jc_input_clears_needs_jc(settings: Settings) -> None: + daemon = _build_daemon(settings, response_text="[NEEDS-JC] question") + _drop_queue(settings, "a.json", "trigger needs_jc") + with patch("relay.daemon.notify"): + daemon._tick() + assert daemon.status.state == "needs_jc" + + # JC writes a non-prefix response — treated as next chat turn. + daemon.client.send = lambda **kwargs: _fake_turn_result("post-jc reply") # type: ignore[assignment] + (settings.state_dir / "jc_input.txt").write_text("here's the answer: do A") + with patch("relay.daemon.notify"): + daemon._tick() + + assert daemon.status.state == "running" + dispatched = settings.dispatch_dir / "session-1" / "input.txt" + assert dispatched.read_text() == "post-jc reply" + + +def test_summarization_triggers_at_cap(settings: Settings) -> None: + tiny = replace(settings, history_char_cap=200) + daemon = Daemon(tiny) + + # Pre-load history to exceed the cap + daemon.conversation.append("user", "u" * 100) + daemon.conversation.append("assistant", "a" * 200) + assert daemon.conversation.needs_summarization(200) + + daemon.client.send = lambda **kwargs: _fake_turn_result("CONDENSED SUMMARY") # type: ignore[assignment] + + _drop_queue(tiny, "a.json", "next user turn") + with patch("relay.daemon.notify"): + daemon._tick() + + # First turn after cap was hit invokes summarization, then sends the user + # turn through the same call. The summary turn is now the head. + assert daemon.conversation.turns[0].meta == "summary" + assert daemon.conversation.turns[0].content == "CONDENSED SUMMARY" + + +def test_malformed_queue_file_does_not_crash_loop(settings: Settings) -> None: + bad = settings.queue_dir / "broken.json" + bad.write_text("{not json") + daemon = _build_daemon(settings, response_text="x") + + with patch("relay.daemon.notify"): + daemon._tick() + + # Bad file moved to .rejected, no crash + assert not bad.exists() + assert (settings.queue_dir / ".rejected" / "broken.json").exists() + + +def test_dispatch_queues_when_session_input_pending(settings: Settings) -> None: + """Two queue entries for the same session and the first dispatch hasn't been consumed yet.""" + daemon = _build_daemon(settings, response_text="reply 1") + + _drop_queue(settings, "a.json", "msg 1") + with patch("relay.daemon.notify"): + daemon._tick() + assert (settings.dispatch_dir / "session-1" / "input.txt").read_text() == "reply 1" + + daemon.client.send = lambda **kwargs: _fake_turn_result("reply 2") # type: ignore[assignment] + _drop_queue(settings, "b.json", "msg 2") + with patch("relay.daemon.notify"): + daemon._tick() + + # First dispatch still present (CC hasn't consumed); second is queued. + assert (settings.dispatch_dir / "session-1" / "input.txt").read_text() == "reply 1" + assert daemon.dispatch.has_pending("session-1") + + # CC consumes + (settings.dispatch_dir / "session-1" / "input.txt").unlink() + with patch("relay.daemon.notify"): + daemon._tick() + assert (settings.dispatch_dir / "session-1" / "input.txt").read_text() == "reply 2" + + +def test_status_persists_to_disk(settings: Settings) -> None: + daemon = _build_daemon(settings, response_text="ok") + _drop_queue(settings, "a.json", "hi") + with patch("relay.daemon.notify"): + daemon._tick() + daemon._persist_status() + + status_path = settings.state_dir / "status.json" + payload = json.loads(status_path.read_text()) + assert payload["state"] == "running" + assert payload["last_dispatch_session"] == "session-1" + assert payload["history_turns"] >= 2 # user + assistant diff --git a/tests/test_dispatch.py b/tests/test_dispatch.py new file mode 100644 index 0000000..9c44a01 --- /dev/null +++ b/tests/test_dispatch.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from pathlib import Path + +from relay.dispatch import DispatchManager + + +def test_writes_input_when_session_dir_empty(tmp_path: Path) -> None: + mgr = DispatchManager(tmp_path / "dispatch") + delivered = mgr.queue_or_write("session-1", "hello") + assert delivered is True + assert (tmp_path / "dispatch" / "session-1" / "input.txt").read_text() == "hello" + + +def test_does_not_overwrite_pending_input(tmp_path: Path) -> None: + mgr = DispatchManager(tmp_path / "dispatch") + mgr.queue_or_write("session-1", "first") + delivered = mgr.queue_or_write("session-1", "second") + assert delivered is False + # First file unchanged + assert (tmp_path / "dispatch" / "session-1" / "input.txt").read_text() == "first" + assert mgr.has_pending("session-1") + + +def test_flush_all_delivers_after_consumer_deletes(tmp_path: Path) -> None: + mgr = DispatchManager(tmp_path / "dispatch") + mgr.queue_or_write("session-1", "first") + mgr.queue_or_write("session-1", "second") + # Consumer "consumes" first + (tmp_path / "dispatch" / "session-1" / "input.txt").unlink() + delivered = mgr.flush_all() + assert delivered == 1 + assert (tmp_path / "dispatch" / "session-1" / "input.txt").read_text() == "second" + assert not mgr.has_pending("session-1") + + +def test_independent_sessions_do_not_block_each_other(tmp_path: Path) -> None: + mgr = DispatchManager(tmp_path / "dispatch") + mgr.queue_or_write("session-1", "a") # blocks on session-1 + delivered = mgr.queue_or_write("session-2", "b") # session-2 unaffected + assert delivered is True + assert (tmp_path / "dispatch" / "session-2" / "input.txt").read_text() == "b" + + +def test_input_path_resolution(tmp_path: Path) -> None: + mgr = DispatchManager(tmp_path / "dispatch") + p = mgr.input_path("session-1") + assert p == tmp_path / "dispatch" / "session-1" / "input.txt" diff --git a/tests/test_ntfy.py b/tests/test_ntfy.py new file mode 100644 index 0000000..a2662cc --- /dev/null +++ b/tests/test_ntfy.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from relay.ntfy import notify, topic_url + + +def test_topic_url_builds_correctly() -> None: + assert topic_url("abc123") == "https://ntfy.sh/abc123" + + +def test_notify_returns_false_on_empty_topic() -> None: + assert notify("", "title", "msg") is False + + +def test_notify_posts_with_headers_and_body() -> None: + with patch("relay.ntfy.requests.post") as post: + post.return_value = MagicMock(status_code=200) + ok = notify("topic", "the title", "body text", priority="high", tags=["warning"]) + assert ok is True + args, kwargs = post.call_args + assert args[0] == "https://ntfy.sh/topic" + assert kwargs["data"] == b"body text" + assert kwargs["headers"]["Title"] == "the title" + assert kwargs["headers"]["Priority"] == "high" + assert kwargs["headers"]["Tags"] == "warning" + + +def test_notify_returns_false_on_non_200() -> None: + with patch("relay.ntfy.requests.post") as post: + post.return_value = MagicMock(status_code=500, text="oops") + assert notify("topic", "t", "m") is False + + +def test_notify_returns_false_on_network_error() -> None: + import requests + + with patch("relay.ntfy.requests.post", side_effect=requests.ConnectionError("down")): + assert notify("topic", "t", "m") is False diff --git a/tests/test_queue.py b/tests/test_queue.py new file mode 100644 index 0000000..949385a --- /dev/null +++ b/tests/test_queue.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import json +import time +from pathlib import Path + +from relay.queue import ack, list_pending, stuck_age_seconds, take_oldest + + +def _drop(queue_dir: Path, name: str, payload: dict) -> Path: + queue_dir.mkdir(parents=True, exist_ok=True) + path = queue_dir / name + path.write_text(json.dumps(payload)) + return path + + +def test_take_oldest_returns_none_on_empty_queue(tmp_path: Path) -> None: + assert take_oldest(tmp_path / "queue") is None + + +def test_take_oldest_returns_oldest_by_mtime(tmp_path: Path) -> None: + queue_dir = tmp_path / "queue" + older = _drop(queue_dir, "a.json", {"session_id": "s1", "timestamp": "t1", "content": "first"}) + time.sleep(0.01) + newer = _drop(queue_dir, "b.json", {"session_id": "s2", "timestamp": "t2", "content": "second"}) + entry = take_oldest(queue_dir) + assert entry is not None + assert entry.path == older + assert entry.session_id == "s1" + assert entry.content == "first" + assert newer.exists() + + +def test_invalid_json_is_rejected(tmp_path: Path) -> None: + queue_dir = tmp_path / "queue" + queue_dir.mkdir() + bad = queue_dir / "bad.json" + bad.write_text("not json") + assert take_oldest(queue_dir) is None + assert not bad.exists() + assert (queue_dir / ".rejected" / "bad.json").exists() + + +def test_envelope_missing_keys_is_rejected(tmp_path: Path) -> None: + queue_dir = tmp_path / "queue" + _drop(queue_dir, "x.json", {"session_id": "s1"}) # no timestamp / content + assert take_oldest(queue_dir) is None + assert (queue_dir / ".rejected" / "x.json").exists() + + +def test_envelope_blank_content_is_rejected(tmp_path: Path) -> None: + queue_dir = tmp_path / "queue" + _drop(queue_dir, "x.json", {"session_id": "s1", "timestamp": "t", "content": " "}) + assert take_oldest(queue_dir) is None + + +def test_ack_deletes_file(tmp_path: Path) -> None: + queue_dir = tmp_path / "queue" + _drop(queue_dir, "a.json", {"session_id": "s1", "timestamp": "t", "content": "x"}) + entry = take_oldest(queue_dir) + assert entry is not None + ack(entry) + assert not entry.path.exists() + + +def test_list_pending_skips_rejected_dir(tmp_path: Path) -> None: + queue_dir = tmp_path / "queue" + queue_dir.mkdir() + rejected = queue_dir / ".rejected" + rejected.mkdir() + (rejected / "old.json").write_text("{}") + a = _drop(queue_dir, "a.json", {"session_id": "s", "timestamp": "t", "content": "x"}) + pending = list_pending(queue_dir) + assert pending == [a] + + +def test_stuck_age_returns_zero_on_empty_queue(tmp_path: Path) -> None: + assert stuck_age_seconds(tmp_path / "queue") == 0 + + +def test_stuck_age_reports_oldest_age(tmp_path: Path) -> None: + queue_dir = tmp_path / "queue" + path = _drop(queue_dir, "a.json", {"session_id": "s", "timestamp": "t", "content": "x"}) + # Backdate the file + old_mtime = time.time() - 60 + import os + + os.utime(path, (old_mtime, old_mtime)) + age = stuck_age_seconds(queue_dir) + assert age >= 59 diff --git a/tests/test_real_api.py b/tests/test_real_api.py new file mode 100644 index 0000000..f9e564d --- /dev/null +++ b/tests/test_real_api.py @@ -0,0 +1,57 @@ +"""Live-API smoke test. + +Runs only when ANTHROPIC_API_KEY is set in the env (or .env). Sends one +real turn, verifies we get back a non-empty assistant text and usage +counts. Marked ``real_api`` so it is excluded from the default suite — +opt in with ``pytest -m real_api``. +""" + +from __future__ import annotations + +import os + +import pytest +from dotenv import dotenv_values + +from relay.anthropic_client import AnthropicClient +from relay.config import REPO_ROOT + + +def _api_key_available() -> str | None: + env_values = dotenv_values(REPO_ROOT / ".env") + return ( + env_values.get("ANTHROPIC_API_KEY") or os.environ.get("ANTHROPIC_API_KEY") or "" + ).strip() or None + + +@pytest.mark.real_api +@pytest.mark.skipif(_api_key_available() is None, reason="ANTHROPIC_API_KEY not configured") +def test_real_round_trip_against_haiku() -> None: + """One round-trip against the Haiku model — cheapest available, ~$0.0001 per call. + + Skips (rather than fails) on credit-balance errors so the test is + honest about why it didn't run. Real authentication errors (401) + still fail loudly. + """ + import anthropic + + api_key = _api_key_available() + assert api_key is not None + client = AnthropicClient( + api_key=api_key, model="claude-haiku-4-5-20251001", max_output_tokens=64 + ) + try: + result = client.send( + system_prompt="You are a terse echo. Reply with exactly 'pong' and nothing else.", + messages=[{"role": "user", "content": "ping"}], + ) + except anthropic.BadRequestError as exc: + msg = str(exc).lower() + if "credit balance" in msg or "billing" in msg: + pytest.skip(f"Anthropic API key has no credit balance: {exc}") + raise + + assert result.text.strip().lower() == "pong" + assert result.input_tokens > 0 + assert result.output_tokens > 0 + assert result.estimated_cost_usd > 0 diff --git a/tests/test_state.py b/tests/test_state.py new file mode 100644 index 0000000..018a1d5 --- /dev/null +++ b/tests/test_state.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import json +import os +import threading +from pathlib import Path + +import pytest + +from relay.state import InstanceLock, StateError, read_json, write_atomic, write_json_atomic + + +def test_write_atomic_round_trips(tmp_path: Path) -> None: + target = tmp_path / "f.json" + write_atomic(target, "hello") + assert target.read_text() == "hello" + + +def test_write_atomic_does_not_leave_temp_files(tmp_path: Path) -> None: + target = tmp_path / "f.json" + write_atomic(target, "hello") + siblings = [p.name for p in tmp_path.iterdir() if p != target] + assert siblings == [] + + +def test_write_json_atomic_round_trips(tmp_path: Path) -> None: + target = tmp_path / "value.json" + payload = [{"k": "v"}, {"k2": "v2"}] + write_json_atomic(target, payload) + assert json.loads(target.read_text()) == payload + + +def test_read_json_returns_default_when_missing(tmp_path: Path) -> None: + assert read_json(tmp_path / "nope.json", default=[]) == [] + assert read_json(tmp_path / "nope.json", default={"x": 1}) == {"x": 1} + + +def test_read_json_returns_default_when_empty(tmp_path: Path) -> None: + target = tmp_path / "empty.json" + target.write_text("") + assert read_json(target, default=[]) == [] + + +def test_read_json_raises_on_corruption(tmp_path: Path) -> None: + target = tmp_path / "bad.json" + target.write_text("{not json") + with pytest.raises(StateError): + read_json(target, default=[]) + + +def test_instance_lock_acquires_and_releases(tmp_path: Path) -> None: + lock = InstanceLock(tmp_path / ".lock") + lock.acquire() + try: + # PID written to file + assert (tmp_path / ".lock").read_text().strip() == str(os.getpid()) + finally: + lock.release() + + +def test_instance_lock_rejects_second_acquirer(tmp_path: Path) -> None: + lock1 = InstanceLock(tmp_path / ".lock") + lock1.acquire() + try: + lock2 = InstanceLock(tmp_path / ".lock") + with pytest.raises(StateError): + lock2.acquire() + finally: + lock1.release() + + +def test_instance_lock_release_is_idempotent(tmp_path: Path) -> None: + lock = InstanceLock(tmp_path / ".lock") + lock.acquire() + lock.release() + lock.release() # second release should not raise + + +def test_concurrent_acquire_serializes(tmp_path: Path) -> None: + """Confirm flock semantics: two threads acquiring the same lock can't overlap.""" + lock_path = tmp_path / ".lock" + held: list[bool] = [] + error_count = [0] + + def worker(): + lock = InstanceLock(lock_path) + try: + lock.acquire() + held.append(True) + lock.release() + except StateError: + error_count[0] += 1 + + threads = [threading.Thread(target=worker) for _ in range(2)] + for t in threads: + t.start() + for t in threads: + t.join() + + # At least one succeeded; the other either succeeded after the first + # released, or failed to acquire (depending on scheduling). + assert sum(held) + error_count[0] == 2