|
37 | 37 | from google.adk.models.lite_llm import LiteLLMClient |
38 | 38 | from google.adk.models.lite_llm import TextChunk |
39 | 39 | from google.adk.models.lite_llm import UsageMetadataChunk |
| 40 | +import google.adk.models.lite_llm as lite_llm_module |
40 | 41 | from google.adk.models.llm_request import LlmRequest |
41 | 42 | from google.genai import types |
42 | 43 | import litellm |
|
48 | 49 | from litellm.types.utils import Delta |
49 | 50 | from litellm.types.utils import ModelResponse |
50 | 51 | from litellm.types.utils import StreamingChoices |
| 52 | +from opentelemetry.trace import SpanContext |
| 53 | +from opentelemetry.trace import TraceFlags |
| 54 | +from opentelemetry.trace import TraceState |
51 | 55 | from pydantic import BaseModel |
52 | 56 | from pydantic import Field |
53 | 57 | import pytest |
|
210 | 214 | ] |
211 | 215 |
|
212 | 216 |
|
| 217 | +class _StubSpan: |
| 218 | + |
| 219 | + def __init__(self, span_context): |
| 220 | + self._span_context = span_context |
| 221 | + |
| 222 | + def get_span_context(self): |
| 223 | + return self._span_context |
| 224 | + |
| 225 | + |
| 226 | +def _build_valid_span_context(): |
| 227 | + return SpanContext( |
| 228 | + trace_id=int("0123456789abcdef0123456789abcdef", 16), |
| 229 | + span_id=int("abcdef0123456789", 16), |
| 230 | + is_remote=False, |
| 231 | + trace_flags=TraceFlags(1), |
| 232 | + trace_state=TraceState(), |
| 233 | + ) |
| 234 | + |
| 235 | + |
| 236 | +def _build_invalid_span_context(): |
| 237 | + return SpanContext( |
| 238 | + trace_id=0, |
| 239 | + span_id=0, |
| 240 | + is_remote=False, |
| 241 | + trace_flags=TraceFlags(0), |
| 242 | + trace_state=TraceState(), |
| 243 | + ) |
| 244 | + |
| 245 | + |
| 246 | +def test_maybe_add_traceparent_header_with_existing_headers(monkeypatch): |
| 247 | + span_context = _build_valid_span_context() |
| 248 | + monkeypatch.setattr( |
| 249 | + lite_llm_module.trace, |
| 250 | + "get_current_span", |
| 251 | + lambda: _StubSpan(span_context), |
| 252 | + ) |
| 253 | + |
| 254 | + headers = {"custom": "header"} |
| 255 | + result = LiteLLMClient._maybe_add_traceparent_header(headers) |
| 256 | + |
| 257 | + assert result is not headers |
| 258 | + assert result["custom"] == "header" |
| 259 | + assert result["traceparent"] == ( |
| 260 | + "00-0123456789abcdef0123456789abcdef-abcdef0123456789-01" |
| 261 | + ) |
| 262 | + |
| 263 | + |
| 264 | +def test_maybe_add_traceparent_header_without_existing_headers(monkeypatch): |
| 265 | + span_context = _build_valid_span_context() |
| 266 | + monkeypatch.setattr( |
| 267 | + lite_llm_module.trace, |
| 268 | + "get_current_span", |
| 269 | + lambda: _StubSpan(span_context), |
| 270 | + ) |
| 271 | + |
| 272 | + result = LiteLLMClient._maybe_add_traceparent_header(None) |
| 273 | + |
| 274 | + assert result == { |
| 275 | + "traceparent": "00-0123456789abcdef0123456789abcdef-abcdef0123456789-01" |
| 276 | + } |
| 277 | + |
| 278 | + |
| 279 | +def test_maybe_add_traceparent_header_without_active_span(monkeypatch): |
| 280 | + span_context = _build_invalid_span_context() |
| 281 | + monkeypatch.setattr( |
| 282 | + lite_llm_module.trace, |
| 283 | + "get_current_span", |
| 284 | + lambda: _StubSpan(span_context), |
| 285 | + ) |
| 286 | + |
| 287 | + headers = {"custom": "value"} |
| 288 | + result = LiteLLMClient._maybe_add_traceparent_header(headers) |
| 289 | + |
| 290 | + assert result is headers |
| 291 | + |
| 292 | + |
| 293 | +@pytest.mark.asyncio |
| 294 | +async def test_litellmclient_acompletion_sets_traceparent_header(monkeypatch): |
| 295 | + async_mock = AsyncMock(return_value="response") |
| 296 | + monkeypatch.setattr(lite_llm_module, "acompletion", async_mock) |
| 297 | + |
| 298 | + def fake_helper(headers): |
| 299 | + assert headers == {"existing": "header"} |
| 300 | + return {"existing": "header", "traceparent": "tp"} |
| 301 | + |
| 302 | + monkeypatch.setattr( |
| 303 | + LiteLLMClient, "_maybe_add_traceparent_header", fake_helper |
| 304 | + ) |
| 305 | + |
| 306 | + client = LiteLLMClient() |
| 307 | + await client.acompletion( |
| 308 | + model="test", |
| 309 | + messages=[], |
| 310 | + tools=None, |
| 311 | + extra_headers={"existing": "header"}, |
| 312 | + custom="value", |
| 313 | + ) |
| 314 | + |
| 315 | + async_mock.assert_awaited_once() |
| 316 | + _, kwargs = async_mock.call_args |
| 317 | + assert kwargs["extra_headers"] == { |
| 318 | + "existing": "header", |
| 319 | + "traceparent": "tp", |
| 320 | + } |
| 321 | + assert kwargs["custom"] == "value" |
| 322 | + |
| 323 | + |
| 324 | +def test_litellmclient_completion_sets_traceparent_header(monkeypatch): |
| 325 | + sync_mock = Mock(return_value="response") |
| 326 | + monkeypatch.setattr(lite_llm_module, "completion", sync_mock) |
| 327 | + |
| 328 | + def fake_helper(headers): |
| 329 | + assert headers is None |
| 330 | + return {"traceparent": "tp"} |
| 331 | + |
| 332 | + monkeypatch.setattr( |
| 333 | + LiteLLMClient, "_maybe_add_traceparent_header", fake_helper |
| 334 | + ) |
| 335 | + |
| 336 | + client = LiteLLMClient() |
| 337 | + client.completion( |
| 338 | + model="test", |
| 339 | + messages=[], |
| 340 | + tools=None, |
| 341 | + stream=True, |
| 342 | + ) |
| 343 | + |
| 344 | + sync_mock.assert_called_once() |
| 345 | + _, kwargs = sync_mock.call_args |
| 346 | + assert kwargs["extra_headers"] == {"traceparent": "tp"} |
| 347 | + assert kwargs["stream"] |
| 348 | + |
| 349 | + |
213 | 350 | class _StructuredOutput(BaseModel): |
214 | 351 | value: int = Field(description="Value to emit") |
215 | 352 |
|
|
0 commit comments