🌐 AI搜索 & 代理 主页
Skip to content

Commit f9ab6c2

Browse files
committed
feat: attach trace context to LiteLLM headers
1 parent 7b2fe14 commit f9ab6c2

File tree

2 files changed

+175
-0
lines changed

2 files changed

+175
-0
lines changed

src/google/adk/models/lite_llm.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from litellm import Message
5050
from litellm import ModelResponse
5151
from litellm import OpenAIMessageContent
52+
from opentelemetry import trace
5253
from pydantic import BaseModel
5354
from pydantic import Field
5455
from typing_extensions import override
@@ -225,6 +226,39 @@ class UsageMetadataChunk(BaseModel):
225226
class LiteLLMClient:
226227
"""Provides acompletion method (for better testability)."""
227228

229+
@staticmethod
230+
def _build_traceparent() -> Optional[str]:
231+
span_context = trace.get_current_span().get_span_context()
232+
if not span_context.is_valid:
233+
return None
234+
235+
trace_id = f"{span_context.trace_id:032x}"
236+
span_id = f"{span_context.span_id:016x}"
237+
trace_flags = f"{int(span_context.trace_flags):02x}"
238+
return f"00-{trace_id}-{span_id}-{trace_flags}"
239+
240+
@classmethod
241+
def _maybe_add_traceparent_header(
242+
cls, extra_headers: Optional[dict[str, str]]
243+
) -> Optional[dict[str, str]]:
244+
traceparent = cls._build_traceparent()
245+
if not traceparent:
246+
return extra_headers
247+
248+
headers_with_trace = dict(extra_headers) if extra_headers else {}
249+
headers_with_trace["traceparent"] = traceparent
250+
return headers_with_trace
251+
252+
@classmethod
253+
def _attach_traceparent_header(cls, kwargs: Dict[str, Any]) -> None:
254+
updated_headers = cls._maybe_add_traceparent_header(
255+
kwargs.get("extra_headers")
256+
)
257+
if updated_headers is None:
258+
kwargs.pop("extra_headers", None)
259+
else:
260+
kwargs["extra_headers"] = updated_headers
261+
228262
async def acompletion(
229263
self, model, messages, tools, **kwargs
230264
) -> Union[ModelResponse, CustomStreamWrapper]:
@@ -240,6 +274,8 @@ async def acompletion(
240274
The model response as a message.
241275
"""
242276

277+
self._attach_traceparent_header(kwargs)
278+
243279
return await acompletion(
244280
model=model,
245281
messages=messages,
@@ -263,6 +299,8 @@ def completion(
263299
The response from the model.
264300
"""
265301

302+
self._attach_traceparent_header(kwargs)
303+
266304
return completion(
267305
model=model,
268306
messages=messages,

tests/unittests/models/test_litellm.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from google.adk.models.lite_llm import LiteLLMClient
3838
from google.adk.models.lite_llm import TextChunk
3939
from google.adk.models.lite_llm import UsageMetadataChunk
40+
import google.adk.models.lite_llm as lite_llm_module
4041
from google.adk.models.llm_request import LlmRequest
4142
from google.genai import types
4243
import litellm
@@ -48,6 +49,9 @@
4849
from litellm.types.utils import Delta
4950
from litellm.types.utils import ModelResponse
5051
from litellm.types.utils import StreamingChoices
52+
from opentelemetry.trace import SpanContext
53+
from opentelemetry.trace import TraceFlags
54+
from opentelemetry.trace import TraceState
5155
from pydantic import BaseModel
5256
from pydantic import Field
5357
import pytest
@@ -210,6 +214,139 @@
210214
]
211215

212216

217+
class _StubSpan:
218+
219+
def __init__(self, span_context):
220+
self._span_context = span_context
221+
222+
def get_span_context(self):
223+
return self._span_context
224+
225+
226+
def _build_valid_span_context():
227+
return SpanContext(
228+
trace_id=int("0123456789abcdef0123456789abcdef", 16),
229+
span_id=int("abcdef0123456789", 16),
230+
is_remote=False,
231+
trace_flags=TraceFlags(1),
232+
trace_state=TraceState(),
233+
)
234+
235+
236+
def _build_invalid_span_context():
237+
return SpanContext(
238+
trace_id=0,
239+
span_id=0,
240+
is_remote=False,
241+
trace_flags=TraceFlags(0),
242+
trace_state=TraceState(),
243+
)
244+
245+
246+
def test_maybe_add_traceparent_header_with_existing_headers(monkeypatch):
247+
span_context = _build_valid_span_context()
248+
monkeypatch.setattr(
249+
lite_llm_module.trace,
250+
"get_current_span",
251+
lambda: _StubSpan(span_context),
252+
)
253+
254+
headers = {"custom": "header"}
255+
result = LiteLLMClient._maybe_add_traceparent_header(headers)
256+
257+
assert result is not headers
258+
assert result["custom"] == "header"
259+
assert result["traceparent"] == (
260+
"00-0123456789abcdef0123456789abcdef-abcdef0123456789-01"
261+
)
262+
263+
264+
def test_maybe_add_traceparent_header_without_existing_headers(monkeypatch):
265+
span_context = _build_valid_span_context()
266+
monkeypatch.setattr(
267+
lite_llm_module.trace,
268+
"get_current_span",
269+
lambda: _StubSpan(span_context),
270+
)
271+
272+
result = LiteLLMClient._maybe_add_traceparent_header(None)
273+
274+
assert result == {
275+
"traceparent": "00-0123456789abcdef0123456789abcdef-abcdef0123456789-01"
276+
}
277+
278+
279+
def test_maybe_add_traceparent_header_without_active_span(monkeypatch):
280+
span_context = _build_invalid_span_context()
281+
monkeypatch.setattr(
282+
lite_llm_module.trace,
283+
"get_current_span",
284+
lambda: _StubSpan(span_context),
285+
)
286+
287+
headers = {"custom": "value"}
288+
result = LiteLLMClient._maybe_add_traceparent_header(headers)
289+
290+
assert result is headers
291+
292+
293+
@pytest.mark.asyncio
294+
async def test_litellmclient_acompletion_sets_traceparent_header(monkeypatch):
295+
async_mock = AsyncMock(return_value="response")
296+
monkeypatch.setattr(lite_llm_module, "acompletion", async_mock)
297+
298+
def fake_helper(headers):
299+
assert headers == {"existing": "header"}
300+
return {"existing": "header", "traceparent": "tp"}
301+
302+
monkeypatch.setattr(
303+
LiteLLMClient, "_maybe_add_traceparent_header", fake_helper
304+
)
305+
306+
client = LiteLLMClient()
307+
await client.acompletion(
308+
model="test",
309+
messages=[],
310+
tools=None,
311+
extra_headers={"existing": "header"},
312+
custom="value",
313+
)
314+
315+
async_mock.assert_awaited_once()
316+
_, kwargs = async_mock.call_args
317+
assert kwargs["extra_headers"] == {
318+
"existing": "header",
319+
"traceparent": "tp",
320+
}
321+
assert kwargs["custom"] == "value"
322+
323+
324+
def test_litellmclient_completion_sets_traceparent_header(monkeypatch):
325+
sync_mock = Mock(return_value="response")
326+
monkeypatch.setattr(lite_llm_module, "completion", sync_mock)
327+
328+
def fake_helper(headers):
329+
assert headers is None
330+
return {"traceparent": "tp"}
331+
332+
monkeypatch.setattr(
333+
LiteLLMClient, "_maybe_add_traceparent_header", fake_helper
334+
)
335+
336+
client = LiteLLMClient()
337+
client.completion(
338+
model="test",
339+
messages=[],
340+
tools=None,
341+
stream=True,
342+
)
343+
344+
sync_mock.assert_called_once()
345+
_, kwargs = sync_mock.call_args
346+
assert kwargs["extra_headers"] == {"traceparent": "tp"}
347+
assert kwargs["stream"]
348+
349+
213350
class _StructuredOutput(BaseModel):
214351
value: int = Field(description="Value to emit")
215352

0 commit comments

Comments
 (0)