123 lines
4.4 KiB
Python
123 lines
4.4 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from types import SimpleNamespace
|
||
|
|
from unittest.mock import MagicMock
|
||
|
|
|
||
|
|
from relay.anthropic_client import AnthropicClient
|
||
|
|
|
||
|
|
|
||
|
|
def _fake_message(
|
||
|
|
text: str, *, in_tokens: int = 100, out_tokens: int = 50, cache_w: int = 0, cache_r: int = 0
|
||
|
|
):
|
||
|
|
"""Build a stand-in for anthropic.types.Message with .content + .usage."""
|
||
|
|
return SimpleNamespace(
|
||
|
|
content=[SimpleNamespace(type="text", text=text)],
|
||
|
|
usage=SimpleNamespace(
|
||
|
|
input_tokens=in_tokens,
|
||
|
|
output_tokens=out_tokens,
|
||
|
|
cache_creation_input_tokens=cache_w,
|
||
|
|
cache_read_input_tokens=cache_r,
|
||
|
|
),
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _client_with_mock(model: str = "claude-opus-4-7") -> tuple[AnthropicClient, MagicMock]:
|
||
|
|
client = AnthropicClient(api_key="sk-fake", model=model)
|
||
|
|
mock_create = MagicMock(return_value=_fake_message("response text"))
|
||
|
|
client._sdk = SimpleNamespace(messages=SimpleNamespace(create=mock_create))
|
||
|
|
return client, mock_create
|
||
|
|
|
||
|
|
|
||
|
|
def test_send_returns_assistant_text_and_usage() -> None:
|
||
|
|
client, _ = _client_with_mock()
|
||
|
|
result = client.send(system_prompt="sys", messages=[{"role": "user", "content": "u1"}])
|
||
|
|
assert result.text == "response text"
|
||
|
|
assert result.input_tokens == 100
|
||
|
|
assert result.output_tokens == 50
|
||
|
|
|
||
|
|
|
||
|
|
def test_send_passes_system_prompt_with_cache_control() -> None:
|
||
|
|
client, mock_create = _client_with_mock()
|
||
|
|
client.send(system_prompt="SYSTEM PROMPT", messages=[{"role": "user", "content": "u"}])
|
||
|
|
call = mock_create.call_args
|
||
|
|
system_arg = call.kwargs["system"]
|
||
|
|
assert system_arg[0]["text"] == "SYSTEM PROMPT"
|
||
|
|
assert system_arg[0]["cache_control"] == {"type": "ephemeral"}
|
||
|
|
|
||
|
|
|
||
|
|
def test_cost_estimation_opus() -> None:
|
||
|
|
client, _ = _client_with_mock()
|
||
|
|
# Override response with specific token counts
|
||
|
|
client._sdk.messages.create = MagicMock(
|
||
|
|
return_value=_fake_message("x", in_tokens=1_000_000, out_tokens=0)
|
||
|
|
)
|
||
|
|
result = client.send(system_prompt="s", messages=[{"role": "user", "content": "u"}])
|
||
|
|
# Opus: $15/M input → $15
|
||
|
|
assert abs(result.estimated_cost_usd - 15.0) < 0.01
|
||
|
|
|
||
|
|
|
||
|
|
def test_cost_estimation_includes_cache_savings() -> None:
|
||
|
|
client, _ = _client_with_mock()
|
||
|
|
client._sdk.messages.create = MagicMock(
|
||
|
|
return_value=_fake_message(
|
||
|
|
"x",
|
||
|
|
in_tokens=0,
|
||
|
|
out_tokens=0,
|
||
|
|
cache_w=1_000_000, # 1M cache write at $18.75/M
|
||
|
|
cache_r=1_000_000, # 1M cache read at $1.50/M
|
||
|
|
)
|
||
|
|
)
|
||
|
|
result = client.send(system_prompt="s", messages=[{"role": "user", "content": "u"}])
|
||
|
|
# cache_w 1M @ $18.75 + cache_r 1M @ $1.50 = $20.25
|
||
|
|
assert abs(result.estimated_cost_usd - 20.25) < 0.01
|
||
|
|
|
||
|
|
|
||
|
|
def test_cost_estimation_unknown_model_falls_back_to_opus() -> None:
|
||
|
|
client = AnthropicClient(api_key="sk-fake", model="claude-future-9000")
|
||
|
|
client._sdk = SimpleNamespace(
|
||
|
|
messages=SimpleNamespace(
|
||
|
|
create=MagicMock(return_value=_fake_message("x", in_tokens=1_000_000, out_tokens=0))
|
||
|
|
)
|
||
|
|
)
|
||
|
|
result = client.send(system_prompt="s", messages=[{"role": "user", "content": "u"}])
|
||
|
|
# Falls back to Opus pricing
|
||
|
|
assert abs(result.estimated_cost_usd - 15.0) < 0.01
|
||
|
|
|
||
|
|
|
||
|
|
def test_send_concatenates_multiple_text_blocks() -> None:
|
||
|
|
client, _ = _client_with_mock()
|
||
|
|
multi = SimpleNamespace(
|
||
|
|
content=[
|
||
|
|
SimpleNamespace(type="text", text="hello "),
|
||
|
|
SimpleNamespace(type="text", text="world"),
|
||
|
|
],
|
||
|
|
usage=SimpleNamespace(
|
||
|
|
input_tokens=0,
|
||
|
|
output_tokens=0,
|
||
|
|
cache_creation_input_tokens=0,
|
||
|
|
cache_read_input_tokens=0,
|
||
|
|
),
|
||
|
|
)
|
||
|
|
client._sdk.messages.create = MagicMock(return_value=multi)
|
||
|
|
result = client.send(system_prompt="s", messages=[{"role": "user", "content": "u"}])
|
||
|
|
assert result.text == "hello world"
|
||
|
|
|
||
|
|
|
||
|
|
def test_send_ignores_non_text_blocks() -> None:
|
||
|
|
client, _ = _client_with_mock()
|
||
|
|
mixed = SimpleNamespace(
|
||
|
|
content=[
|
||
|
|
SimpleNamespace(type="text", text="text part"),
|
||
|
|
SimpleNamespace(type="tool_use"), # no .text — would crash if not filtered
|
||
|
|
],
|
||
|
|
usage=SimpleNamespace(
|
||
|
|
input_tokens=0,
|
||
|
|
output_tokens=0,
|
||
|
|
cache_creation_input_tokens=0,
|
||
|
|
cache_read_input_tokens=0,
|
||
|
|
),
|
||
|
|
)
|
||
|
|
client._sdk.messages.create = MagicMock(return_value=mixed)
|
||
|
|
result = client.send(system_prompt="s", messages=[{"role": "user", "content": "u"}])
|
||
|
|
assert result.text == "text part"
|