diff --git a/.github/workflows/analyze-releases-for-adk-docs-updates.yml b/.github/workflows/analyze-releases-for-adk-docs-updates.yml
index c8a86fac66..21414ae534 100644
--- a/.github/workflows/analyze-releases-for-adk-docs-updates.yml
+++ b/.github/workflows/analyze-releases-for-adk-docs-updates.yml
@@ -16,10 +16,10 @@ jobs:
steps:
- name: Checkout repository
- uses: actions/checkout@v4
+ uses: actions/checkout@v5
- name: Set up Python
- uses: actions/setup-python@v5
+ uses: actions/setup-python@v6
with:
python-version: '3.11'
diff --git a/.github/workflows/check-file-contents.yml b/.github/workflows/check-file-contents.yml
index 861e2247a0..bb575e0f20 100644
--- a/.github/workflows/check-file-contents.yml
+++ b/.github/workflows/check-file-contents.yml
@@ -24,7 +24,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout Code
- uses: actions/checkout@v4
+ uses: actions/checkout@v5
with:
fetch-depth: 2
diff --git a/.github/workflows/copybara-pr-handler.yml b/.github/workflows/copybara-pr-handler.yml
index 670389c527..4ca3c48803 100644
--- a/.github/workflows/copybara-pr-handler.yml
+++ b/.github/workflows/copybara-pr-handler.yml
@@ -25,7 +25,7 @@ jobs:
steps:
- name: Check for Copybara commits and close PRs
- uses: actions/github-script@v7
+ uses: actions/github-script@v8
with:
github-token: ${{ secrets.ADK_TRIAGE_AGENT }}
script: |
diff --git a/.github/workflows/discussion_answering.yml b/.github/workflows/discussion_answering.yml
index 71c06ba9f6..d9bfffc361 100644
--- a/.github/workflows/discussion_answering.yml
+++ b/.github/workflows/discussion_answering.yml
@@ -15,10 +15,10 @@ jobs:
steps:
- name: Checkout repository
- uses: actions/checkout@v4
+ uses: actions/checkout@v5
- name: Set up Python
- uses: actions/setup-python@v5
+ uses: actions/setup-python@v6
with:
python-version: '3.11'
diff --git a/.github/workflows/isort.yml b/.github/workflows/isort.yml
index e1a087742c..b8b24da5ce 100644
--- a/.github/workflows/isort.yml
+++ b/.github/workflows/isort.yml
@@ -26,12 +26,12 @@ jobs:
steps:
- name: Checkout code
- uses: actions/checkout@v4
+ uses: actions/checkout@v5
with:
fetch-depth: 2
- name: Set up Python
- uses: actions/setup-python@v5
+ uses: actions/setup-python@v6
with:
python-version: '3.x'
diff --git a/.github/workflows/pr-triage.yml b/.github/workflows/pr-triage.yml
index a8a2082094..55b088b505 100644
--- a/.github/workflows/pr-triage.yml
+++ b/.github/workflows/pr-triage.yml
@@ -20,10 +20,10 @@ jobs:
steps:
- name: Checkout repository
- uses: actions/checkout@v4
+ uses: actions/checkout@v5
- name: Set up Python
- uses: actions/setup-python@v5
+ uses: actions/setup-python@v6
with:
python-version: '3.11'
diff --git a/.github/workflows/pyink.yml b/.github/workflows/pyink.yml
index ef9e72e453..0822757fa0 100644
--- a/.github/workflows/pyink.yml
+++ b/.github/workflows/pyink.yml
@@ -26,12 +26,12 @@ jobs:
steps:
- name: Checkout code
- uses: actions/checkout@v4
+ uses: actions/checkout@v5
with:
fetch-depth: 2
- name: Set up Python
- uses: actions/setup-python@v5
+ uses: actions/setup-python@v6
with:
python-version: '3.x'
diff --git a/.github/workflows/python-unit-tests.yml b/.github/workflows/python-unit-tests.yml
index 42b6174813..3fc6bd943f 100644
--- a/.github/workflows/python-unit-tests.yml
+++ b/.github/workflows/python-unit-tests.yml
@@ -25,14 +25,14 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
+ python-version: ["3.10", "3.11", "3.12", "3.13"]
steps:
- name: Checkout code
- uses: actions/checkout@v4
+ uses: actions/checkout@v5
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v5
+ uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}
@@ -48,14 +48,6 @@ jobs:
- name: Run unit tests with pytest
run: |
source .venv/bin/activate
- if [[ "${{ matrix.python-version }}" == "3.9" ]]; then
- pytest tests/unittests \
- --ignore=tests/unittests/a2a \
- --ignore=tests/unittests/tools/mcp_tool \
- --ignore=tests/unittests/artifacts/test_artifact_service.py \
- --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py
- else
- pytest tests/unittests \
- --ignore=tests/unittests/artifacts/test_artifact_service.py \
- --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py
- fi
\ No newline at end of file
+ pytest tests/unittests \
+ --ignore=tests/unittests/artifacts/test_artifact_service.py \
+ --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py
\ No newline at end of file
diff --git a/.github/workflows/stale-bot.yml b/.github/workflows/stale-bot.yml
new file mode 100644
index 0000000000..6948b56459
--- /dev/null
+++ b/.github/workflows/stale-bot.yml
@@ -0,0 +1,43 @@
+name: ADK Stale Issue Auditor
+
+on:
+ workflow_dispatch:
+
+ schedule:
+ # This runs at 6:00 AM UTC (10 PM PST)
+ - cron: '0 6 * * *'
+
+jobs:
+ audit-stale-issues:
+ runs-on: ubuntu-latest
+ timeout-minutes: 60
+
+ permissions:
+ issues: write
+ contents: read
+
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v5
+
+ - name: Set up Python
+ uses: actions/setup-python@v6
+ with:
+ python-version: '3.11'
+
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install requests google-adk
+
+ - name: Run Auditor Agent Script
+ env:
+ GITHUB_TOKEN: ${{ secrets.ADK_TRIAGE_AGENT }}
+ GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }}
+ OWNER: ${{ github.repository_owner }}
+ REPO: adk-python
+ CONCURRENCY_LIMIT: 3
+ LLM_MODEL_NAME: "gemini-2.5-flash"
+ PYTHONPATH: contributing/samples
+
+ run: python -m adk_stale_agent.main
\ No newline at end of file
diff --git a/.github/workflows/triage.yml b/.github/workflows/triage.yml
index 57e729e9b5..46153f413a 100644
--- a/.github/workflows/triage.yml
+++ b/.github/workflows/triage.yml
@@ -2,21 +2,32 @@ name: ADK Issue Triaging Agent
on:
issues:
- types: [opened, reopened]
+ types: [opened, labeled]
+ schedule:
+ # Run every 6 hours to triage untriaged issues
+ - cron: '0 */6 * * *'
jobs:
agent-triage-issues:
runs-on: ubuntu-latest
+ # Run for:
+ # - Scheduled runs (batch processing)
+ # - New issues (need component labeling)
+ # - Issues labeled with "planned" (need owner assignment)
+ if: >-
+ github.event_name == 'schedule' ||
+ github.event.action == 'opened' ||
+ github.event.label.name == 'planned'
permissions:
issues: write
contents: read
steps:
- name: Checkout repository
- uses: actions/checkout@v4
+ uses: actions/checkout@v5
- name: Set up Python
- uses: actions/setup-python@v5
+ uses: actions/setup-python@v6
with:
python-version: '3.11'
@@ -30,8 +41,8 @@ jobs:
GITHUB_TOKEN: ${{ secrets.ADK_TRIAGE_AGENT }}
GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }}
GOOGLE_GENAI_USE_VERTEXAI: 0
- OWNER: 'google'
- REPO: 'adk-python'
+ OWNER: ${{ github.repository_owner }}
+ REPO: ${{ github.event.repository.name }}
INTERACTIVE: 0
EVENT_NAME: ${{ github.event_name }} # 'issues', 'schedule', etc.
ISSUE_NUMBER: ${{ github.event.issue.number }}
diff --git a/.github/workflows/upload-adk-docs-to-vertex-ai-search.yml b/.github/workflows/upload-adk-docs-to-vertex-ai-search.yml
index 9b1f042917..bce7598c2f 100644
--- a/.github/workflows/upload-adk-docs-to-vertex-ai-search.yml
+++ b/.github/workflows/upload-adk-docs-to-vertex-ai-search.yml
@@ -13,7 +13,7 @@ jobs:
steps:
- name: Checkout repository
- uses: actions/checkout@v4
+ uses: actions/checkout@v5
- name: Clone adk-docs repository
run: git clone https://github.com/google/adk-docs.git /tmp/adk-docs
@@ -22,7 +22,7 @@ jobs:
run: git clone https://github.com/google/adk-python.git /tmp/adk-python
- name: Set up Python
- uses: actions/setup-python@v5
+ uses: actions/setup-python@v6
with:
python-version: '3.11'
diff --git a/CHANGELOG.md b/CHANGELOG.md
index ced7b7026b..93dc505adc 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,146 @@
# Changelog
+## [1.20.0](https://github.com/google/adk-python/compare/v1.19.0...v1.20.0) (2025-12-01)
+
+
+### Features
+* **[Core]**
+ * Add enum constraint to `agent_name` for `transfer_to_agent` ([4a42d0d](https://github.com/google/adk-python/commit/4a42d0d9d81b7aab98371427f70a7707dbfb8bc4))
+ * Add validation for unique sub-agent names ([#3557](https://github.com/google/adk-python/issues/3557)) ([2247a45](https://github.com/google/adk-python/commit/2247a45922afdf0a733239b619f45601d9b325ec))
+ * Support streaming function call arguments in progressive SSE streaming feature ([786aaed](https://github.com/google/adk-python/commit/786aaed335e1ce64b7e92dff2f4af8316b2ef593))
+
+* **[Models]**
+ * Enable multi-provider support for Claude and LiteLLM ([d29261a](https://github.com/google/adk-python/commit/d29261a3dc9c5a603feef27ea657c4a03bb8a089))
+
+* **[Tools]**
+ * Create APIRegistryToolset to add tools from Cloud API registry to agent ([ec4ccd7](https://github.com/google/adk-python/commit/ec4ccd718feeadeb6b2b59fcc0e9ff29a4fd0bac))
+ * Add an option to disallow propagating runner plugins to AgentTool runner ([777dba3](https://github.com/google/adk-python/commit/777dba3033a9a14667fb009ba017f648177be41d))
+
+* **[Web]**
+ * Added an endpoint to list apps with details ([b57fe5f](https://github.com/google/adk-python/commit/b57fe5f4598925ec7592917bb32c7f0d6eca287a))
+
+
+### Bug Fixes
+
+* Allow image parts in user messages for Anthropic Claude ([5453b5b](https://github.com/google/adk-python/commit/5453b5bfdedc91d9d668c9eac39e3bb009a7bbbf))
+* Mark the Content as non-empty if its first part contains text or inline_data or file_data or func call/response ([631b583](https://github.com/google/adk-python/commit/631b58336d36bfd93e190582be34069613d38559))
+* Fixes double response processing issue in `base_llm_flow.py` where, in Bidi-streaming (live) mode, the multi-agent structure causes duplicated responses after tool calling. ([cf21ca3](https://github.com/google/adk-python/commit/cf21ca358478919207049695ba6b31dc6e0b2673))
+* Fix out of bounds error in _run_async_impl ([8fc6128](https://github.com/google/adk-python/commit/8fc6128b62ba576480d196d4a2597564fd0a7006))
+* Fix paths for public docs ([cd54f48](https://github.com/google/adk-python/commit/cd54f48fed0c87b54fb19743c9c75e790c5d9135))
+* Ensure request bodies without explicit names are named 'body' ([084c2de](https://github.com/google/adk-python/commit/084c2de0dac84697906e2b4beebf008bbd9ae8e1)), closes [#2213](https://github.com/google/adk-python/issues/2213)
+* Optimize Stale Agent with GraphQL and Search API to resolve 429 Quota errors ([cb19d07](https://github.com/google/adk-python/commit/cb19d0714c90cd578551753680f39d8d6076c79b))
+* Update AgentTool to use Agent's description when input_schema is provided in FunctionDeclaration ([52674e7](https://github.com/google/adk-python/commit/52674e7fac6b7689f0e3871d41c4523e13471a7e))
+* Update LiteLLM system instruction role from "developer" to "system" ([2e1f730](https://github.com/google/adk-python/commit/2e1f730c3bc0eb454b76d7f36b7b9f1da7304cfe)), closes [#3657](https://github.com/google/adk-python/issues/3657)
+* Update session last update time when appending events ([a3e4ad3](https://github.com/google/adk-python/commit/a3e4ad3cd130714affcaa880f696aeb498cd93af)), closes [#2721](https://github.com/google/adk-python/issues/2721)
+* Update the retry_on_closed_resource decorator to retry on all errors ([a3aa077](https://github.com/google/adk-python/commit/a3aa07722a7de3e08807e86fd10f28938f0b267d))
+* Windows Path Handling and Normalize Cross-Platform Path Resolution in AgentLoader ([a1c09b7](https://github.com/google/adk-python/commit/a1c09b724bb37513eaabaff9643eeaa68014f14d))
+
+
+### Documentation
+
+* Add Code Wiki badge to README ([caf23ac](https://github.com/google/adk-python/commit/caf23ac49fe08bc7f625c61eed4635c26852c3ba))
+
+
+## [1.19.0](https://github.com/google/adk-python/compare/v1.18.0...v1.19.0) (2025-11-19)
+
+### Features
+
+* **[Core]**
+ * Add `id` and `custom_metadata` fields to `MemoryEntry` ([4dd28a3](https://github.com/google/adk-python/commit/4dd28a3970d0f76c571caf80b3e1bea1b79e9dde))
+ * Add progressive SSE streaming feature ([a5ac1d5](https://github.com/google/adk-python/commit/a5ac1d5e14f5ce7cd875d81a494a773710669dc1))
+ * Add a2a_request_meta_provider to RemoteAgent init ([d12468e](https://github.com/google/adk-python/commit/d12468ee5a2b906b6699ccdb94c6a5a4c2822465))
+ * Add feature decorator for the feature registry system ([871da73](https://github.com/google/adk-python/commit/871da731f1c09c6a62d51b137d9d2e7c9fb3897a))
+ * Breaking: Raise minimum Python version to 3_10 ([8402832](https://github.com/google/adk-python/commit/840283228ee77fb3dbd737cfe7eb8736d9be5ec8))
+ * Refactor and rename BigQuery agent analytics plugin ([6b14f88](https://github.com/google/adk-python/commit/6b14f887262722ccb85dcd6cef9c0e9b103cfa6e))
+ * Pass custom_metadata through forwarding artifact service ([c642f13](https://github.com/google/adk-python/commit/c642f13f216fb64bc93ac46c1c57702c8a2add8c))
+ * Update save_files_as_artifacts_plugin to never keep inline data ([857de04](https://github.com/google/adk-python/commit/857de04debdeba421075c2283c9bd8518d586624))
+
+* **[Evals]**
+ * Add support for InOrder and AnyOrder match in ToolTrajectoryAvgScore Metric ([e2d3b2d](https://github.com/google/adk-python/commit/e2d3b2d862f7fc93807d16089307d4df25367a24))
+
+* **[Integrations]**
+ * Enhance BQ Plugin Schema, Error Handling, and Logging ([5ac5129](https://github.com/google/adk-python/commit/5ac5129fb01913516d6f5348a825ca83d024d33a))
+ * Schema Enhancements with Descriptions, Partitioning, and Truncation Indicator ([7c993b0](https://github.com/google/adk-python/commit/7c993b01d1b9d582b4e2348f73c0591d47bf2f3a))
+
+* **[Services]**
+ * Add file-backed artifact service ([99ca6aa](https://github.com/google/adk-python/commit/99ca6aa6e6b4027f37d091d9c93da6486def20d7))
+ * Add service factory for configurable session and artifact backends ([a12ae81](https://github.com/google/adk-python/commit/a12ae812d367d2d00ab246f85a73ed679dd3828a))
+ * Add SqliteSessionService and a migration script to migrate existing DB using DatabaseSessionService to SqliteSessionService ([e218254](https://github.com/google/adk-python/commit/e2182544952c0174d1a8307fbba319456dca748b))
+ * Add transcription fields to session events ([3ad30a5](https://github.com/google/adk-python/commit/3ad30a58f95b8729f369d00db799546069d7b23a))
+ * Full async implementation of DatabaseSessionService ([7495941](https://github.com/google/adk-python/commit/74959414d8ded733d584875a49fb4638a12d3ce5))
+
+* **[Models]**
+ * Add experimental feature to use `parameters_json_schema` and `response_json_schema` for McpTool ([1dd97f5](https://github.com/google/adk-python/commit/1dd97f5b45226c25e4c51455c78ebf3ff56ab46a))
+ * Add support for parsing inline JSON tool calls in LiteLLM responses ([22eb7e5](https://github.com/google/adk-python/commit/22eb7e5b06c9e048da5bb34fe7ae9135d00acb4e))
+ * Expose artifact URLs to the model when available ([e3caf79](https://github.com/google/adk-python/commit/e3caf791395ce3cc0b10410a852be6e7b0d8d3b1))
+
+* **[Tools]**
+ * Add BigQuery related label handling ([ffbab4c](https://github.com/google/adk-python/commit/ffbab4cf4ed6ceb313241c345751214d3c0e11ce))
+ * Allow setting max_billed_bytes in BigQuery tools config ([ffbb0b3](https://github.com/google/adk-python/commit/ffbb0b37e128de50ebf57d76cba8b743a8b970d5))
+ * Propagate `application_name` set for the BigQuery Tools as BigQuery job labels ([f13a11e](https://github.com/google/adk-python/commit/f13a11e1dc27c5aa46345154fbe0eecfe1690cbb))
+ * Set per-tool user agent in BQ calls and tool label in BQ jobs ([c0be1df](https://github.com/google/adk-python/commit/c0be1df0521cfd4b84585f404d4385b80d08ba59))
+
+* **[Observability]**
+ * Migrate BigQuery logging to Storage Write API ([a2ce34a](https://github.com/google/adk-python/commit/a2ce34a0b9a8403f830ff637d0e2094e82dee8e7))
+
+### Bug Fixes
+
+* Add `jsonschema` dependency for Agent Builder config validation ([0fa7e46](https://github.com/google/adk-python/commit/0fa7e4619d589dc834f7508a18bc2a3b93ec7fd9))
+* Add None check for `event` in `remote_a2a_agent.py` ([744f94f](https://github.com/google/adk-python/commit/744f94f0c8736087724205bbbad501640b365270))
+* Add vertexai initialization for code being deployed to AgentEngine ([b8e4aed](https://github.com/google/adk-python/commit/b8e4aedfbf0eb55b34599ee24e163b41072a699c))
+* Change LiteLLM content and tool parameter handling ([a19be12](https://github.com/google/adk-python/commit/a19be12c1f04bb62a8387da686499857c24b45c0))
+* Change name for builder agent ([131d39c](https://github.com/google/adk-python/commit/131d39c3db1ae25e3911fa7f72afbe05e24a1c37))
+* Ensure event compaction completes by awaiting task ([b5f5df9](https://github.com/google/adk-python/commit/b5f5df9fa8f616b855c186fcef45bade00653c77))
+* Fix deploy to cloud run on Windows ([29fea7e](https://github.com/google/adk-python/commit/29fea7ec1fb27989f07c90494b2d6acbe76c03d8))
+* Fix error handling when MCP server is unreachable ([ee8106b](https://github.com/google/adk-python/commit/ee8106be77f253e3687e72ae0e236687d254965c))
+* Fix error when query job destination is None ([0ccc43c](https://github.com/google/adk-python/commit/0ccc43cf49dc0882dc896455d6603a602d8a28e7))
+* Fix Improve logic for checking if a MCP session is disconnected ([a754c96](https://github.com/google/adk-python/commit/a754c96d3c4fd00f9c2cd924fc428b68cc5115fb))
+* Fix McpToolset crashing with anyio.BrokenResourceError ([8e0648d](https://github.com/google/adk-python/commit/8e0648df23d0694afd3e245ec4a3c41aa935120a))
+* Fix Safely handle `FunctionDeclaration` without a `required` attribute ([93aad61](https://github.com/google/adk-python/commit/93aad611983dc1daf415d3a73105db45bbdd1988))
+* Fix status code in error message in RestApiTool ([9b75456](https://github.com/google/adk-python/commit/9b754564b3cc5a06ad0c6ae2cd2d83082f9f5943))
+* Fix Use `async for` to loop through event iterator to get all events in vertex_ai_session_service ([9211f4c](https://github.com/google/adk-python/commit/9211f4ce8cc6d918df314d6a2ff13da2e0ef35fa))
+* Fix: Fixes DeprecationWarning when using send method ([2882995](https://github.com/google/adk-python/commit/28829952890c39dbdb4463b2b67ff241d0e9ef6d))
+* Improve logic for checking if a MCP session is disconnected ([a48a1a9](https://github.com/google/adk-python/commit/a48a1a9e889d4126e6f30b56c93718dfbacef624))
+* Improve handling of partial and complete transcriptions in live calls ([1819ecb](https://github.com/google/adk-python/commit/1819ecb4b8c009d02581c2d060fae49cd7fdf653))
+* Keep vertex session event after the session update time ([0ec0195](https://github.com/google/adk-python/commit/0ec01956e86df6ae8e6553c70e410f1f8238ba88))
+* Let part converters also return multiple parts so they can support more usecases ([824ab07](https://github.com/google/adk-python/commit/824ab072124e037cc373c493f43de38f8b61b534))
+* Load agent/app before creating session ([236f562](https://github.com/google/adk-python/commit/236f562cd275f84837be46f7dfb0065f85425169))
+* Remove app name from FileArtifactService directory structure ([12db84f](https://github.com/google/adk-python/commit/12db84f5cd6d8b6e06142f6f6411f6b78ff3f177))
+* Remove hardcoded `google-cloud-aiplatform` version in agent engine requirements ([e15e19d](https://github.com/google/adk-python/commit/e15e19da05ee1b763228467e83f6f73e0eced4b5))
+* Stop updating write mode in the global settings during tool execution ([5adbf95](https://github.com/google/adk-python/commit/5adbf95a0ab0657dd7df5c4a6bac109d424d436e))
+* Update description for `load_artifacts` tool ([c485889](https://github.com/google/adk-python/commit/c4858896ff085bedcfbc42b2010af8bd78febdd0))
+
+### Improvements
+
+* Add BigQuery related label handling ([ffbab4c](https://github.com/google/adk-python/commit/ffbab4cf4ed6ceb313241c345751214d3c0e11ce))
+* Add demo for rewind ([8eb1bdb](https://github.com/google/adk-python/commit/8eb1bdbc58dc709006988f5b6eec5fda25bd0c89))
+* Add debug logging for live connection ([5d5708b](https://github.com/google/adk-python/commit/5d5708b2ab26cb714556311c490b4d6f0a1f9666))
+* Add debug logging for missing function call events ([f3d6fcf](https://github.com/google/adk-python/commit/f3d6fcf44411d07169c14ae12189543f44f96c27))
+* Add default retry options as fall back to llm_request that are made during evals ([696852a](https://github.com/google/adk-python/commit/696852a28095a024cbe76413ee7617356e19a9e3))
+* Add plugin for returning GenAI Parts from tools into the model request ([116b26c](https://github.com/google/adk-python/commit/116b26c33e166bf1a22964e2b67013907fbfcb80))
+* Add support for abstract types in AFC ([2efc184](https://github.com/google/adk-python/commit/2efc184a46173529bdfc622db0d6f3866e7ee778))
+* Add support for structured output schemas in LiteLLM models ([7ea4aed](https://github.com/google/adk-python/commit/7ea4aed35ba70ec5a38dc1b3b0a9808183c2bab1))
+* Add tests for `max_query_result_rows` in BigQuery tool config ([fd33610](https://github.com/google/adk-python/commit/fd33610e967ad814bc02422f5d14dae046bee833))
+* Add type hints in `cleanup_unused_files.py` ([2dea573](https://github.com/google/adk-python/commit/2dea5733b759a7a07d74f36a4d6da7b081afc732))
+* Add util to build our llms.txt and llms-full.txt files
+* ADK changes ([f1f4467](https://github.com/google/adk-python/commit/f1f44675e4a86b75e72cfd838efd8a0399f23e24))
+* Defer import of `google.cloud.storage` in `GCSArtifactService` ([999af55](https://github.com/google/adk-python/commit/999af5588005e7b29451bdbf9252265187ca992d))
+* Defer import of `live`, `Client` and `_transformers` in `google.genai` ([22c6dbe](https://github.com/google/adk-python/commit/22c6dbe83cd1a8900d0ac6fd23d2092f095189fa))
+* Enhance the messaging with possible fixes for RESOURCE_EXHAUSTED errors from Gemini ([b2c45f8](https://github.com/google/adk-python/commit/b2c45f8d910eb7bca4805c567279e65aff72b58a))
+* Improve gepa tau-bench colab for external use ([e02f177](https://github.com/google/adk-python/commit/e02f177790d9772dd253c9102b80df1a9418aa7f))
+* Improve gepa voter agent demo colab ([d118479](https://github.com/google/adk-python/commit/d118479ccf3a970ce9b24ac834b4b6764edb5de4))
+* Lazy import DatabaseSessionService in the adk/sessions/ module ([5f05749](https://github.com/google/adk-python/commit/5f057498a274d3b3db0be0866f04d5225334f54a))
+* Move adk_agent_builder_assistant to built_in_agents ([b2b7f2d](https://github.com/google/adk-python/commit/b2b7f2d6aa5b919a00a92abaf2543993746e939e))
+* Plumb memory service from LocalEvalService to EvaluationGenerator ([dc3f60c](https://github.com/google/adk-python/commit/dc3f60cc939335da49399a69c0b4abc0e7f25ea4))
+* Removes the unrealistic todo comment of visibility management ([e511eb1](https://github.com/google/adk-python/commit/e511eb1f70f2a3fccc9464ddaf54d0165db22feb))
+* Returns agent state regardless if ctx.is_resumable ([d6b928b](https://github.com/google/adk-python/commit/d6b928bdf7cdbf8f1925d4c5227c7d580093348e))
+* Stop logging the full content of LLM blobs ([0826755](https://github.com/google/adk-python/commit/082675546f501a70f4bc8969b9431a2e4808bd13))
+* Update ADK web to match main branch ([14e3802](https://github.com/google/adk-python/commit/14e3802643a2d8ce436d030734fafd163080a1ad))
+* Update agent instructions and retry limit in `plugin_reflect_tool_retry` sample ([01bac62](https://github.com/google/adk-python/commit/01bac62f0c14cce5d454a389b64a9f44a03a3673))
+* Update conformance test CLI to handle long-running tool calls ([dd706bd](https://github.com/google/adk-python/commit/dd706bdc4563a2a815459482237190a63994cb6f))
+* Update Gemini Live model names in live bidi streaming sample ([aa77834](https://github.com/google/adk-python/commit/aa77834e2ecd4b77dfb4e689ef37549b3ebd6134))
+
+
## [1.18.0](https://github.com/google/adk-python/compare/v1.17.0...v1.18.0) (2025-11-05)
### Features
diff --git a/README.md b/README.md
index 3113b1ccdf..9046cc49a3 100644
--- a/README.md
+++ b/README.md
@@ -4,7 +4,7 @@
[](https://pypi.org/project/google-adk/)
[](https://github.com/google/adk-python/actions/workflows/python-unit-tests.yml)
[](https://www.reddit.com/r/agentdevelopmentkit/)
-[](https://deepwiki.com/google/adk-python)
+
diff --git a/contributing/samples/adk_stale_agent/PROMPT_INSTRUCTION.txt b/contributing/samples/adk_stale_agent/PROMPT_INSTRUCTION.txt
new file mode 100644
index 0000000000..8f5f585ff6
--- /dev/null
+++ b/contributing/samples/adk_stale_agent/PROMPT_INSTRUCTION.txt
@@ -0,0 +1,68 @@
+You are a highly intelligent repository auditor for '{OWNER}/{REPO}'.
+Your job is to analyze a specific issue and report findings before taking action.
+
+**Primary Directive:** Ignore any events from users ending in `[bot]`.
+**Reporting Directive:** Output a concise summary starting with "Analysis for Issue #[number]:".
+
+**THRESHOLDS:**
+- Stale Threshold: {stale_threshold_days} days.
+- Close Threshold: {close_threshold_days} days.
+
+**WORKFLOW:**
+1. **Context Gathering**: Call `get_issue_state`.
+2. **Decision**: Follow this strict decision tree using the data returned by the tool.
+
+--- **DECISION TREE** ---
+
+**STEP 1: CHECK IF ALREADY STALE**
+- **Condition**: Is `is_stale` (from tool) **True**?
+- **Action**:
+ - **Check Role**: Look at `last_action_role`.
+
+ - **IF 'author' OR 'other_user'**:
+ - **Context**: The user has responded. The issue is now ACTIVE.
+ - **Action 1**: Call `remove_label_from_issue` with '{STALE_LABEL_NAME}'.
+ - **Action 2 (ALERT CHECK)**: Look at `maintainer_alert_needed`.
+ - **IF True**: User edited description silently.
+ -> **Action**: Call `alert_maintainer_of_edit`.
+ - **IF False**: User commented normally. No alert needed.
+ - **Report**: "Analysis for Issue #[number]: ACTIVE. User activity detected. Removed stale label."
+
+ - **IF 'maintainer'**:
+ - **Check Time**: Check `days_since_stale_label`.
+ - **If `days_since_stale_label` > {close_threshold_days}**:
+ - **Action**: Call `close_as_stale`.
+ - **Report**: "Analysis for Issue #[number]: STALE. Close threshold met. Closing."
+ - **Else**:
+ - **Report**: "Analysis for Issue #[number]: STALE. Waiting for close threshold. No action."
+
+**STEP 2: CHECK IF ACTIVE (NOT STALE)**
+- **Condition**: `is_stale` is **False**.
+- **Action**:
+ - **Check Role**: If `last_action_role` is 'author' or 'other_user':
+ - **Context**: The issue is Active.
+ - **Action (ALERT CHECK)**: Look at `maintainer_alert_needed`.
+ - **IF True**: The user edited the description silently, and we haven't alerted yet.
+ -> **Action**: Call `alert_maintainer_of_edit`.
+ -> **Report**: "Analysis for Issue #[number]: ACTIVE. Silent update detected (Description Edit). Alerted maintainer."
+ - **IF False**:
+ -> **Report**: "Analysis for Issue #[number]: ACTIVE. Last action was by user. No action."
+
+ - **Check Role**: If `last_action_role` is 'maintainer':
+ - **Proceed to STEP 3.**
+
+**STEP 3: ANALYZE MAINTAINER INTENT**
+- **Context**: The last person to act was a Maintainer.
+- **Action**: Read the text in `last_comment_text`.
+ - **Question Check**: Does the text ask a question, request clarification, ask for logs, or suggest trying a fix?
+ - **Time Check**: Is `days_since_activity` > {stale_threshold_days}?
+
+ - **DECISION**:
+ - **IF (Question == YES) AND (Time == YES)**:
+ - **Action**: Call `add_stale_label_and_comment`.
+ - **Check**: If '{REQUEST_CLARIFICATION_LABEL}' is not in `current_labels`, call `add_label_to_issue` for it.
+ - **Report**: "Analysis for Issue #[number]: STALE. Maintainer asked question [days_since_activity] days ago. Marking stale."
+ - **IF (Question == YES) BUT (Time == NO)**:
+ - **Report**: "Analysis for Issue #[number]: PENDING. Maintainer asked question, but threshold not met yet. No action."
+ - **IF (Question == NO)** (e.g., "I am working on this"):
+ - **Report**: "Analysis for Issue #[number]: ACTIVE. Maintainer gave status update (not a question). No action."
\ No newline at end of file
diff --git a/contributing/samples/adk_stale_agent/README.md b/contributing/samples/adk_stale_agent/README.md
new file mode 100644
index 0000000000..afc47b11cc
--- /dev/null
+++ b/contributing/samples/adk_stale_agent/README.md
@@ -0,0 +1,89 @@
+# ADK Stale Issue Auditor Agent
+
+This directory contains an autonomous, **GraphQL-powered** agent designed to audit a GitHub repository for stale issues. It maintains repository hygiene by ensuring all open items are actionable and responsive.
+
+Unlike traditional "Stale Bots" that only look at timestamps, this agent uses a **Unified History Trace** and an **LLM (Large Language Model)** to understand the *context* of a conversation. It distinguishes between a maintainer asking a question (stale candidate) vs. a maintainer providing a status update (active).
+
+---
+
+## Core Logic & Features
+
+The agent operates as a "Repository Auditor," proactively scanning open issues using a high-efficiency decision tree.
+
+### 1. Smart State Verification (GraphQL)
+Instead of making multiple expensive API calls, the agent uses a single **GraphQL** query per issue to reconstruct the entire history of the conversation. It combines:
+* **Comments**
+* **Description/Body Edits** ("Ghost Edits")
+* **Title Renames**
+* **State Changes** (Reopens)
+
+It sorts these events chronologically to determine the **Last Active Actor**.
+
+### 2. The "Last Actor" Rule
+The agent follows a precise logic flow based on who acted last:
+
+* **If Author/User acted last:** The issue is **ACTIVE**.
+ * This includes comments, title changes, and *silent* description edits.
+ * **Action:** The agent immediately removes the `stale` label.
+ * **Silent Update Alert:** If the user edited the description but *did not* comment, the agent posts a specific alert: *"Notification: The author has updated the issue description..."* to ensure maintainers are notified (since GitHub does not trigger notifications for body edits).
+ * **Spam Prevention:** The agent checks if it has already alerted about a specific silent edit to avoid spamming the thread.
+
+* **If Maintainer acted last:** The issue is **POTENTIALLY STALE**.
+ * The agent passes the text of the maintainer's last comment to the LLM.
+
+### 3. Semantic Intent Analysis (LLM)
+If the maintainer was the last person to speak, the LLM analyzes the comment text to determine intent:
+* **Question/Request:** "Can you provide logs?" / "Please try v2.0."
+ * **Verdict:** **STALE** (Waiting on Author).
+ * **Action:** If the time threshold is met, the agent adds the `stale` label. It also checks for the `request clarification` label and adds it if missing.
+* **Status Update:** "We are working on a fix." / "Added to backlog."
+ * **Verdict:** **ACTIVE** (Waiting on Maintainer).
+ * **Action:** No action taken. The issue remains open without stale labels.
+
+### 4. Lifecycle Management
+* **Marking Stale:** After `STALE_HOURS_THRESHOLD` (default: 7 days) of inactivity following a maintainer's question.
+* **Closing:** After `CLOSE_HOURS_AFTER_STALE_THRESHOLD` (default: 7 days) of continued inactivity while marked stale.
+
+---
+
+## Performance & Safety
+
+* **GraphQL Optimized:** Fetches comments, edits, labels, and timeline events in a single network request to minimize latency and API quota usage.
+* **Search API Filtering:** Uses the GitHub Search API to pre-filter issues created recently, ensuring the bot doesn't waste cycles analyzing brand-new issues.
+* **Rate Limit Aware:** Includes intelligent sleeping and retry logic (exponential backoff) to handle GitHub API rate limits (HTTP 429) gracefully.
+* **Execution Metrics:** Logs the time taken and API calls consumed for every issue processed.
+
+---
+
+## Configuration
+
+The agent is configured via environment variables, typically set as secrets in GitHub Actions.
+
+### Required Secrets
+
+| Secret Name | Description |
+| :--- | :--- |
+| `GITHUB_TOKEN` | A GitHub Personal Access Token (PAT) or Service Account Token with `repo` scope. |
+| `GOOGLE_API_KEY` | An API key for the Google AI (Gemini) model used for reasoning. |
+
+### Optional Configuration
+
+These variables control the timing thresholds and model selection.
+
+| Variable Name | Description | Default |
+| :--- | :--- | :--- |
+| `STALE_HOURS_THRESHOLD` | Hours of inactivity after a maintainer's question before marking as `stale`. | `168` (7 days) |
+| `CLOSE_HOURS_AFTER_STALE_THRESHOLD` | Hours after being marked `stale` before the issue is closed. | `168` (7 days) |
+| `LLM_MODEL_NAME`| The specific Gemini model version to use. | `gemini-2.5-flash` |
+| `OWNER` | Repository owner (auto-detected in Actions). | (Environment dependent) |
+| `REPO` | Repository name (auto-detected in Actions). | (Environment dependent) |
+
+---
+
+## Deployment
+
+To deploy this agent, a GitHub Actions workflow file (`.github/workflows/stale-bot.yml`) is recommended.
+
+### Directory Structure Note
+Because this agent resides within the `adk-python` package structure, the workflow must ensure the script is executed correctly to handle imports.
+
diff --git a/contributing/samples/adk_stale_agent/__init__.py b/contributing/samples/adk_stale_agent/__init__.py
new file mode 100644
index 0000000000..c48963cdc7
--- /dev/null
+++ b/contributing/samples/adk_stale_agent/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import agent
diff --git a/contributing/samples/adk_stale_agent/agent.py b/contributing/samples/adk_stale_agent/agent.py
new file mode 100644
index 0000000000..8769adc193
--- /dev/null
+++ b/contributing/samples/adk_stale_agent/agent.py
@@ -0,0 +1,597 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from datetime import datetime
+from datetime import timezone
+import logging
+import os
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+from adk_stale_agent.settings import CLOSE_HOURS_AFTER_STALE_THRESHOLD
+from adk_stale_agent.settings import GITHUB_BASE_URL
+from adk_stale_agent.settings import GRAPHQL_COMMENT_LIMIT
+from adk_stale_agent.settings import GRAPHQL_EDIT_LIMIT
+from adk_stale_agent.settings import GRAPHQL_TIMELINE_LIMIT
+from adk_stale_agent.settings import LLM_MODEL_NAME
+from adk_stale_agent.settings import OWNER
+from adk_stale_agent.settings import REPO
+from adk_stale_agent.settings import REQUEST_CLARIFICATION_LABEL
+from adk_stale_agent.settings import STALE_HOURS_THRESHOLD
+from adk_stale_agent.settings import STALE_LABEL_NAME
+from adk_stale_agent.utils import delete_request
+from adk_stale_agent.utils import error_response
+from adk_stale_agent.utils import get_request
+from adk_stale_agent.utils import patch_request
+from adk_stale_agent.utils import post_request
+import dateutil.parser
+from google.adk.agents.llm_agent import Agent
+from requests.exceptions import RequestException
+
+logger = logging.getLogger("google_adk." + __name__)
+
+# --- Constants ---
+# Used to detect if the bot has already posted an alert to avoid spamming.
+BOT_ALERT_SIGNATURE = (
+ "**Notification:** The author has updated the issue description"
+)
+
+# --- Global Cache ---
+_MAINTAINERS_CACHE: Optional[List[str]] = None
+
+
+def _get_cached_maintainers() -> List[str]:
+ """
+ Fetches the list of repository maintainers.
+
+ This function relies on `utils.get_request` for network resilience.
+ `get_request` is configured with an HTTPAdapter that automatically performs
+ exponential backoff retries (up to 6 times) for 5xx errors and rate limits.
+
+ If the retries are exhausted or the data format is invalid, this function
+ raises a RuntimeError to prevent the bot from running with incorrect permissions.
+
+ Returns:
+ List[str]: A list of GitHub usernames with push access.
+
+ Raises:
+ RuntimeError: If the API fails after all retries or returns invalid data.
+ """
+ global _MAINTAINERS_CACHE
+ if _MAINTAINERS_CACHE is not None:
+ return _MAINTAINERS_CACHE
+
+ logger.info("Initializing Maintainers Cache...")
+
+ try:
+ url = f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/collaborators"
+ params = {"permission": "push"}
+
+ data = get_request(url, params)
+
+ if isinstance(data, list):
+ _MAINTAINERS_CACHE = [u["login"] for u in data if "login" in u]
+ logger.info(f"Cached {len(_MAINTAINERS_CACHE)} maintainers.")
+ return _MAINTAINERS_CACHE
+ else:
+ logger.error(
+ f"Invalid API response format: Expected list, got {type(data)}"
+ )
+ raise ValueError(f"GitHub API returned non-list data: {data}")
+
+ except Exception as e:
+ logger.critical(
+ f"FATAL: Failed to verify repository maintainers. Error: {e}"
+ )
+ raise RuntimeError(
+ "Maintainer verification failed. processing aborted."
+ ) from e
+
+
+def load_prompt_template(filename: str) -> str:
+ """
+ Loads the raw text content of a prompt file.
+
+ Args:
+ filename (str): The name of the file (e.g., 'PROMPT_INSTRUCTION.txt').
+
+ Returns:
+ str: The file content.
+ """
+ file_path = os.path.join(os.path.dirname(__file__), filename)
+ with open(file_path, "r") as f:
+ return f.read()
+
+
+PROMPT_TEMPLATE = load_prompt_template("PROMPT_INSTRUCTION.txt")
+
+
+def _fetch_graphql_data(item_number: int) -> Dict[str, Any]:
+ """
+ Executes the GraphQL query to fetch raw issue data, including comments,
+ edits, and timeline events.
+
+ Args:
+ item_number (int): The GitHub issue number.
+
+ Returns:
+ Dict[str, Any]: The raw 'issue' object from the GraphQL response.
+
+ Raises:
+ RequestException: If the GraphQL query returns errors or the issue is not found.
+ """
+ query = """
+ query($owner: String!, $name: String!, $number: Int!, $commentLimit: Int!, $timelineLimit: Int!, $editLimit: Int!) {
+ repository(owner: $owner, name: $name) {
+ issue(number: $number) {
+ author { login }
+ createdAt
+ labels(first: 20) { nodes { name } }
+
+ comments(last: $commentLimit) {
+ nodes {
+ author { login }
+ body
+ createdAt
+ lastEditedAt
+ }
+ }
+
+ userContentEdits(last: $editLimit) {
+ nodes {
+ editor { login }
+ editedAt
+ }
+ }
+
+ timelineItems(itemTypes: [LABELED_EVENT, RENAMED_TITLE_EVENT, REOPENED_EVENT], last: $timelineLimit) {
+ nodes {
+ __typename
+ ... on LabeledEvent {
+ createdAt
+ actor { login }
+ label { name }
+ }
+ ... on RenamedTitleEvent {
+ createdAt
+ actor { login }
+ }
+ ... on ReopenedEvent {
+ createdAt
+ actor { login }
+ }
+ }
+ }
+ }
+ }
+ }
+ """
+
+ variables = {
+ "owner": OWNER,
+ "name": REPO,
+ "number": item_number,
+ "commentLimit": GRAPHQL_COMMENT_LIMIT,
+ "editLimit": GRAPHQL_EDIT_LIMIT,
+ "timelineLimit": GRAPHQL_TIMELINE_LIMIT,
+ }
+
+ response = post_request(
+ f"{GITHUB_BASE_URL}/graphql", {"query": query, "variables": variables}
+ )
+
+ if "errors" in response:
+ raise RequestException(f"GraphQL Error: {response['errors'][0]['message']}")
+
+ data = response.get("data", {}).get("repository", {}).get("issue", {})
+ if not data:
+ raise RequestException(f"Issue #{item_number} not found.")
+
+ return data
+
+
+def _build_history_timeline(
+ data: Dict[str, Any],
+) -> Tuple[List[Dict[str, Any]], List[datetime], Optional[datetime]]:
+ """
+ Parses raw GraphQL data into a unified, chronologically sorted history list.
+ Also extracts specific event times needed for logic checks.
+
+ Args:
+ data (Dict[str, Any]): The raw issue data from `_fetch_graphql_data`.
+
+ Returns:
+ Tuple[List[Dict], List[datetime], Optional[datetime]]:
+ - history: A list of normalized event dictionaries sorted by time.
+ - label_events: A list of timestamps when the stale label was applied.
+ - last_bot_alert_time: Timestamp of the last bot silent-edit alert (if any).
+ """
+ issue_author = data.get("author", {}).get("login")
+ history = []
+ label_events = []
+ last_bot_alert_time = None
+
+ # 1. Baseline: Issue Creation
+ history.append({
+ "type": "created",
+ "actor": issue_author,
+ "time": dateutil.parser.isoparse(data["createdAt"]),
+ "data": None,
+ })
+
+ # 2. Process Comments
+ for c in data.get("comments", {}).get("nodes", []):
+ if not c:
+ continue
+
+ actor = c.get("author", {}).get("login")
+ c_body = c.get("body", "")
+ c_time = dateutil.parser.isoparse(c.get("createdAt"))
+
+ # Track bot alerts for spam prevention
+ if BOT_ALERT_SIGNATURE in c_body:
+ if last_bot_alert_time is None or c_time > last_bot_alert_time:
+ last_bot_alert_time = c_time
+
+ if actor and not actor.endswith("[bot]"):
+ # Use edit time if available, otherwise creation time
+ e_time = c.get("lastEditedAt")
+ actual_time = dateutil.parser.isoparse(e_time) if e_time else c_time
+ history.append({
+ "type": "commented",
+ "actor": actor,
+ "time": actual_time,
+ "data": c_body,
+ })
+
+ # 3. Process Body Edits ("Ghost Edits")
+ for e in data.get("userContentEdits", {}).get("nodes", []):
+ if not e:
+ continue
+ actor = e.get("editor", {}).get("login")
+ if actor and not actor.endswith("[bot]"):
+ history.append({
+ "type": "edited_description",
+ "actor": actor,
+ "time": dateutil.parser.isoparse(e.get("editedAt")),
+ "data": None,
+ })
+
+ # 4. Process Timeline Events
+ for t in data.get("timelineItems", {}).get("nodes", []):
+ if not t:
+ continue
+
+ etype = t.get("__typename")
+ actor = t.get("actor", {}).get("login")
+ time_val = dateutil.parser.isoparse(t.get("createdAt"))
+
+ if etype == "LabeledEvent":
+ if t.get("label", {}).get("name") == STALE_LABEL_NAME:
+ label_events.append(time_val)
+ continue
+
+ if actor and not actor.endswith("[bot]"):
+ pretty_type = (
+ "renamed_title" if etype == "RenamedTitleEvent" else "reopened"
+ )
+ history.append({
+ "type": pretty_type,
+ "actor": actor,
+ "time": time_val,
+ "data": None,
+ })
+
+ # Sort chronologically
+ history.sort(key=lambda x: x["time"])
+ return history, label_events, last_bot_alert_time
+
+
+def _replay_history_to_find_state(
+ history: List[Dict[str, Any]], maintainers: List[str], issue_author: str
+) -> Dict[str, Any]:
+ """
+ Replays the unified event history to determine the absolute last actor and their role.
+
+ Args:
+ history (List[Dict]): Chronologically sorted list of events.
+ maintainers (List[str]): List of maintainer usernames.
+ issue_author (str): Username of the issue author.
+
+ Returns:
+ Dict[str, Any]: A dictionary containing the last state of the issue:
+ - last_action_role (str): 'author', 'maintainer', or 'other_user'.
+ - last_activity_time (datetime): Timestamp of the last human action.
+ - last_action_type (str): The type of the last action (e.g., 'commented').
+ - last_comment_text (Optional[str]): The text of the last comment.
+ """
+ last_action_role = "author"
+ last_activity_time = history[0]["time"]
+ last_action_type = "created"
+ last_comment_text = None
+
+ for event in history:
+ actor = event["actor"]
+ etype = event["type"]
+
+ role = "other_user"
+ if actor == issue_author:
+ role = "author"
+ elif actor in maintainers:
+ role = "maintainer"
+
+ last_action_role = role
+ last_activity_time = event["time"]
+ last_action_type = etype
+
+ # Only store text if it was a comment (resets on other events like labels/edits)
+ if etype == "commented":
+ last_comment_text = event["data"]
+ else:
+ last_comment_text = None
+
+ return {
+ "last_action_role": last_action_role,
+ "last_activity_time": last_activity_time,
+ "last_action_type": last_action_type,
+ "last_comment_text": last_comment_text,
+ }
+
+
+def get_issue_state(item_number: int) -> Dict[str, Any]:
+ """
+ Retrieves the comprehensive state of a GitHub issue using GraphQL.
+
+ This function orchestrates the fetching, parsing, and analysis of the issue's
+ history to determine if it is stale, active, or pending maintainer review.
+
+ Args:
+ item_number (int): The GitHub issue number.
+
+ Returns:
+ Dict[str, Any]: A comprehensive state dictionary for the LLM agent.
+ Contains keys such as 'last_action_role', 'is_stale', 'days_since_activity',
+ and 'maintainer_alert_needed'.
+ """
+ try:
+ maintainers = _get_cached_maintainers()
+
+ # 1. Fetch
+ raw_data = _fetch_graphql_data(item_number)
+
+ issue_author = raw_data.get("author", {}).get("login")
+ labels_list = [
+ l["name"] for l in raw_data.get("labels", {}).get("nodes", [])
+ ]
+
+ # 2. Parse & Sort
+ history, label_events, last_bot_alert_time = _build_history_timeline(
+ raw_data
+ )
+
+ # 3. Analyze (Replay)
+ state = _replay_history_to_find_state(history, maintainers, issue_author)
+
+ # 4. Final Calculations & Alert Logic
+ current_time = datetime.now(timezone.utc)
+ days_since_activity = (
+ current_time - state["last_activity_time"]
+ ).total_seconds() / 86400
+
+ # Stale Checks
+ is_stale = STALE_LABEL_NAME in labels_list
+ days_since_stale_label = 0.0
+ if is_stale and label_events:
+ latest_label_time = max(label_events)
+ days_since_stale_label = (
+ current_time - latest_label_time
+ ).total_seconds() / 86400
+
+ # Silent Edit Alert Logic
+ maintainer_alert_needed = False
+ if (
+ state["last_action_role"] in ["author", "other_user"]
+ and state["last_action_type"] == "edited_description"
+ ):
+ if (
+ last_bot_alert_time
+ and last_bot_alert_time > state["last_activity_time"]
+ ):
+ logger.info(
+ f"#{item_number}: Silent edit detected, but Bot already alerted. No"
+ " spam."
+ )
+ else:
+ maintainer_alert_needed = True
+ logger.info(f"#{item_number}: Silent edit detected. Alert needed.")
+
+ logger.debug(
+ f"#{item_number} VERDICT: Role={state['last_action_role']}, "
+ f"Idle={days_since_activity:.2f}d"
+ )
+
+ return {
+ "status": "success",
+ "last_action_role": state["last_action_role"],
+ "last_action_type": state["last_action_type"],
+ "maintainer_alert_needed": maintainer_alert_needed,
+ "is_stale": is_stale,
+ "days_since_activity": days_since_activity,
+ "days_since_stale_label": days_since_stale_label,
+ "last_comment_text": state["last_comment_text"],
+ "current_labels": labels_list,
+ "stale_threshold_days": STALE_HOURS_THRESHOLD / 24,
+ "close_threshold_days": CLOSE_HOURS_AFTER_STALE_THRESHOLD / 24,
+ }
+
+ except RequestException as e:
+ return error_response(f"Network Error: {e}")
+ except Exception as e:
+ logger.error(
+ f"Unexpected error analyzing #{item_number}: {e}", exc_info=True
+ )
+ return error_response(f"Analysis Error: {e}")
+
+
+# --- Tool Definitions ---
+
+
+def _format_days(hours: float) -> str:
+ """
+ Formats a duration in hours into a clean day string.
+
+ Example:
+ 168.0 -> "7"
+ 12.0 -> "0.5"
+ """
+ days = hours / 24
+ return f"{days:.1f}" if days % 1 != 0 else f"{int(days)}"
+
+
+def add_label_to_issue(item_number: int, label_name: str) -> dict[str, Any]:
+ """
+ Adds a label to the issue.
+
+ Args:
+ item_number (int): The GitHub issue number.
+ label_name (str): The name of the label to add.
+ """
+ logger.debug(f"Adding label '{label_name}' to issue #{item_number}.")
+ url = f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{item_number}/labels"
+ try:
+ post_request(url, [label_name])
+ return {"status": "success"}
+ except RequestException as e:
+ return error_response(f"Error adding label: {e}")
+
+
+def remove_label_from_issue(
+ item_number: int, label_name: str
+) -> dict[str, Any]:
+ """
+ Removes a label from the issue.
+
+ Args:
+ item_number (int): The GitHub issue number.
+ label_name (str): The name of the label to remove.
+ """
+ logger.debug(f"Removing label '{label_name}' from issue #{item_number}.")
+ url = f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{item_number}/labels/{label_name}"
+ try:
+ delete_request(url)
+ return {"status": "success"}
+ except RequestException as e:
+ return error_response(f"Error removing label: {e}")
+
+
+def add_stale_label_and_comment(item_number: int) -> dict[str, Any]:
+ """
+ Marks the issue as stale with a comment and label.
+
+ Args:
+ item_number (int): The GitHub issue number.
+ """
+ stale_days_str = _format_days(STALE_HOURS_THRESHOLD)
+ close_days_str = _format_days(CLOSE_HOURS_AFTER_STALE_THRESHOLD)
+
+ comment = (
+ "This issue has been automatically marked as stale because it has not"
+ f" had recent activity for {stale_days_str} days after a maintainer"
+ " requested clarification. It will be closed if no further activity"
+ f" occurs within {close_days_str} days."
+ )
+ try:
+ post_request(
+ f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{item_number}/comments",
+ {"body": comment},
+ )
+ post_request(
+ f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{item_number}/labels",
+ [STALE_LABEL_NAME],
+ )
+ return {"status": "success"}
+ except RequestException as e:
+ return error_response(f"Error marking issue as stale: {e}")
+
+
+def alert_maintainer_of_edit(item_number: int) -> dict[str, Any]:
+ """
+ Posts a comment alerting maintainers of a silent description update.
+
+ Args:
+ item_number (int): The GitHub issue number.
+ """
+ # Uses the constant signature to ensure detection logic in get_issue_state works.
+ comment = f"{BOT_ALERT_SIGNATURE}. Maintainers, please review."
+ try:
+ post_request(
+ f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{item_number}/comments",
+ {"body": comment},
+ )
+ return {"status": "success"}
+ except RequestException as e:
+ return error_response(f"Error posting alert: {e}")
+
+
+def close_as_stale(item_number: int) -> dict[str, Any]:
+ """
+ Closes the issue as not planned/stale.
+
+ Args:
+ item_number (int): The GitHub issue number.
+ """
+ days_str = _format_days(CLOSE_HOURS_AFTER_STALE_THRESHOLD)
+
+ comment = (
+ "This has been automatically closed because it has been marked as stale"
+ f" for over {days_str} days."
+ )
+ try:
+ post_request(
+ f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{item_number}/comments",
+ {"body": comment},
+ )
+ patch_request(
+ f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{item_number}",
+ {"state": "closed"},
+ )
+ return {"status": "success"}
+ except RequestException as e:
+ return error_response(f"Error closing issue: {e}")
+
+
+root_agent = Agent(
+ model=LLM_MODEL_NAME,
+ name="adk_repository_auditor_agent",
+ description="Audits open issues.",
+ instruction=PROMPT_TEMPLATE.format(
+ OWNER=OWNER,
+ REPO=REPO,
+ STALE_LABEL_NAME=STALE_LABEL_NAME,
+ REQUEST_CLARIFICATION_LABEL=REQUEST_CLARIFICATION_LABEL,
+ stale_threshold_days=STALE_HOURS_THRESHOLD / 24,
+ close_threshold_days=CLOSE_HOURS_AFTER_STALE_THRESHOLD / 24,
+ ),
+ tools=[
+ add_label_to_issue,
+ add_stale_label_and_comment,
+ alert_maintainer_of_edit,
+ close_as_stale,
+ get_issue_state,
+ remove_label_from_issue,
+ ],
+)
diff --git a/contributing/samples/adk_stale_agent/main.py b/contributing/samples/adk_stale_agent/main.py
new file mode 100644
index 0000000000..d4fe58dd63
--- /dev/null
+++ b/contributing/samples/adk_stale_agent/main.py
@@ -0,0 +1,195 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+import logging
+import time
+from typing import Tuple
+
+from adk_stale_agent.agent import root_agent
+from adk_stale_agent.settings import CONCURRENCY_LIMIT
+from adk_stale_agent.settings import OWNER
+from adk_stale_agent.settings import REPO
+from adk_stale_agent.settings import SLEEP_BETWEEN_CHUNKS
+from adk_stale_agent.settings import STALE_HOURS_THRESHOLD
+from adk_stale_agent.utils import get_api_call_count
+from adk_stale_agent.utils import get_old_open_issue_numbers
+from adk_stale_agent.utils import reset_api_call_count
+from google.adk.cli.utils import logs
+from google.adk.runners import InMemoryRunner
+from google.genai import types
+
+logs.setup_adk_logger(level=logging.INFO)
+logger = logging.getLogger("google_adk." + __name__)
+
+APP_NAME = "stale_bot_app"
+USER_ID = "stale_bot_user"
+
+
+async def process_single_issue(issue_number: int) -> Tuple[float, int]:
+ """
+ Processes a single GitHub issue using the AI agent and logs execution metrics.
+
+ Args:
+ issue_number (int): The GitHub issue number to audit.
+
+ Returns:
+ Tuple[float, int]: A tuple containing:
+ - duration (float): Time taken to process the issue in seconds.
+ - api_calls (int): The number of API calls made during this specific execution.
+
+ Raises:
+ Exception: catches generic exceptions to prevent one failure from stopping the batch.
+ """
+ start_time = time.perf_counter()
+
+ start_api_calls = get_api_call_count()
+
+ logger.info(f"Processing Issue #{issue_number}...")
+ logger.debug(f"#{issue_number}: Initializing runner and session.")
+
+ try:
+ runner = InMemoryRunner(agent=root_agent, app_name=APP_NAME)
+ session = await runner.session_service.create_session(
+ user_id=USER_ID, app_name=APP_NAME
+ )
+
+ prompt_text = f"Audit Issue #{issue_number}."
+ prompt_message = types.Content(
+ role="user", parts=[types.Part(text=prompt_text)]
+ )
+
+ logger.debug(f"#{issue_number}: Sending prompt to agent.")
+
+ async for event in runner.run_async(
+ user_id=USER_ID, session_id=session.id, new_message=prompt_message
+ ):
+ if (
+ event.content
+ and event.content.parts
+ and hasattr(event.content.parts[0], "text")
+ ):
+ text = event.content.parts[0].text
+ if text:
+ clean_text = text[:150].replace("\n", " ")
+ logger.info(f"#{issue_number} Decision: {clean_text}...")
+
+ except Exception as e:
+ logger.error(f"Error processing issue #{issue_number}: {e}", exc_info=True)
+
+ duration = time.perf_counter() - start_time
+
+ end_api_calls = get_api_call_count()
+ issue_api_calls = end_api_calls - start_api_calls
+
+ logger.info(
+ f"Issue #{issue_number} finished in {duration:.2f}s "
+ f"with ~{issue_api_calls} API calls."
+ )
+
+ return duration, issue_api_calls
+
+
+async def main():
+ """
+ Main entry point to run the stale issue bot concurrently.
+
+ Fetches old issues and processes them in batches to respect API rate limits
+ and concurrency constraints.
+ """
+ logger.info(f"--- Starting Stale Bot for {OWNER}/{REPO} ---")
+ logger.info(f"Concurrency level set to {CONCURRENCY_LIMIT}")
+
+ reset_api_call_count()
+
+ filter_days = STALE_HOURS_THRESHOLD / 24
+ logger.debug(f"Fetching issues older than {filter_days:.2f} days...")
+
+ try:
+ all_issues = get_old_open_issue_numbers(OWNER, REPO, days_old=filter_days)
+ except Exception as e:
+ logger.critical(f"Failed to fetch issue list: {e}", exc_info=True)
+ return
+
+ total_count = len(all_issues)
+
+ search_api_calls = get_api_call_count()
+
+ if total_count == 0:
+ logger.info("No issues matched the criteria. Run finished.")
+ return
+
+ logger.info(
+ f"Found {total_count} issues to process. "
+ f"(Initial search used {search_api_calls} API calls)."
+ )
+
+ total_processing_time = 0.0
+ total_issue_api_calls = 0
+ processed_count = 0
+
+ # Process the list in chunks of size CONCURRENCY_LIMIT
+ for i in range(0, total_count, CONCURRENCY_LIMIT):
+ chunk = all_issues[i : i + CONCURRENCY_LIMIT]
+ current_chunk_num = i // CONCURRENCY_LIMIT + 1
+
+ logger.info(
+ f"--- Starting chunk {current_chunk_num}: Processing issues {chunk} ---"
+ )
+
+ tasks = [process_single_issue(issue_num) for issue_num in chunk]
+
+ results = await asyncio.gather(*tasks)
+
+ for duration, api_calls in results:
+ total_processing_time += duration
+ total_issue_api_calls += api_calls
+
+ processed_count += len(chunk)
+ logger.info(
+ f"--- Finished chunk {current_chunk_num}. Progress:"
+ f" {processed_count}/{total_count} ---"
+ )
+
+ if (i + CONCURRENCY_LIMIT) < total_count:
+ logger.debug(
+ f"Sleeping for {SLEEP_BETWEEN_CHUNKS}s to respect rate limits..."
+ )
+ await asyncio.sleep(SLEEP_BETWEEN_CHUNKS)
+
+ total_api_calls_for_run = search_api_calls + total_issue_api_calls
+ avg_time_per_issue = (
+ total_processing_time / total_count if total_count > 0 else 0
+ )
+
+ logger.info("--- Stale Agent Run Finished ---")
+ logger.info(f"Successfully processed {processed_count} issues.")
+ logger.info(f"Total API calls made this run: {total_api_calls_for_run}")
+ logger.info(
+ f"Average processing time per issue: {avg_time_per_issue:.2f} seconds."
+ )
+
+
+if __name__ == "__main__":
+ start_time = time.perf_counter()
+
+ try:
+ asyncio.run(main())
+ except KeyboardInterrupt:
+ logger.warning("Bot execution interrupted manually.")
+ except Exception as e:
+ logger.critical(f"Unexpected fatal error: {e}", exc_info=True)
+
+ duration = time.perf_counter() - start_time
+ logger.info(f"Full audit finished in {duration/60:.2f} minutes.")
diff --git a/contributing/samples/adk_stale_agent/settings.py b/contributing/samples/adk_stale_agent/settings.py
new file mode 100644
index 0000000000..599c6ef2ea
--- /dev/null
+++ b/contributing/samples/adk_stale_agent/settings.py
@@ -0,0 +1,63 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+from dotenv import load_dotenv
+
+# Load environment variables from a .env file for local testing
+load_dotenv(override=True)
+
+# --- GitHub API Configuration ---
+GITHUB_BASE_URL = "https://api.github.com"
+GITHUB_TOKEN = os.getenv("GITHUB_TOKEN")
+if not GITHUB_TOKEN:
+ raise ValueError("GITHUB_TOKEN environment variable not set")
+
+OWNER = os.getenv("OWNER", "google")
+REPO = os.getenv("REPO", "adk-python")
+LLM_MODEL_NAME = os.getenv("LLM_MODEL_NAME", "gemini-2.5-flash")
+
+STALE_LABEL_NAME = "stale"
+REQUEST_CLARIFICATION_LABEL = "request clarification"
+
+# --- THRESHOLDS IN HOURS ---
+# Default: 168 hours (7 days)
+# The number of hours of inactivity after a maintainer comment before an issue is marked as stale.
+STALE_HOURS_THRESHOLD = float(os.getenv("STALE_HOURS_THRESHOLD", 168))
+
+# Default: 168 hours (7 days)
+# The number of hours of inactivity after an issue is marked 'stale' before it is closed.
+CLOSE_HOURS_AFTER_STALE_THRESHOLD = float(
+ os.getenv("CLOSE_HOURS_AFTER_STALE_THRESHOLD", 168)
+)
+
+# --- Performance Configuration ---
+# The number of issues to process concurrently.
+# Higher values are faster but increase the immediate rate of API calls
+CONCURRENCY_LIMIT = int(os.getenv("CONCURRENCY_LIMIT", 3))
+
+# --- GraphQL Query Limits ---
+# The number of most recent comments to fetch for context analysis.
+GRAPHQL_COMMENT_LIMIT = int(os.getenv("GRAPHQL_COMMENT_LIMIT", 30))
+
+# The number of most recent description edits to fetch.
+GRAPHQL_EDIT_LIMIT = int(os.getenv("GRAPHQL_EDIT_LIMIT", 10))
+
+# The number of most recent timeline events (labels, renames, reopens) to fetch.
+GRAPHQL_TIMELINE_LIMIT = int(os.getenv("GRAPHQL_TIMELINE_LIMIT", 20))
+
+# --- Rate Limiting ---
+# Time in seconds to wait between processing chunks.
+SLEEP_BETWEEN_CHUNKS = float(os.getenv("SLEEP_BETWEEN_CHUNKS", 1.5))
diff --git a/contributing/samples/adk_stale_agent/utils.py b/contributing/samples/adk_stale_agent/utils.py
new file mode 100644
index 0000000000..a396c22ac7
--- /dev/null
+++ b/contributing/samples/adk_stale_agent/utils.py
@@ -0,0 +1,260 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from datetime import datetime
+from datetime import timedelta
+from datetime import timezone
+import logging
+import threading
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+
+from adk_stale_agent.settings import GITHUB_TOKEN
+from adk_stale_agent.settings import STALE_HOURS_THRESHOLD
+import dateutil.parser
+import requests
+from requests.adapters import HTTPAdapter
+from urllib3.util.retry import Retry
+
+logger = logging.getLogger("google_adk." + __name__)
+
+# --- API Call Counter for Monitoring ---
+_api_call_count = 0
+_counter_lock = threading.Lock()
+
+
+def get_api_call_count() -> int:
+ """
+ Returns the total number of API calls made since the last reset.
+
+ Returns:
+ int: The global count of API calls.
+ """
+ with _counter_lock:
+ return _api_call_count
+
+
+def reset_api_call_count() -> None:
+ """Resets the global API call counter to zero."""
+ global _api_call_count
+ with _counter_lock:
+ _api_call_count = 0
+
+
+def _increment_api_call_count() -> None:
+ """
+ Atomically increments the global API call counter.
+ Required because the agent may run tools in parallel threads.
+ """
+ global _api_call_count
+ with _counter_lock:
+ _api_call_count += 1
+
+
+# --- Production-Ready HTTP Session with Exponential Backoff ---
+
+# Configure the retry strategy:
+retry_strategy = Retry(
+ total=6,
+ backoff_factor=2,
+ status_forcelist=[429, 500, 502, 503, 504],
+ allowed_methods=[
+ "HEAD",
+ "GET",
+ "POST",
+ "PUT",
+ "DELETE",
+ "OPTIONS",
+ "TRACE",
+ "PATCH",
+ ],
+)
+
+adapter = HTTPAdapter(max_retries=retry_strategy)
+
+# Create a single, reusable Session object for connection pooling
+_session = requests.Session()
+_session.mount("https://", adapter)
+_session.mount("http://", adapter)
+
+_session.headers.update({
+ "Authorization": f"token {GITHUB_TOKEN}",
+ "Accept": "application/vnd.github.v3+json",
+})
+
+
+def get_request(url: str, params: Optional[Dict[str, Any]] = None) -> Any:
+ """
+ Sends a GET request to the GitHub API with automatic retries.
+
+ Args:
+ url (str): The URL endpoint.
+ params (Optional[Dict[str, Any]]): Query parameters.
+
+ Returns:
+ Any: The JSON response parsed into a dict or list.
+
+ Raises:
+ requests.exceptions.RequestException: If retries are exhausted.
+ """
+ _increment_api_call_count()
+ try:
+ response = _session.get(url, params=params or {}, timeout=60)
+ response.raise_for_status()
+ return response.json()
+ except requests.exceptions.RequestException as e:
+ logger.error(f"GET request failed for {url}: {e}")
+ raise
+
+
+def post_request(url: str, payload: Any) -> Any:
+ """
+ Sends a POST request to the GitHub API with automatic retries.
+
+ Args:
+ url (str): The URL endpoint.
+ payload (Any): The JSON payload.
+
+ Returns:
+ Any: The JSON response.
+ """
+ _increment_api_call_count()
+ try:
+ response = _session.post(url, json=payload, timeout=60)
+ response.raise_for_status()
+ return response.json()
+ except requests.exceptions.RequestException as e:
+ logger.error(f"POST request failed for {url}: {e}")
+ raise
+
+
+def patch_request(url: str, payload: Any) -> Any:
+ """
+ Sends a PATCH request to the GitHub API with automatic retries.
+
+ Args:
+ url (str): The URL endpoint.
+ payload (Any): The JSON payload.
+
+ Returns:
+ Any: The JSON response.
+ """
+ _increment_api_call_count()
+ try:
+ response = _session.patch(url, json=payload, timeout=60)
+ response.raise_for_status()
+ return response.json()
+ except requests.exceptions.RequestException as e:
+ logger.error(f"PATCH request failed for {url}: {e}")
+ raise
+
+
+def delete_request(url: str) -> Any:
+ """
+ Sends a DELETE request to the GitHub API with automatic retries.
+
+ Args:
+ url (str): The URL endpoint.
+
+ Returns:
+ Any: A success dict if 204, else the JSON response.
+ """
+ _increment_api_call_count()
+ try:
+ response = _session.delete(url, timeout=60)
+ response.raise_for_status()
+ if response.status_code == 204:
+ return {"status": "success", "message": "Deletion successful."}
+ return response.json()
+ except requests.exceptions.RequestException as e:
+ logger.error(f"DELETE request failed for {url}: {e}")
+ raise
+
+
+def error_response(error_message: str) -> Dict[str, Any]:
+ """
+ Creates a standardized error response dictionary for tool outputs.
+
+ Args:
+ error_message (str): The error details.
+
+ Returns:
+ Dict[str, Any]: Standardized error object.
+ """
+ return {"status": "error", "message": error_message}
+
+
+def get_old_open_issue_numbers(
+ owner: str, repo: str, days_old: Optional[float] = None
+) -> List[int]:
+ """
+ Finds open issues older than the specified threshold using server-side filtering.
+
+ OPTIMIZATION:
+ Instead of fetching ALL issues and filtering in Python (which wastes API calls),
+ this uses the GitHub Search API `created: dict[str, Any]:
- """List most recent `issue_count` number of unlabeled issues in the repo.
+def list_untriaged_issues(issue_count: int) -> dict[str, Any]:
+ """List open issues that need triaging.
+
+ Returns issues that need any of the following actions:
+ 1. Issues without component labels (need labeling + type setting)
+ 2. Issues with 'planned' label but no assignee (need owner assignment)
Args:
issue_count: number of issues to return
Returns:
The status of this request, with a list of issues when successful.
+ Each issue includes flags indicating what actions are needed.
"""
url = f"{GITHUB_BASE_URL}/search/issues"
- query = f"repo:{OWNER}/{REPO} is:open is:issue no:label"
+ query = f"repo:{OWNER}/{REPO} is:open is:issue"
params = {
"q": query,
"sort": "created",
"order": "desc",
- "per_page": issue_count,
+ "per_page": 100, # Fetch more to filter
"page": 1,
}
@@ -101,27 +109,48 @@ def list_unlabeled_issues(issue_count: int) -> dict[str, Any]:
response = get_request(url, params)
except requests.exceptions.RequestException as e:
return error_response(f"Error: {e}")
- issues = response.get("items", None)
+ issues = response.get("items", [])
- unlabeled_issues = []
+ component_labels = set(LABEL_TO_OWNER.keys())
+ untriaged_issues = []
for issue in issues:
- if not issue.get("labels", None):
- unlabeled_issues.append(issue)
- return {"status": "success", "issues": unlabeled_issues}
-
-
-def add_label_and_owner_to_issue(
- issue_number: int, label: str
-) -> dict[str, Any]:
- """Add the specified label and owner to the given issue number.
+ issue_labels = {label["name"] for label in issue.get("labels", [])}
+ assignees = issue.get("assignees", [])
+
+ existing_component_labels = issue_labels & component_labels
+ has_component = bool(existing_component_labels)
+ has_planned = "planned" in issue_labels
+
+ # Determine what actions are needed
+ needs_component_label = not has_component
+ needs_owner = has_planned and not assignees
+
+ # Include issue if it needs any action
+ if needs_component_label or needs_owner:
+ issue["has_planned_label"] = has_planned
+ issue["has_component_label"] = has_component
+ issue["existing_component_label"] = (
+ list(existing_component_labels)[0]
+ if existing_component_labels
+ else None
+ )
+ issue["needs_component_label"] = needs_component_label
+ issue["needs_owner"] = needs_owner
+ untriaged_issues.append(issue)
+ if len(untriaged_issues) >= issue_count:
+ break
+ return {"status": "success", "issues": untriaged_issues}
+
+
+def add_label_to_issue(issue_number: int, label: str) -> dict[str, Any]:
+ """Add the specified component label to the given issue number.
Args:
issue_number: issue number of the GitHub issue.
label: label to assign
Returns:
- The the status of this request, with the applied label and assigned owner
- when successful.
+ The status of this request, with the applied label when successful.
"""
print(f"Attempting to add label '{label}' to issue #{issue_number}")
if label not in LABEL_TO_OWNER:
@@ -139,15 +168,38 @@ def add_label_and_owner_to_issue(
except requests.exceptions.RequestException as e:
return error_response(f"Error: {e}")
+ return {
+ "status": "success",
+ "message": response,
+ "applied_label": label,
+ }
+
+
+def add_owner_to_issue(issue_number: int, label: str) -> dict[str, Any]:
+ """Assign an owner to the issue based on the component label.
+
+ This should only be called for issues that have the 'planned' label.
+
+ Args:
+ issue_number: issue number of the GitHub issue.
+ label: component label that determines the owner to assign
+
+ Returns:
+ The status of this request, with the assigned owner when successful.
+ """
+ print(
+ f"Attempting to assign owner for label '{label}' to issue #{issue_number}"
+ )
+ if label not in LABEL_TO_OWNER:
+ return error_response(
+ f"Error: Label '{label}' is not a valid component label."
+ )
+
owner = LABEL_TO_OWNER.get(label, None)
if not owner:
return {
"status": "warning",
- "message": (
- f"{response}\n\nLabel '{label}' does not have an owner. Will not"
- " assign."
- ),
- "applied_label": label,
+ "message": f"Label '{label}' does not have an owner. Will not assign.",
}
assignee_url = (
@@ -163,7 +215,6 @@ def add_label_and_owner_to_issue(
return {
"status": "success",
"message": response,
- "applied_label": label,
"assigned_owner": owner,
}
@@ -217,31 +268,50 @@ def change_issue_type(issue_number: int, issue_type: str) -> dict[str, Any]:
- If it's about Model Context Protocol (e.g. MCP tool, MCP toolset, MCP session management etc.), label it with both "mcp" and "tools".
- If it's about A2A integrations or workflows, label it with "a2a".
- If it's about BigQuery integrations, label it with "bq".
+ - If it's about workflow agents or workflow execution, label it with "workflow".
+ - If it's about authentication, label it with "auth".
- If you can't find an appropriate labels for the issue, follow the previous instruction that starts with "IMPORTANT:".
- Call the `add_label_and_owner_to_issue` tool to label the issue, which will also assign the issue to the owner of the label.
+ ## Triaging Workflow
+
+ Each issue will have flags indicating what actions are needed:
+ - `needs_component_label`: true if the issue needs a component label
+ - `needs_owner`: true if the issue needs an owner assigned (has 'planned' label but no assignee)
+
+ For each issue, perform ONLY the required actions based on the flags:
+
+ 1. **If `needs_component_label` is true**:
+ - Use `add_label_to_issue` to add the appropriate component label
+ - Use `change_issue_type` to set the issue type:
+ - Bug report → "Bug"
+ - Feature request → "Feature"
+ - Otherwise → do not change the issue type
+
+ 2. **If `needs_owner` is true**:
+ - Use `add_owner_to_issue` to assign an owner based on the component label
+ - Note: If the issue already has a component label (`has_component_label: true`), use that existing label to determine the owner
- After you label the issue, call the `change_issue_type` tool to change the issue type:
- - If the issue is a bug report, change the issue type to "Bug".
- - If the issue is a feature request, change the issue type to "Feature".
- - Otherwise, **do not change the issue type**.
+ Do NOT add a component label if `needs_component_label` is false.
+ Do NOT assign an owner if `needs_owner` is false.
Response quality requirements:
- Summarize the issue in your own words without leaving template
placeholders (never output text like "[fill in later]").
- Justify the chosen label with a short explanation referencing the issue
details.
- - Mention the assigned owner when a label maps to one.
+ - Mention the assigned owner only when you actually assign one (i.e., when
+ the issue has the 'planned' label).
- If no label is applied, clearly state why.
Present the following in an easy to read format highlighting issue number and your label.
- the issue summary in a few sentence
- your label recommendation and justification
- - the owner of the label if you assign the issue to an owner
+ - the owner of the label if you assign the issue to an owner (only for planned issues)
""",
tools=[
- list_unlabeled_issues,
- add_label_and_owner_to_issue,
+ list_untriaged_issues,
+ add_label_to_issue,
+ add_owner_to_issue,
change_issue_type,
],
)
diff --git a/contributing/samples/adk_triaging_agent/main.py b/contributing/samples/adk_triaging_agent/main.py
index 317f5893e2..3a2d4da570 100644
--- a/contributing/samples/adk_triaging_agent/main.py
+++ b/contributing/samples/adk_triaging_agent/main.py
@@ -16,6 +16,7 @@
import time
from adk_triaging_agent import agent
+from adk_triaging_agent.agent import LABEL_TO_OWNER
from adk_triaging_agent.settings import EVENT_NAME
from adk_triaging_agent.settings import GITHUB_BASE_URL
from adk_triaging_agent.settings import ISSUE_BODY
@@ -37,21 +38,49 @@
async def fetch_specific_issue_details(issue_number: int):
- """Fetches details for a single issue if it's unlabelled."""
+ """Fetches details for a single issue if it needs triaging."""
url = f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{issue_number}"
print(f"Fetching details for specific issue: {url}")
try:
issue_data = get_request(url)
- if not issue_data.get("labels", None):
- print(f"Issue #{issue_number} is unlabelled. Proceeding.")
+ labels = issue_data.get("labels", [])
+ label_names = {label["name"] for label in labels}
+ assignees = issue_data.get("assignees", [])
+
+ # Check issue state
+ component_labels = set(LABEL_TO_OWNER.keys())
+ has_planned = "planned" in label_names
+ existing_component_labels = label_names & component_labels
+ has_component = bool(existing_component_labels)
+ has_assignee = len(assignees) > 0
+
+ # Determine what actions are needed
+ needs_component_label = not has_component
+ needs_owner = has_planned and not has_assignee
+
+ if needs_component_label or needs_owner:
+ print(
+ f"Issue #{issue_number} needs triaging. "
+ f"needs_component_label={needs_component_label}, "
+ f"needs_owner={needs_owner}"
+ )
return {
"number": issue_data["number"],
"title": issue_data["title"],
"body": issue_data.get("body", ""),
+ "has_planned_label": has_planned,
+ "has_component_label": has_component,
+ "existing_component_label": (
+ list(existing_component_labels)[0]
+ if existing_component_labels
+ else None
+ ),
+ "needs_component_label": needs_component_label,
+ "needs_owner": needs_owner,
}
else:
- print(f"Issue #{issue_number} is already labelled. Skipping.")
+ print(f"Issue #{issue_number} is already fully triaged. Skipping.")
return None
except requests.exceptions.RequestException as e:
print(f"Error fetching issue #{issue_number}: {e}")
@@ -108,26 +137,32 @@ async def main():
specific_issue = await fetch_specific_issue_details(issue_number)
if specific_issue is None:
print(
- f"No unlabelled issue details found for #{issue_number} or an error"
- " occurred. Skipping agent interaction."
+ f"No issue details found for #{issue_number} that needs triaging,"
+ " or an error occurred. Skipping agent interaction."
)
return
issue_title = ISSUE_TITLE or specific_issue["title"]
issue_body = ISSUE_BODY or specific_issue["body"]
+ needs_component_label = specific_issue.get("needs_component_label", True)
+ needs_owner = specific_issue.get("needs_owner", False)
+ existing_component_label = specific_issue.get("existing_component_label")
+
prompt = (
- f"A new GitHub issue #{issue_number} has been opened or"
- f' reopened. Title: "{issue_title}"\nBody:'
- f' "{issue_body}"\n\nBased on the rules, recommend an'
- " appropriate label and its justification."
- " Then, use the 'add_label_to_issue' tool to apply the label "
- "directly to this issue. Only label it, do not"
- " process any other issues."
+ f"Triage GitHub issue #{issue_number}.\n\n"
+ f'Title: "{issue_title}"\n'
+ f'Body: "{issue_body}"\n\n'
+ f"Issue state: needs_component_label={needs_component_label}, "
+ f"needs_owner={needs_owner}, "
+ f"existing_component_label={existing_component_label}"
)
else:
print(f"EVENT: Processing batch of issues (event: {EVENT_NAME}).")
issue_count = parse_number_string(ISSUE_COUNT_TO_PROCESS, default_value=3)
- prompt = f"Please triage the most recent {issue_count} issues."
+ prompt = (
+ f"Please use 'list_untriaged_issues' to find {issue_count} issues that"
+ " need triaging, then triage each one according to your instructions."
+ )
response = await call_agent_async(runner, USER_ID, session.id, prompt)
print(f"<<<< Agent Final Output: {response}\n")
diff --git a/contributing/samples/api_registry_agent/README.md b/contributing/samples/api_registry_agent/README.md
new file mode 100644
index 0000000000..78b3c22382
--- /dev/null
+++ b/contributing/samples/api_registry_agent/README.md
@@ -0,0 +1,21 @@
+# BigQuery API Registry Agent
+
+This agent demonstrates how to use `ApiRegistry` to discover and interact with Google Cloud services like BigQuery via tools exposed by an MCP server registered in an API Registry.
+
+## Prerequisites
+
+- A Google Cloud project with the API Registry API enabled.
+- An MCP server exposing BigQuery tools registered in API Registry.
+
+## Configuration & Running
+
+1. **Configure:** Edit `agent.py` and replace `your-google-cloud-project-id` and `your-mcp-server-name` with your Google Cloud Project ID and the name of your registered MCP server.
+2. **Run in CLI:**
+ ```bash
+ adk run contributing/samples/api_registry_agent -- --log-level DEBUG
+ ```
+3. **Run in Web UI:**
+ ```bash
+ adk web contributing/samples/
+ ```
+ Navigate to `http://127.0.0.1:8080` and select the `api_registry_agent` agent.
diff --git a/contributing/samples/api_registry_agent/__init__.py b/contributing/samples/api_registry_agent/__init__.py
new file mode 100644
index 0000000000..c48963cdc7
--- /dev/null
+++ b/contributing/samples/api_registry_agent/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import agent
diff --git a/contributing/samples/api_registry_agent/agent.py b/contributing/samples/api_registry_agent/agent.py
new file mode 100644
index 0000000000..6504822092
--- /dev/null
+++ b/contributing/samples/api_registry_agent/agent.py
@@ -0,0 +1,39 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+from google.adk.agents.llm_agent import LlmAgent
+from google.adk.tools.api_registry import ApiRegistry
+
+# TODO: Fill in with your GCloud project id and MCP server name
+PROJECT_ID = "your-google-cloud-project-id"
+MCP_SERVER_NAME = "your-mcp-server-name"
+
+# Header required for BigQuery MCP server
+header_provider = lambda context: {
+ "x-goog-user-project": PROJECT_ID,
+}
+api_registry = ApiRegistry(PROJECT_ID, header_provider=header_provider)
+registry_tools = api_registry.get_toolset(
+ mcp_server_name=MCP_SERVER_NAME,
+)
+root_agent = LlmAgent(
+ model="gemini-2.0-flash",
+ name="bigquery_assistant",
+ instruction="""
+Help user access their BigQuery data via API Registry tools.
+ """,
+ tools=[registry_tools],
+)
diff --git a/contributing/samples/computer_use/README.md b/contributing/samples/computer_use/README.md
index ff7ae6c0a5..38b6fe79c6 100644
--- a/contributing/samples/computer_use/README.md
+++ b/contributing/samples/computer_use/README.md
@@ -19,7 +19,7 @@ The computer use agent consists of:
Install the required Python packages from the requirements file:
```bash
-uv pip install -r internal/samples/computer_use/requirements.txt
+uv pip install -r contributing/samples/computer_use/requirements.txt
```
### 2. Install Playwright Dependencies
@@ -45,7 +45,7 @@ playwright install chromium
To start the computer use agent, run the following command from the project root:
```bash
-adk web internal/samples
+adk web contributing/samples
```
This will start the ADK web interface where you can interact with the computer_use agent.
diff --git a/contributing/samples/hello_world_stream_fc_args/__init__.py b/contributing/samples/hello_world_stream_fc_args/__init__.py
new file mode 100755
index 0000000000..c48963cdc7
--- /dev/null
+++ b/contributing/samples/hello_world_stream_fc_args/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import agent
diff --git a/contributing/samples/hello_world_stream_fc_args/agent.py b/contributing/samples/hello_world_stream_fc_args/agent.py
new file mode 100755
index 0000000000..f613842171
--- /dev/null
+++ b/contributing/samples/hello_world_stream_fc_args/agent.py
@@ -0,0 +1,55 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from google.adk import Agent
+from google.genai import types
+
+
+def concat_number_and_string(num: int, s: str) -> str:
+ """Concatenate a number and a string.
+
+ Args:
+ num: The number to concatenate.
+ s: The string to concatenate.
+
+ Returns:
+ The concatenated string.
+ """
+ return str(num) + ': ' + s
+
+
+root_agent = Agent(
+ model='gemini-3-pro-preview',
+ name='hello_world_stream_fc_args',
+ description='Demo agent showcasing streaming function call arguments.',
+ instruction="""
+ You are a helpful assistant.
+ You can use the `concat_number_and_string` tool to concatenate a number and a string.
+ You should always call the concat_number_and_string tool to concatenate a number and a string.
+ You should never concatenate on your own.
+ """,
+ tools=[
+ concat_number_and_string,
+ ],
+ generate_content_config=types.GenerateContentConfig(
+ automatic_function_calling=types.AutomaticFunctionCallingConfig(
+ disable=True,
+ ),
+ tool_config=types.ToolConfig(
+ function_calling_config=types.FunctionCallingConfig(
+ stream_function_call_arguments=True,
+ ),
+ ),
+ ),
+)
diff --git a/contributing/samples/live_bidi_streaming_single_agent/agent.py b/contributing/samples/live_bidi_streaming_single_agent/agent.py
index c295adc136..9246fca9d5 100755
--- a/contributing/samples/live_bidi_streaming_single_agent/agent.py
+++ b/contributing/samples/live_bidi_streaming_single_agent/agent.py
@@ -65,8 +65,8 @@ async def check_prime(nums: list[int]) -> str:
root_agent = Agent(
- # model='gemini-live-2.5-flash-preview-native-audio-09-2025', # vertex
- model='gemini-2.5-flash-native-audio-preview-09-2025', # for AI studio
+ model='gemini-live-2.5-flash-preview-native-audio-09-2025', # vertex
+ # model='gemini-2.5-flash-native-audio-preview-09-2025', # for AI studio
# key
name='roll_dice_agent',
description=(
diff --git a/contributing/samples/migrate_session_db/sample-output/alembic.ini b/contributing/samples/migrate_session_db/sample-output/alembic.ini
index 6405320948..e346ee8ac6 100644
--- a/contributing/samples/migrate_session_db/sample-output/alembic.ini
+++ b/contributing/samples/migrate_session_db/sample-output/alembic.ini
@@ -21,7 +21,7 @@ prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
-# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
+# If specified, requires the python>=3.10 and tzdata library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
diff --git a/contributing/samples/spanner_rag_agent/README.md b/contributing/samples/spanner_rag_agent/README.md
index 8148983736..99b60794fe 100644
--- a/contributing/samples/spanner_rag_agent/README.md
+++ b/contributing/samples/spanner_rag_agent/README.md
@@ -57,9 +57,9 @@ model endpoint.
CREATE MODEL EmbeddingsModel INPUT(
content STRING(MAX),
) OUTPUT(
-embeddings STRUCT, values ARRAY>,
+embeddings STRUCT>,
) REMOTE OPTIONS (
-endpoint = '//aiplatform.googleapis.com/projects//locations/us-central1/publishers/google/models/text-embedding-004'
+endpoint = '//aiplatform.googleapis.com/projects//locations//publishers/google/models/text-embedding-005'
);
```
@@ -187,40 +187,203 @@ type.
## Which tool to use and When?
-There are a few options to perform similarity search (see the `agent.py` for
-implementation details):
-
-1. Wraps the built-in `similarity_search` in the Spanner Toolset.
-
- - This provides an easy and controlled way to perform similarity search.
- You can specify different configurations related to vector search based
- on your need without having to figure out all the details for a vector
- search query.
-
-2. Wraps the built-in `execute_sql` in the Spanner Toolset.
-
- - `execute_sql` is a lower-level tool that you can have more control over
- with. With the flexibility, you can specify a complicated (parameterized)
- SQL query for your need, and let the `LlmAgent` pass the parameters.
-
-3. Use the Spanner Toolset (and all the tools that come with it) directly.
-
- - The most flexible and generic way. Instead of fixing configurations via
- code, you can also specify the configurations via `instruction` to
- the `LlmAgent` and let LLM to decide which tool to use and what parameters
- to pass to different tools. It might even combine different tools together!
- Note that in this usage, SQL generation is powered by the LlmAgent, which
- can be more suitable for data analysis and assistant scenarios.
- - To restrict the ability of an `LlmAgent`, `SpannerToolSet` also supports
- `tool_filter` to explicitly specify allowed tools. As an example, the
- following code specifies that only `execute_sql` and `get_table_schema`
- are allowed:
-
- ```py
- toolset = SpannerToolset(
- credentials_config=credentials_config,
- tool_filter=["execute_sql", "get_table_schema"],
- spanner_tool_settings=SpannerToolSettings(),
- )
- ```
-
+There are a few options to perform similarity search:
+
+1. Use the built-in `vector_store_similarity_search` in the Spanner Toolset with explicit `SpannerVectorStoreSettings` configuration.
+
+ - This provides an easy way to perform similarity search. You can specify
+ different configurations related to vector search based on your Spanner
+ database vector store table setup.
+
+ Example pseudocode (see the `agent.py` for details):
+
+ ```py
+ from google.adk.agents.llm_agent import LlmAgent
+ from google.adk.tools.spanner.settings import Capabilities
+ from google.adk.tools.spanner.settings import SpannerToolSettings
+ from google.adk.tools.spanner.settings import SpannerVectorStoreSettings
+ from google.adk.tools.spanner.spanner_toolset import SpannerToolset
+
+ # credentials_config = SpannerCredentialsConfig(...)
+
+ # Define Spanner tool config with the vector store settings.
+ vector_store_settings = SpannerVectorStoreSettings(
+ project_id="",
+ instance_id="",
+ database_id="",
+ table_name="products",
+ content_column="productDescription",
+ embedding_column="productDescriptionEmbedding",
+ vector_length=768,
+ vertex_ai_embedding_model_name="text-embedding-005",
+ selected_columns=[
+ "productId",
+ "productName",
+ "productDescription",
+ ],
+ nearest_neighbors_algorithm="EXACT_NEAREST_NEIGHBORS",
+ top_k=3,
+ distance_type="COSINE",
+ additional_filter="inventoryCount > 0",
+ )
+
+ tool_settings = SpannerToolSettings(
+ capabilities=[Capabilities.DATA_READ],
+ vector_store_settings=vector_store_settings,
+ )
+
+ # Get the Spanner toolset with the Spanner tool settings and credentials config.
+ spanner_toolset = SpannerToolset(
+ credentials_config=credentials_config,
+ spanner_tool_settings=tool_settings,
+ # Use `vector_store_similarity_search` only
+ tool_filter=["vector_store_similarity_search"],
+ )
+
+ root_agent = LlmAgent(
+ model="gemini-2.5-flash",
+ name="spanner_knowledge_base_agent",
+ description=(
+ "Agent to answer questions about product-specific recommendations."
+ ),
+ instruction="""
+ You are a helpful assistant that answers user questions about product-specific recommendations.
+ 1. Always use the `vector_store_similarity_search` tool to find relevant information.
+ 2. If no relevant information is found, say you don't know.
+ 3. Present all the relevant information naturally and well formatted in your response.
+ """,
+ tools=[spanner_toolset],
+ )
+ ```
+
+2. Use the built-in `similarity_search` in the Spanner Toolset.
+
+ - `similarity_search` is a lower-level tool, which provide the most flexible
+ and generic way. Specify all the necessary tool's parameters is required
+ when interacting with `LlmAgent` before performing the tool call. This is
+ more suitable for data analysis, ad-hoc query and assistant scenarios.
+
+ Example pseudocode:
+
+ ```py
+ from google.adk.agents.llm_agent import LlmAgent
+ from google.adk.tools.spanner.settings import Capabilities
+ from google.adk.tools.spanner.settings import SpannerToolSettings
+ from google.adk.tools.spanner.spanner_toolset import SpannerToolset
+
+ # credentials_config = SpannerCredentialsConfig(...)
+
+ tool_settings = SpannerToolSettings(
+ capabilities=[Capabilities.DATA_READ],
+ )
+
+ spanner_toolset = SpannerToolset(
+ credentials_config=credentials_config,
+ spanner_tool_settings=tool_settings,
+ # Use `similarity_search` only
+ tool_filter=["similarity_search"],
+ )
+
+ root_agent = LlmAgent(
+ model="gemini-2.5-flash",
+ name="spanner_knowledge_base_agent",
+ description=(
+ "Agent to answer questions by retrieving relevant information "
+ "from the Spanner database."
+ ),
+ instruction="""
+ You are a helpful assistant that answers user questions to find the most relavant information from a Spanner database.
+ 1. Always use the `similarity_search` tool to find relevant information.
+ 2. If no relevant information is found, say you don't know.
+ 3. Present all the relevant information naturally and well formatted in your response.
+ """,
+ tools=[spanner_toolset],
+ )
+ ```
+
+3. Wraps the built-in `similarity_search` in the Spanner Toolset.
+
+ - This provides a more controlled way to perform similarity search via code.
+ You can extend the tool as a wrapped function tool to have customized logic.
+
+ Example pseudocode:
+
+ ```py
+ from google.adk.agents.llm_agent import LlmAgent
+
+ from google.adk.tools.google_tool import GoogleTool
+ from google.adk.tools.spanner import search_tool
+ import google.auth
+ from google.auth.credentials import Credentials
+
+ # credentials_config = SpannerCredentialsConfig(...)
+
+ # Create a wrapped function tool for the agent on top of the built-in
+ # similarity_search tool in the Spanner toolset.
+ # This customized tool is used to perform a Spanner KNN vector search on a
+ # embedded knowledge base stored in a Spanner database table.
+ def wrapped_spanner_similarity_search(
+ search_query: str,
+ credentials: Credentials,
+ ) -> str:
+ """Perform a similarity search on the product catalog.
+
+ Args:
+ search_query: The search query to find relevant content.
+
+ Returns:
+ Relevant product catalog content with sources
+ """
+
+ # ... Customized logic ...
+
+ # Instead of fixing all parameters, you can also expose some of them for
+ # the LLM to decide.
+ return search_tool.similarity_search(
+ project_id="",
+ instance_id="",
+ database_id="",
+ table_name="products",
+ query=search_query,
+ embedding_column_to_search="productDescriptionEmbedding",
+ columns= [
+ "productId",
+ "productName",
+ "productDescription",
+ ]
+ embedding_options={
+ "vertex_ai_embedding_model_name": "text-embedding-005",
+ },
+ credentials=credentials,
+ additional_filter="inventoryCount > 0",
+ search_options={
+ "top_k": 3,
+ "distance_type": "EUCLIDEAN",
+ },
+ )
+
+ # ...
+
+ root_agent = LlmAgent(
+ model="gemini-2.5-flash",
+ name="spanner_knowledge_base_agent",
+ description=(
+ "Agent to answer questions about product-specific recommendations."
+ ),
+ instruction="""
+ You are a helpful assistant that answers user questions about product-specific recommendations.
+ 1. Always use the `wrapped_spanner_similarity_search` tool to find relevant information.
+ 2. If no relevant information is found, say you don't know.
+ 3. Present all the relevant information naturally and well formatted in your response.
+ """,
+ tools=[
+ # Add customized Spanner tool based on the built-in similarity_search
+ # in the Spanner toolset.
+ GoogleTool(
+ func=wrapped_spanner_similarity_search,
+ credentials_config=credentials_config,
+ tool_settings=tool_settings,
+ ),
+ ],
+ )
+ ```
diff --git a/contributing/samples/spanner_rag_agent/agent.py b/contributing/samples/spanner_rag_agent/agent.py
index 58ec95294e..1460242184 100644
--- a/contributing/samples/spanner_rag_agent/agent.py
+++ b/contributing/samples/spanner_rag_agent/agent.py
@@ -13,23 +13,15 @@
# limitations under the License.
import os
-from typing import Any
-from typing import Dict
-from typing import Optional
from google.adk.agents.llm_agent import LlmAgent
from google.adk.auth.auth_credential import AuthCredentialTypes
-from google.adk.tools.base_tool import BaseTool
-from google.adk.tools.google_tool import GoogleTool
-from google.adk.tools.spanner import query_tool
-from google.adk.tools.spanner import search_tool
from google.adk.tools.spanner.settings import Capabilities
from google.adk.tools.spanner.settings import SpannerToolSettings
+from google.adk.tools.spanner.settings import SpannerVectorStoreSettings
from google.adk.tools.spanner.spanner_credentials import SpannerCredentialsConfig
-from google.adk.tools.tool_context import ToolContext
+from google.adk.tools.spanner.spanner_toolset import SpannerToolset
import google.auth
-from google.auth.credentials import Credentials
-from pydantic import BaseModel
# Define an appropriate credential type
# Set to None to use the application default credentials (ADC) for a quick
@@ -37,9 +29,6 @@
CREDENTIALS_TYPE = None
-# Define Spanner tool config with read capability set to allowed.
-tool_settings = SpannerToolSettings(capabilities=[Capabilities.DATA_READ])
-
if CREDENTIALS_TYPE == AuthCredentialTypes.OAUTH2:
# Initialize the tools to do interactive OAuth
# The environment variables OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET
@@ -67,172 +56,46 @@
credentials=application_default_credentials
)
+# Follow the instructions in README.md to set up the example Spanner database.
+# Replace the following settings with your specific Spanner database.
+
+# Define Spanner vector store settings.
+vector_store_settings = SpannerVectorStoreSettings(
+ project_id="",
+ instance_id="",
+ database_id="",
+ table_name="products",
+ content_column="productDescription",
+ embedding_column="productDescriptionEmbedding",
+ vector_length=768,
+ vertex_ai_embedding_model_name="text-embedding-005",
+ selected_columns=[
+ "productId",
+ "productName",
+ "productDescription",
+ ],
+ nearest_neighbors_algorithm="EXACT_NEAREST_NEIGHBORS",
+ top_k=3,
+ distance_type="COSINE",
+ additional_filter="inventoryCount > 0",
+)
-### Section 1: Extending the built-in Spanner Toolset for Custom Use Cases ###
-# This example illustrates how to extend the built-in Spanner toolset to create
-# a customized Spanner tool. This method is advantageous when you need to deal
-# with a specific use case:
-#
-# 1. Streamline the end user experience by pre-configuring the tool with fixed
-# parameters (such as a specific database, instance, or project) and a
-# dedicated SQL query, making it perfect for a single, focused use case
-# like vector search on a specific table.
-# 2. Enhance functionality by adding custom logic to manage tool inputs,
-# execution, and result processing, providing greater control over the
-# tool's behavior.
-class SpannerRagSetting(BaseModel):
- """Customized Spanner RAG settings for an example use case."""
-
- # Replace the following settings for your Spanner database used in the sample.
- project_id: str = ""
- instance_id: str = ""
- database_id: str = ""
-
- # Follow the instructions in README.md, the table name is "products" and the
- # Spanner embedding model name is "EmbeddingsModel" in this sample.
- table_name: str = "products"
- # Learn more about Spanner Vertex AI integration for embedding and Spanner
- # vector search.
- # https://cloud.google.com/spanner/docs/ml-tutorial-embeddings
- # https://cloud.google.com/spanner/docs/vector-search/overview
- embedding_model_name: str = "EmbeddingsModel"
-
- selected_columns: list[str] = [
- "productId",
- "productName",
- "productDescription",
- ]
- embedding_column_name: str = "productDescriptionEmbedding"
-
- additional_filter_expression: str = "inventoryCount > 0"
- vector_distance_function: str = "EUCLIDEAN_DISTANCE"
- top_k: int = 3
-
-
-RAG_SETTINGS = SpannerRagSetting()
-
-
-### (Option 1) Use the built-in similarity_search tool ###
-# Create a wrapped function tool for the agent on top of the built-in
-# similarity_search tool in the Spanner toolset.
-# This customized tool is used to perform a Spanner KNN vector search on a
-# embedded knowledge base stored in a Spanner database table.
-def wrapped_spanner_similarity_search(
- search_query: str,
- credentials: Credentials, # GoogleTool handles `credentials` automatically
- settings: SpannerToolSettings, # GoogleTool handles `settings` automatically
- tool_context: ToolContext, # GoogleTool handles `tool_context` automatically
-) -> str:
- """Perform a similarity search on the product catalog.
-
- Args:
- search_query: The search query to find relevant content.
-
- Returns:
- Relevant product catalog content with sources
- """
- columns = RAG_SETTINGS.selected_columns.copy()
-
- # Instead of fixing all parameters, you can also expose some of them for
- # the LLM to decide.
- return search_tool.similarity_search(
- RAG_SETTINGS.project_id,
- RAG_SETTINGS.instance_id,
- RAG_SETTINGS.database_id,
- RAG_SETTINGS.table_name,
- search_query,
- RAG_SETTINGS.embedding_column_name,
- columns,
- {
- "spanner_embedding_model_name": RAG_SETTINGS.embedding_model_name,
- },
- credentials,
- settings,
- tool_context,
- RAG_SETTINGS.additional_filter_expression,
- {
- "top_k": RAG_SETTINGS.top_k,
- "distance_type": RAG_SETTINGS.vector_distance_function,
- },
- )
-
-
-### (Option 2) Use the built-in execute_sql tool ###
-# Create a wrapped function tool for the agent on top of the built-in
-# execute_sql tool in the Spanner toolset.
-# This customized tool is used to perform a Spanner KNN vector search on a
-# embedded knowledge base stored in a Spanner database table.
-#
-# Compared with similarity_search, using execute_sql (a lower level tool) means
-# that you have more control, but you also need to do more work (e.g. to write
-# the SQL query from scratch). Consider using this option if your scenario is
-# more complicated than a plain similarity search.
-def wrapped_spanner_execute_sql_tool(
- search_query: str,
- credentials: Credentials, # GoogleTool handles `credentials` automatically
- settings: SpannerToolSettings, # GoogleTool handles `settings` automatically
- tool_context: ToolContext, # GoogleTool handles `tool_context` automatically
-) -> str:
- """Perform a similarity search on the product catalog.
-
- Args:
- search_query: The search query to find relevant content.
-
- Returns:
- Relevant product catalog content with sources
- """
-
- embedding_query = f"""SELECT embeddings.values
- FROM ML.PREDICT(
- MODEL {RAG_SETTINGS.embedding_model_name},
- (SELECT "{search_query}" as content)
- )
- """
-
- distance_alias = "distance"
- columns = [f"{column}" for column in RAG_SETTINGS.selected_columns]
- columns += [f"""{RAG_SETTINGS.vector_distance_function}(
- {RAG_SETTINGS.embedding_column_name},
- ({embedding_query})) AS {distance_alias}
- """]
- columns = ", ".join(columns)
-
- knn_query = f"""
- SELECT {columns}
- FROM {RAG_SETTINGS.table_name}
- WHERE {RAG_SETTINGS.additional_filter_expression}
- ORDER BY {distance_alias}
- LIMIT {RAG_SETTINGS.top_k}
- """
-
- # Customized tool based on the built-in Spanner toolset.
- return query_tool.execute_sql(
- project_id=RAG_SETTINGS.project_id,
- instance_id=RAG_SETTINGS.instance_id,
- database_id=RAG_SETTINGS.database_id,
- query=knn_query,
- credentials=credentials,
- settings=settings,
- tool_context=tool_context,
- )
-
-
-def inspect_tool_params(
- tool: BaseTool,
- args: Dict[str, Any],
- tool_context: ToolContext,
-) -> Optional[Dict]:
- """A callback function to inspect tool parameters before execution."""
- print("Inspect for tool: " + tool.name)
-
- actual_search_query_in_args = args.get("search_query")
- # Inspect the `search_query` when calling the tool for tutorial purposes.
- print(f"Tool args `search_query`: {actual_search_query_in_args}")
+# Define Spanner tool config with the vector store settings.
+tool_settings = SpannerToolSettings(
+ capabilities=[Capabilities.DATA_READ],
+ vector_store_settings=vector_store_settings,
+)
- pass
+# Get the Spanner toolset with the Spanner tool settings and credentials config.
+# Filter the tools to only include the `vector_store_similarity_search` tool.
+spanner_toolset = SpannerToolset(
+ credentials_config=credentials_config,
+ spanner_tool_settings=tool_settings,
+ # Comment to include all allowed tools.
+ tool_filter=["vector_store_similarity_search"],
+)
-### Section 2: Create the root agent ###
root_agent = LlmAgent(
model="gemini-2.5-flash",
name="spanner_knowledge_base_agent",
@@ -241,27 +104,10 @@ def inspect_tool_params(
),
instruction="""
You are a helpful assistant that answers user questions about product-specific recommendations.
- 1. Always use the `wrapped_spanner_similarity_search` tool to find relevant information.
- 2. If no relevant information is found, say you don't know.
- 3. Present all the relevant information naturally and well formatted in your response.
+ 1. Always use the `vector_store_similarity_search` tool to find information.
+ 2. Directly present all the information results from the `vector_store_similarity_search` tool naturally and well formatted in your response.
+ 3. If no information result is returned by the `vector_store_similarity_search` tool, say you don't know.
""",
- tools=[
- # # (Option 1)
- # # Add customized Spanner tool based on the built-in similarity_search
- # # in the Spanner toolset.
- GoogleTool(
- func=wrapped_spanner_similarity_search,
- credentials_config=credentials_config,
- tool_settings=tool_settings,
- ),
- # # (Option 2)
- # # Add customized Spanner tool based on the built-in execute_sql in
- # # the Spanner toolset.
- # GoogleTool(
- # func=wrapped_spanner_execute_sql_tool,
- # credentials_config=credentials_config,
- # tool_settings=tool_settings,
- # ),
- ],
- before_tool_callback=inspect_tool_params,
+ # Use the Spanner toolset for vector similarity search.
+ tools=[spanner_toolset],
)
diff --git a/contributing/samples/telemetry/main.py b/contributing/samples/telemetry/main.py
index e580060dc4..c6e05f0f62 100755
--- a/contributing/samples/telemetry/main.py
+++ b/contributing/samples/telemetry/main.py
@@ -13,6 +13,7 @@
# limitations under the License.
import asyncio
+from contextlib import aclosing
import os
import time
@@ -46,19 +47,16 @@ async def run_prompt(session: Session, new_message: str):
role='user', parts=[types.Part.from_text(text=new_message)]
)
print('** User says:', content.model_dump(exclude_none=True))
- # TODO - migrate try...finally to contextlib.aclosing after Python 3.9 is
- # no longer supported.
- agen = runner.run_async(
- user_id=user_id_1,
- session_id=session.id,
- new_message=content,
- )
- try:
+ async with aclosing(
+ runner.run_async(
+ user_id=user_id_1,
+ session_id=session.id,
+ new_message=content,
+ )
+ ) as agen:
async for event in agen:
if event.content.parts and event.content.parts[0].text:
print(f'** {event.author}: {event.content.parts[0].text}')
- finally:
- await agen.aclose()
async def run_prompt_bytes(session: Session, new_message: str):
content = types.Content(
@@ -70,20 +68,17 @@ async def run_prompt_bytes(session: Session, new_message: str):
],
)
print('** User says:', content.model_dump(exclude_none=True))
- # TODO - migrate try...finally to contextlib.aclosing after Python 3.9 is
- # no longer supported.
- agen = runner.run_async(
- user_id=user_id_1,
- session_id=session.id,
- new_message=content,
- run_config=RunConfig(save_input_blobs_as_artifacts=True),
- )
- try:
+ async with aclosing(
+ runner.run_async(
+ user_id=user_id_1,
+ session_id=session.id,
+ new_message=content,
+ run_config=RunConfig(save_input_blobs_as_artifacts=True),
+ )
+ ) as agen:
async for event in agen:
if event.content.parts and event.content.parts[0].text:
print(f'** {event.author}: {event.content.parts[0].text}')
- finally:
- await agen.aclose()
start_time = time.time()
print('Start time:', start_time)
diff --git a/llms-full.txt b/llms-full.txt
index 4c744512e4..b84e9496ee 100644
--- a/llms-full.txt
+++ b/llms-full.txt
@@ -5620,7 +5620,7 @@ pip install google-cloud-aiplatform[adk,agent_engines]
```
!!!info
- Agent Engine only supported Python version >=3.9 and <=3.12.
+ Agent Engine only supported Python version >=3.10 and <=3.12.
### Initialization
@@ -8073,7 +8073,7 @@ setting up a basic agent with multiple tools, and running it locally either in t
This quickstart assumes a local IDE (VS Code, PyCharm, IntelliJ IDEA, etc.)
-with Python 3.9+ or Java 17+ and terminal access. This method runs the
+with Python 3.10+ or Java 17+ and terminal access. This method runs the
application entirely on your machine and is recommended for internal development.
## 1. Set up Environment & Install ADK {#venv-install}
@@ -16475,7 +16475,7 @@ This guide covers two primary integration patterns:
Before you begin, ensure you have the following set up:
* **Set up ADK:** Follow the standard ADK [setup instructions](../get-started/quickstart.md/#venv-install) in the quickstart.
-* **Install/update Python/Java:** MCP requires Python version of 3.9 or higher for Python or Java 17+.
+* **Install/update Python/Java:** MCP requires Python version of 3.10 or higher for Python or Java 17+.
* **Setup Node.js and npx:** **(Python only)** Many community MCP servers are distributed as Node.js packages and run using `npx`. Install Node.js (which includes npx) if you haven't already. For details, see [https://nodejs.org/en](https://nodejs.org/en).
* **Verify Installations:** **(Python only)** Confirm `adk` and `npx` are in your PATH within the activated virtual environment:
diff --git a/pyproject.toml b/pyproject.toml
index 4f8e42bcf7..06ddb04ef2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -40,11 +40,11 @@ dependencies = [
"google-cloud-secret-manager>=2.22.0, <3.0.0", # Fetching secrets in RestAPI Tool
"google-cloud-spanner>=3.56.0, <4.0.0", # For Spanner database
"google-cloud-speech>=2.30.0, <3.0.0", # For Audio Transcription
- "google-cloud-storage>=3.0.0, <4.0.0", # For GCS Artifact service
- "google-genai>=1.45.0, <2.0.0", # Google GenAI SDK
+ "google-cloud-storage>=2.18.0, <4.0.0", # For GCS Artifact service
+ "google-genai>=1.51.0, <2.0.0", # Google GenAI SDK
"graphviz>=0.20.2, <1.0.0", # Graphviz for graph rendering
"jsonschema>=4.23.0, <5.0.0", # Agent Builder config validation
- "mcp>=1.8.0, <2.0.0", # For MCP Toolset
+ "mcp>=1.10.0, <2.0.0", # For MCP Toolset
"opentelemetry-api>=1.37.0, <=1.37.0", # OpenTelemetry - limit upper version for sdk and api to not risk breaking changes from unstable _logs package.
"opentelemetry-exporter-gcp-logging>=1.9.0a0, <2.0.0",
"opentelemetry-exporter-gcp-monitoring>=1.9.0a0, <2.0.0",
diff --git a/src/google/adk/a2a/converters/part_converter.py b/src/google/adk/a2a/converters/part_converter.py
index a21042cc10..dfe6f4a0a2 100644
--- a/src/google/adk/a2a/converters/part_converter.py
+++ b/src/google/adk/a2a/converters/part_converter.py
@@ -26,23 +26,11 @@
from typing import Optional
from typing import Union
-from .utils import _get_adk_metadata_key
-
-try:
- from a2a import types as a2a_types
-except ImportError as e:
- import sys
-
- if sys.version_info < (3, 10):
- raise ImportError(
- 'A2A requires Python 3.10 or above. Please upgrade your Python version.'
- ) from e
- else:
- raise e
-
+from a2a import types as a2a_types
from google.genai import types as genai_types
from ..experimental import a2a_experimental
+from .utils import _get_adk_metadata_key
logger = logging.getLogger('google_adk.' + __name__)
diff --git a/src/google/adk/a2a/converters/request_converter.py b/src/google/adk/a2a/converters/request_converter.py
index 39db41dac6..1746ec0bca 100644
--- a/src/google/adk/a2a/converters/request_converter.py
+++ b/src/google/adk/a2a/converters/request_converter.py
@@ -15,23 +15,12 @@
from __future__ import annotations
from collections.abc import Callable
-import sys
from typing import Any
from typing import Optional
-from pydantic import BaseModel
-
-try:
- from a2a.server.agent_execution import RequestContext
-except ImportError as e:
- if sys.version_info < (3, 10):
- raise ImportError(
- 'A2A requires Python 3.10 or above. Please upgrade your Python version.'
- ) from e
- else:
- raise e
-
+from a2a.server.agent_execution import RequestContext
from google.genai import types as genai_types
+from pydantic import BaseModel
from ...runners import RunConfig
from ..experimental import a2a_experimental
diff --git a/src/google/adk/a2a/executor/a2a_agent_executor.py b/src/google/adk/a2a/executor/a2a_agent_executor.py
index 608a818864..b6880aaa5c 100644
--- a/src/google/adk/a2a/executor/a2a_agent_executor.py
+++ b/src/google/adk/a2a/executor/a2a_agent_executor.py
@@ -23,34 +23,22 @@
from typing import Optional
import uuid
-from ...utils.context_utils import Aclosing
-
-try:
- from a2a.server.agent_execution import AgentExecutor
- from a2a.server.agent_execution.context import RequestContext
- from a2a.server.events.event_queue import EventQueue
- from a2a.types import Artifact
- from a2a.types import Message
- from a2a.types import Role
- from a2a.types import TaskArtifactUpdateEvent
- from a2a.types import TaskState
- from a2a.types import TaskStatus
- from a2a.types import TaskStatusUpdateEvent
- from a2a.types import TextPart
-
-except ImportError as e:
- import sys
-
- if sys.version_info < (3, 10):
- raise ImportError(
- 'A2A requires Python 3.10 or above. Please upgrade your Python version.'
- ) from e
- else:
- raise e
+from a2a.server.agent_execution import AgentExecutor
+from a2a.server.agent_execution.context import RequestContext
+from a2a.server.events.event_queue import EventQueue
+from a2a.types import Artifact
+from a2a.types import Message
+from a2a.types import Role
+from a2a.types import TaskArtifactUpdateEvent
+from a2a.types import TaskState
+from a2a.types import TaskStatus
+from a2a.types import TaskStatusUpdateEvent
+from a2a.types import TextPart
from google.adk.runners import Runner
from pydantic import BaseModel
from typing_extensions import override
+from ...utils.context_utils import Aclosing
from ..converters.event_converter import AdkEventToA2AEventsConverter
from ..converters.event_converter import convert_event_to_a2a_events
from ..converters.part_converter import A2APartToGenAIPartConverter
diff --git a/src/google/adk/a2a/utils/agent_card_builder.py b/src/google/adk/a2a/utils/agent_card_builder.py
index aa7f657f99..c007870931 100644
--- a/src/google/adk/a2a/utils/agent_card_builder.py
+++ b/src/google/adk/a2a/utils/agent_card_builder.py
@@ -15,25 +15,15 @@
from __future__ import annotations
import re
-import sys
from typing import Dict
from typing import List
from typing import Optional
-try:
- from a2a.types import AgentCapabilities
- from a2a.types import AgentCard
- from a2a.types import AgentProvider
- from a2a.types import AgentSkill
- from a2a.types import SecurityScheme
-except ImportError as e:
- if sys.version_info < (3, 10):
- raise ImportError(
- 'A2A requires Python 3.10 or above. Please upgrade your Python version.'
- ) from e
- else:
- raise e
-
+from a2a.types import AgentCapabilities
+from a2a.types import AgentCard
+from a2a.types import AgentProvider
+from a2a.types import AgentSkill
+from a2a.types import SecurityScheme
from ...agents.base_agent import BaseAgent
from ...agents.llm_agent import LlmAgent
diff --git a/src/google/adk/a2a/utils/agent_to_a2a.py b/src/google/adk/a2a/utils/agent_to_a2a.py
index 72a2292fb3..1a1ba35618 100644
--- a/src/google/adk/a2a/utils/agent_to_a2a.py
+++ b/src/google/adk/a2a/utils/agent_to_a2a.py
@@ -15,30 +15,18 @@
from __future__ import annotations
import logging
-import sys
-
-try:
- from a2a.server.apps import A2AStarletteApplication
- from a2a.server.request_handlers import DefaultRequestHandler
- from a2a.server.tasks import InMemoryTaskStore
- from a2a.types import AgentCard
-except ImportError as e:
- if sys.version_info < (3, 10):
- raise ImportError(
- "A2A requires Python 3.10 or above. Please upgrade your Python version."
- ) from e
- else:
- raise e
-
from typing import Optional
from typing import Union
+from a2a.server.apps import A2AStarletteApplication
+from a2a.server.request_handlers import DefaultRequestHandler
+from a2a.server.tasks import InMemoryTaskStore
+from a2a.types import AgentCard
from starlette.applications import Starlette
from ...agents.base_agent import BaseAgent
from ...artifacts.in_memory_artifact_service import InMemoryArtifactService
from ...auth.credential_service.in_memory_credential_service import InMemoryCredentialService
-from ...cli.utils.logs import setup_adk_logger
from ...memory.in_memory_memory_service import InMemoryMemoryService
from ...runners import Runner
from ...sessions.in_memory_session_service import InMemorySessionService
@@ -117,7 +105,8 @@ def to_a2a(
app = to_a2a(agent, agent_card=my_custom_agent_card)
"""
# Set up ADK logging to ensure logs are visible when using uvicorn directly
- setup_adk_logger(logging.INFO)
+ adk_logger = logging.getLogger("google_adk")
+ adk_logger.setLevel(logging.INFO)
async def create_runner() -> Runner:
"""Create a runner for the agent."""
diff --git a/src/google/adk/agents/__init__.py b/src/google/adk/agents/__init__.py
index 5710a21b7f..b5f8e88cde 100644
--- a/src/google/adk/agents/__init__.py
+++ b/src/google/adk/agents/__init__.py
@@ -12,9 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
-import sys
-
from .base_agent import BaseAgent
from .invocation_context import InvocationContext
from .live_request_queue import LiveRequest
@@ -22,6 +19,7 @@
from .llm_agent import Agent
from .llm_agent import LlmAgent
from .loop_agent import LoopAgent
+from .mcp_instruction_provider import McpInstructionProvider
from .parallel_agent import ParallelAgent
from .run_config import RunConfig
from .sequential_agent import SequentialAgent
@@ -31,6 +29,7 @@
'BaseAgent',
'LlmAgent',
'LoopAgent',
+ 'McpInstructionProvider',
'ParallelAgent',
'SequentialAgent',
'InvocationContext',
@@ -38,16 +37,3 @@
'LiveRequestQueue',
'RunConfig',
]
-
-if sys.version_info < (3, 10):
- logger = logging.getLogger('google_adk.' + __name__)
- logger.warning(
- 'MCP requires Python 3.10 or above. Please upgrade your Python'
- ' version in order to use it.'
- )
-else:
- from .mcp_instruction_provider import McpInstructionProvider
-
- __all__.extend([
- 'McpInstructionProvider',
- ])
diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py
index a644cb8b90..e15f9af981 100644
--- a/src/google/adk/agents/base_agent.py
+++ b/src/google/adk/agents/base_agent.py
@@ -15,6 +15,7 @@
from __future__ import annotations
import inspect
+import logging
from typing import Any
from typing import AsyncGenerator
from typing import Awaitable
@@ -49,6 +50,8 @@
if TYPE_CHECKING:
from .invocation_context import InvocationContext
+logger = logging.getLogger('google_adk.' + __name__)
+
_SingleAgentCallback: TypeAlias = Callable[
[CallbackContext],
Union[Awaitable[Optional[types.Content]], Optional[types.Content]],
@@ -563,6 +566,45 @@ def validate_name(cls, value: str):
)
return value
+ @field_validator('sub_agents', mode='after')
+ @classmethod
+ def validate_sub_agents_unique_names(
+ cls, value: list[BaseAgent]
+ ) -> list[BaseAgent]:
+ """Validates that all sub-agents have unique names.
+
+ Args:
+ value: The list of sub-agents to validate.
+
+ Returns:
+ The validated list of sub-agents.
+
+ """
+ if not value:
+ return value
+
+ seen_names: set[str] = set()
+ duplicates: set[str] = set()
+
+ for sub_agent in value:
+ name = sub_agent.name
+ if name in seen_names:
+ duplicates.add(name)
+ else:
+ seen_names.add(name)
+
+ if duplicates:
+ duplicate_names_str = ', '.join(
+ f'`{name}`' for name in sorted(duplicates)
+ )
+ logger.warning(
+ 'Found duplicate sub-agent names: %s. '
+ 'All sub-agents must have unique names.',
+ duplicate_names_str,
+ )
+
+ return value
+
def __set_parent_agent_for_sub_agents(self) -> BaseAgent:
for sub_agent in self.sub_agents:
if sub_agent.parent_agent is not None:
diff --git a/src/google/adk/agents/config_agent_utils.py b/src/google/adk/agents/config_agent_utils.py
index 7982a9cf59..38ba2e2578 100644
--- a/src/google/adk/agents/config_agent_utils.py
+++ b/src/google/adk/agents/config_agent_utils.py
@@ -132,7 +132,7 @@ def resolve_agent_reference(
else:
return from_config(
os.path.join(
- referencing_agent_config_abs_path.rsplit("/", 1)[0],
+ os.path.dirname(referencing_agent_config_abs_path),
ref_config.config_path,
)
)
diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py
index 2f8a969fad..005d073cc7 100644
--- a/src/google/adk/agents/llm_agent.py
+++ b/src/google/adk/agents/llm_agent.py
@@ -910,7 +910,9 @@ def _parse_config(
from .config_agent_utils import resolve_callbacks
from .config_agent_utils import resolve_code_reference
- if config.model:
+ if config.model_code:
+ kwargs['model'] = resolve_code_reference(config.model_code)
+ elif config.model:
kwargs['model'] = config.model
if config.instruction:
kwargs['instruction'] = config.instruction
diff --git a/src/google/adk/agents/llm_agent_config.py b/src/google/adk/agents/llm_agent_config.py
index 7d2493597e..59c6d58869 100644
--- a/src/google/adk/agents/llm_agent_config.py
+++ b/src/google/adk/agents/llm_agent_config.py
@@ -15,6 +15,7 @@
from __future__ import annotations
import logging
+from typing import Any
from typing import List
from typing import Literal
from typing import Optional
@@ -22,6 +23,7 @@
from google.genai import types
from pydantic import ConfigDict
from pydantic import Field
+from pydantic import model_validator
from ..tools.tool_configs import ToolConfig
from .base_agent_config import BaseAgentConfig
@@ -52,11 +54,48 @@ class LlmAgentConfig(BaseAgentConfig):
model: Optional[str] = Field(
default=None,
description=(
- 'Optional. LlmAgent.model. If not set, the model will be inherited'
- ' from the ancestor.'
+ 'Optional. LlmAgent.model. Provide a model name string (e.g.'
+ ' "gemini-2.0-flash"). If not set, the model will be inherited from'
+ ' the ancestor. To construct a model instance from code, use'
+ ' model_code.'
),
)
+ model_code: Optional[CodeConfig] = Field(
+ default=None,
+ description=(
+ 'Optional. A CodeConfig that instantiates a BaseLlm implementation'
+ ' such as LiteLlm with custom arguments (API base, fallbacks,'
+ ' etc.). Cannot be set together with `model`.'
+ ),
+ )
+
+ @model_validator(mode='before')
+ @classmethod
+ def _normalize_model_code(cls, data: Any) -> dict[str, Any] | Any:
+ if not isinstance(data, dict):
+ return data
+
+ model_value = data.get('model')
+ model_code = data.get('model_code')
+ if isinstance(model_value, dict) and model_code is None:
+ logger.warning(
+ 'Detected legacy `model` mapping. Use `model_code` to provide a'
+ ' CodeConfig for custom model construction.'
+ )
+ data = dict(data)
+ data['model_code'] = model_value
+ data['model'] = None
+
+ return data
+
+ @model_validator(mode='after')
+ def _validate_model_sources(self) -> LlmAgentConfig:
+ if self.model and self.model_code:
+ raise ValueError('Only one of `model` or `model_code` should be set.')
+
+ return self
+
instruction: str = Field(
description=(
'Required. LlmAgent.instruction. Dynamic instructions with'
diff --git a/src/google/adk/agents/mcp_instruction_provider.py b/src/google/adk/agents/mcp_instruction_provider.py
index e9f40663c9..20896a7a04 100644
--- a/src/google/adk/agents/mcp_instruction_provider.py
+++ b/src/google/adk/agents/mcp_instruction_provider.py
@@ -22,24 +22,12 @@
from typing import Dict
from typing import TextIO
+from mcp import types
+
+from ..tools.mcp_tool.mcp_session_manager import MCPSessionManager
from .llm_agent import InstructionProvider
from .readonly_context import ReadonlyContext
-# Attempt to import MCP Session Manager from the MCP library, and hints user to
-# upgrade their Python version to 3.10 if it fails.
-try:
- from mcp import types
-
- from ..tools.mcp_tool.mcp_session_manager import MCPSessionManager
-except ImportError as e:
- if sys.version_info < (3, 10):
- raise ImportError(
- "MCP Session Manager requires Python 3.10 or above. Please upgrade"
- " your Python version."
- ) from e
- else:
- raise e
-
class McpInstructionProvider(InstructionProvider):
"""Fetches agent instructions from an MCP server."""
diff --git a/src/google/adk/agents/parallel_agent.py b/src/google/adk/agents/parallel_agent.py
index f7270a75c9..09e65a67a4 100644
--- a/src/google/adk/agents/parallel_agent.py
+++ b/src/google/adk/agents/parallel_agent.py
@@ -48,34 +48,13 @@ def _create_branch_ctx_for_sub_agent(
return invocation_context
-# TODO - remove once Python <3.11 is no longer supported.
-async def _merge_agent_run_pre_3_11(
+async def _merge_agent_run(
agent_runs: list[AsyncGenerator[Event, None]],
) -> AsyncGenerator[Event, None]:
- """Merges the agent run event generator.
- This version works in Python 3.9 and 3.10 and uses custom replacement for
- asyncio.TaskGroup for tasks cancellation and exception handling.
-
- This implementation guarantees for each agent, it won't move on until the
- generated event is processed by upstream runner.
-
- Args:
- agent_runs: A list of async generators that yield events from each agent.
-
- Yields:
- Event: The next event from the merged generator.
- """
+ """Merges agent runs using asyncio.TaskGroup on Python 3.11+."""
sentinel = object()
queue = asyncio.Queue()
- def propagate_exceptions(tasks):
- # Propagate exceptions and errors from tasks.
- for task in tasks:
- if task.done():
- # Ignore the result (None) of correctly finished tasks and re-raise
- # exceptions and errors.
- task.result()
-
# Agents are processed in parallel.
# Events for each agent are put on queue sequentially.
async def process_an_agent(events_for_one_agent):
@@ -89,39 +68,34 @@ async def process_an_agent(events_for_one_agent):
# Mark agent as finished.
await queue.put((sentinel, None))
- tasks = []
- try:
+ async with asyncio.TaskGroup() as tg:
for events_for_one_agent in agent_runs:
- tasks.append(asyncio.create_task(process_an_agent(events_for_one_agent)))
+ tg.create_task(process_an_agent(events_for_one_agent))
sentinel_count = 0
# Run until all agents finished processing.
while sentinel_count < len(agent_runs):
- propagate_exceptions(tasks)
event, resume_signal = await queue.get()
# Agent finished processing.
if event is sentinel:
sentinel_count += 1
else:
yield event
- # Signal to agent that event has been processed by runner and it can
- # continue now.
+ # Signal to agent that it should generate next event.
resume_signal.set()
- finally:
- for task in tasks:
- task.cancel()
-async def _merge_agent_run(
+# TODO - remove once Python <3.11 is no longer supported.
+async def _merge_agent_run_pre_3_11(
agent_runs: list[AsyncGenerator[Event, None]],
) -> AsyncGenerator[Event, None]:
- """Merges the agent run event generator.
+ """Merges agent runs for Python 3.10 without asyncio.TaskGroup.
- This implementation guarantees for each agent, it won't move on until the
- generated event is processed by upstream runner.
+ Uses custom cancellation and exception handling to mirror TaskGroup
+ semantics. Each agent waits until the runner processes emitted events.
Args:
- agent_runs: A list of async generators that yield events from each agent.
+ agent_runs: Async generators that yield events from each agent.
Yields:
Event: The next event from the merged generator.
@@ -129,6 +103,14 @@ async def _merge_agent_run(
sentinel = object()
queue = asyncio.Queue()
+ def propagate_exceptions(tasks):
+ # Propagate exceptions and errors from tasks.
+ for task in tasks:
+ if task.done():
+ # Ignore the result (None) of correctly finished tasks and re-raise
+ # exceptions and errors.
+ task.result()
+
# Agents are processed in parallel.
# Events for each agent are put on queue sequentially.
async def process_an_agent(events_for_one_agent):
@@ -142,21 +124,27 @@ async def process_an_agent(events_for_one_agent):
# Mark agent as finished.
await queue.put((sentinel, None))
- async with asyncio.TaskGroup() as tg:
+ tasks = []
+ try:
for events_for_one_agent in agent_runs:
- tg.create_task(process_an_agent(events_for_one_agent))
+ tasks.append(asyncio.create_task(process_an_agent(events_for_one_agent)))
sentinel_count = 0
# Run until all agents finished processing.
while sentinel_count < len(agent_runs):
+ propagate_exceptions(tasks)
event, resume_signal = await queue.get()
# Agent finished processing.
if event is sentinel:
sentinel_count += 1
else:
yield event
- # Signal to agent that it should generate next event.
+ # Signal to agent that event has been processed by runner and it can
+ # continue now.
resume_signal.set()
+ finally:
+ for task in tasks:
+ task.cancel()
class ParallelAgent(BaseAgent):
@@ -195,13 +183,11 @@ async def _run_async_impl(
pause_invocation = False
try:
- # TODO remove if once Python <3.11 is no longer supported.
merge_func = (
_merge_agent_run
if sys.version_info >= (3, 11)
else _merge_agent_run_pre_3_11
)
-
async with Aclosing(merge_func(agent_runs)) as agen:
async for event in agen:
yield event
diff --git a/src/google/adk/agents/remote_a2a_agent.py b/src/google/adk/agents/remote_a2a_agent.py
index 7b6ff5cdd9..8d133060ec 100644
--- a/src/google/adk/agents/remote_a2a_agent.py
+++ b/src/google/adk/agents/remote_a2a_agent.py
@@ -26,30 +26,22 @@
from urllib.parse import urlparse
import uuid
-try:
- from a2a.client import Client as A2AClient
- from a2a.client import ClientEvent as A2AClientEvent
- from a2a.client.card_resolver import A2ACardResolver
- from a2a.client.client import ClientConfig as A2AClientConfig
- from a2a.client.client_factory import ClientFactory as A2AClientFactory
- from a2a.client.errors import A2AClientError
- from a2a.types import AgentCard
- from a2a.types import Message as A2AMessage
- from a2a.types import Part as A2APart
- from a2a.types import Role
- from a2a.types import TaskArtifactUpdateEvent as A2ATaskArtifactUpdateEvent
- from a2a.types import TaskState
- from a2a.types import TaskStatusUpdateEvent as A2ATaskStatusUpdateEvent
- from a2a.types import TransportProtocol as A2ATransport
-except ImportError as e:
- import sys
-
- if sys.version_info < (3, 10):
- raise ImportError(
- "A2A requires Python 3.10 or above. Please upgrade your Python version."
- ) from e
- else:
- raise e
+from a2a.client import Client as A2AClient
+from a2a.client import ClientEvent as A2AClientEvent
+from a2a.client.card_resolver import A2ACardResolver
+from a2a.client.client import ClientConfig as A2AClientConfig
+from a2a.client.client_factory import ClientFactory as A2AClientFactory
+from a2a.client.errors import A2AClientError
+from a2a.types import AgentCard
+from a2a.types import Message as A2AMessage
+from a2a.types import Part as A2APart
+from a2a.types import Role
+from a2a.types import TaskArtifactUpdateEvent as A2ATaskArtifactUpdateEvent
+from a2a.types import TaskState
+from a2a.types import TaskStatusUpdateEvent as A2ATaskStatusUpdateEvent
+from a2a.types import TransportProtocol as A2ATransport
+from google.genai import types as genai_types
+import httpx
try:
from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH
@@ -57,9 +49,6 @@
# Fallback for older versions of a2a-sdk.
AGENT_CARD_WELL_KNOWN_PATH = "/.well-known/agent.json"
-from google.genai import types as genai_types
-import httpx
-
from ..a2a.converters.event_converter import convert_a2a_message_to_event
from ..a2a.converters.event_converter import convert_a2a_task_to_event
from ..a2a.converters.event_converter import convert_event_to_a2a_message
@@ -417,7 +406,9 @@ async def _handle_a2a_response(
# This is the initial response for a streaming task or the complete
# response for a non-streaming task, which is the full task state.
# We process this to get the initial message.
- event = convert_a2a_task_to_event(task, self.name, ctx)
+ event = convert_a2a_task_to_event(
+ task, self.name, ctx, self._a2a_part_converter
+ )
# for streaming task, we update the event with the task status.
# We update the event as Thought updates.
if task and task.status and task.status.state == TaskState.submitted:
@@ -429,7 +420,7 @@ async def _handle_a2a_response(
):
# This is a streaming task status update with a message.
event = convert_a2a_message_to_event(
- update.status.message, self.name, ctx
+ update.status.message, self.name, ctx, self._a2a_part_converter
)
if event.content and update.status.state in [
TaskState.submitted,
@@ -447,7 +438,9 @@ async def _handle_a2a_response(
# signals:
# 1. append: True for partial updates, False for full updates.
# 2. last_chunk: True for full updates, False for partial updates.
- event = convert_a2a_task_to_event(task, self.name, ctx)
+ event = convert_a2a_task_to_event(
+ task, self.name, ctx, self._a2a_part_converter
+ )
else:
# This is a streaming update without a message (e.g. status change)
# or a partial artifact update. We don't emit an event for these
@@ -463,7 +456,9 @@ async def _handle_a2a_response(
# Otherwise, it's a regular A2AMessage for non-streaming responses.
elif isinstance(a2a_response, A2AMessage):
- event = convert_a2a_message_to_event(a2a_response, self.name, ctx)
+ event = convert_a2a_message_to_event(
+ a2a_response, self.name, ctx, self._a2a_part_converter
+ )
event.custom_metadata = event.custom_metadata or {}
if a2a_response.context_id:
diff --git a/src/google/adk/apps/compaction.py b/src/google/adk/apps/compaction.py
index a6f55f9ad6..4511b1b96e 100644
--- a/src/google/adk/apps/compaction.py
+++ b/src/google/adk/apps/compaction.py
@@ -72,9 +72,10 @@ async def _run_compaction_for_sliding_window(
beginning.
- A `CompactedEvent` is generated, summarizing events within
`invocation_id` range [1, 2].
- - The session now contains: `[E(inv=1, role=user), E(inv=1, role=model),
- E(inv=2, role=user), E(inv=2, role=model), E(inv=2, role=user),
- CompactedEvent(inv=[1, 2])]`.
+ - The session now contains: `[
+ E(inv=1, role=user), E(inv=1, role=model),
+ E(inv=2, role=user), E(inv=2, role=model),
+ CompactedEvent(inv=[1, 2])]`.
2. **After `invocation_id` 3 events are added:**
- No compaction happens yet, because only 1 new invocation (`inv=3`)
@@ -91,10 +92,13 @@ async def _run_compaction_for_sliding_window(
- The new compaction range is from `invocation_id` 2 to 4.
- A new `CompactedEvent` is generated, summarizing events within
`invocation_id` range [2, 4].
- - The session now contains: `[E(inv=1, role=user), E(inv=1, role=model),
- E(inv=2, role=user), E(inv=2, role=model), E(inv=2, role=user),
- CompactedEvent(inv=[1, 2]), E(inv=3, role=user), E(inv=3, role=model),
- E(inv=4, role=user), E(inv=4, role=model), CompactedEvent(inv=[2, 4])]`.
+ - The session now contains: `[
+ E(inv=1, role=user), E(inv=1, role=model),
+ E(inv=2, role=user), E(inv=2, role=model),
+ CompactedEvent(inv=[1, 2]),
+ E(inv=3, role=user), E(inv=3, role=model),
+ E(inv=4, role=user), E(inv=4, role=model),
+ CompactedEvent(inv=[2, 4])]`.
Args:
diff --git a/src/google/adk/artifacts/file_artifact_service.py b/src/google/adk/artifacts/file_artifact_service.py
index 825e4a7a71..53a830c066 100644
--- a/src/google/adk/artifacts/file_artifact_service.py
+++ b/src/google/adk/artifacts/file_artifact_service.py
@@ -32,6 +32,7 @@
from pydantic import ValidationError
from typing_extensions import override
+from ..errors.input_validation_error import InputValidationError
from .base_artifact_service import ArtifactVersion
from .base_artifact_service import BaseArtifactService
@@ -100,14 +101,14 @@ def _resolve_scoped_artifact_path(
to `scope_root`.
Raises:
- ValueError: If `filename` resolves outside of `scope_root`.
+ InputValidationError: If `filename` resolves outside of `scope_root`.
"""
stripped = _strip_user_namespace(filename).strip()
pure_path = _to_posix_path(stripped)
scope_root_resolved = scope_root.resolve(strict=False)
if pure_path.is_absolute():
- raise ValueError(
+ raise InputValidationError(
f"Absolute artifact filename {filename!r} is not permitted; "
"provide a path relative to the storage scope."
)
@@ -118,7 +119,7 @@ def _resolve_scoped_artifact_path(
try:
relative = candidate.relative_to(scope_root_resolved)
except ValueError as exc:
- raise ValueError(
+ raise InputValidationError(
f"Artifact filename {filename!r} escapes storage directory "
f"{scope_root_resolved}"
) from exc
@@ -230,7 +231,7 @@ def _scope_root(
if _is_user_scoped(session_id, filename):
return _user_artifacts_dir(base)
if not session_id:
- raise ValueError(
+ raise InputValidationError(
"Session ID must be provided for session-scoped artifacts."
)
return _session_artifacts_dir(base, session_id)
@@ -371,7 +372,9 @@ def _save_artifact_sync(
content_path.write_text(artifact.text, encoding="utf-8")
mime_type = None
else:
- raise ValueError("Artifact must have either inline_data or text content.")
+ raise InputValidationError(
+ "Artifact must have either inline_data or text content."
+ )
canonical_uri = self._canonical_uri(
user_id=user_id,
diff --git a/src/google/adk/artifacts/gcs_artifact_service.py b/src/google/adk/artifacts/gcs_artifact_service.py
index fc18dab6fc..2bf713a5e8 100644
--- a/src/google/adk/artifacts/gcs_artifact_service.py
+++ b/src/google/adk/artifacts/gcs_artifact_service.py
@@ -30,6 +30,7 @@
from google.genai import types
from typing_extensions import override
+from ..errors.input_validation_error import InputValidationError
from .base_artifact_service import ArtifactVersion
from .base_artifact_service import BaseArtifactService
@@ -161,7 +162,7 @@ def _get_blob_prefix(
return f"{app_name}/{user_id}/user/{filename}"
if session_id is None:
- raise ValueError(
+ raise InputValidationError(
"Session ID must be provided for session-scoped artifacts."
)
return f"{app_name}/{user_id}/{session_id}/{filename}"
@@ -230,7 +231,9 @@ def _save_artifact(
" GcsArtifactService."
)
else:
- raise ValueError("Artifact must have either inline_data or text.")
+ raise InputValidationError(
+ "Artifact must have either inline_data or text."
+ )
return version
diff --git a/src/google/adk/artifacts/in_memory_artifact_service.py b/src/google/adk/artifacts/in_memory_artifact_service.py
index 246e8a85fb..2c7dd14127 100644
--- a/src/google/adk/artifacts/in_memory_artifact_service.py
+++ b/src/google/adk/artifacts/in_memory_artifact_service.py
@@ -18,12 +18,13 @@
from typing import Any
from typing import Optional
-from google.adk.artifacts import artifact_util
from google.genai import types
from pydantic import BaseModel
from pydantic import Field
from typing_extensions import override
+from . import artifact_util
+from ..errors.input_validation_error import InputValidationError
from .base_artifact_service import ArtifactVersion
from .base_artifact_service import BaseArtifactService
@@ -86,7 +87,7 @@ def _artifact_path(
return f"{app_name}/{user_id}/user/{filename}"
if session_id is None:
- raise ValueError(
+ raise InputValidationError(
"Session ID must be provided for session-scoped artifacts."
)
return f"{app_name}/{user_id}/{session_id}/{filename}"
@@ -125,7 +126,7 @@ async def save_artifact(
elif artifact.file_data is not None:
if artifact_util.is_artifact_ref(artifact):
if not artifact_util.parse_artifact_uri(artifact.file_data.file_uri):
- raise ValueError(
+ raise InputValidationError(
f"Invalid artifact reference URI: {artifact.file_data.file_uri}"
)
# If it's a valid artifact URI, we store the artifact part as-is.
@@ -133,7 +134,7 @@ async def save_artifact(
else:
artifact_version.mime_type = artifact.file_data.mime_type
else:
- raise ValueError("Not supported artifact type.")
+ raise InputValidationError("Not supported artifact type.")
self.artifacts[path].append(
_ArtifactEntry(data=artifact, artifact_version=artifact_version)
@@ -172,7 +173,7 @@ async def load_artifact(
artifact_data.file_data.file_uri
)
if not parsed_uri:
- raise ValueError(
+ raise InputValidationError(
"Invalid artifact reference URI:"
f" {artifact_data.file_data.file_uri}"
)
diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py
index 1b422fe335..78fe426628 100644
--- a/src/google/adk/cli/adk_web_server.py
+++ b/src/google/adk/cli/adk_web_server.py
@@ -61,9 +61,11 @@
from ..agents.run_config import RunConfig
from ..agents.run_config import StreamingMode
from ..apps.app import App
+from ..artifacts.base_artifact_service import ArtifactVersion
from ..artifacts.base_artifact_service import BaseArtifactService
from ..auth.credential_service.base_credential_service import BaseCredentialService
from ..errors.already_exists_error import AlreadyExistsError
+from ..errors.input_validation_error import InputValidationError
from ..errors.not_found_error import NotFoundError
from ..evaluation.base_eval_service import InferenceConfig
from ..evaluation.base_eval_service import InferenceRequest
@@ -194,6 +196,19 @@ class CreateSessionRequest(common.BaseModel):
)
+class SaveArtifactRequest(common.BaseModel):
+ """Request payload for saving a new artifact."""
+
+ filename: str = Field(description="Artifact filename.")
+ artifact: types.Part = Field(
+ description="Artifact payload encoded as google.genai.types.Part."
+ )
+ custom_metadata: Optional[dict[str, Any]] = Field(
+ default=None,
+ description="Optional metadata to associate with the artifact version.",
+ )
+
+
class AddSessionToEvalSetRequest(common.BaseModel):
eval_id: str
session_id: str
@@ -280,6 +295,17 @@ class ListMetricsInfoResponse(common.BaseModel):
metrics_info: list[MetricInfo]
+class AppInfo(common.BaseModel):
+ name: str
+ root_agent_name: str
+ description: str
+ language: Literal["yaml", "python"]
+
+
+class ListAppsResponse(common.BaseModel):
+ apps: list[AppInfo]
+
+
def _setup_telemetry(
otel_to_cloud: bool = False,
internal_exporters: Optional[list[SpanProcessor]] = None,
@@ -699,7 +725,14 @@ async def internal_lifespan(app: FastAPI):
)
@app.get("/list-apps")
- async def list_apps() -> list[str]:
+ async def list_apps(
+ detailed: bool = Query(
+ default=False, description="Return detailed app information"
+ )
+ ) -> list[str] | ListAppsResponse:
+ if detailed:
+ apps_info = self.agent_loader.list_agents_detailed()
+ return ListAppsResponse(apps=[AppInfo(**app) for app in apps_info])
return self.agent_loader.list_agents()
@app.get("/debug/trace/{event_id}", tags=[TAG_DEBUG])
@@ -1298,6 +1331,53 @@ async def load_artifact_version(
raise HTTPException(status_code=404, detail="Artifact not found")
return artifact
+ @app.post(
+ "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts",
+ response_model=ArtifactVersion,
+ response_model_exclude_none=True,
+ )
+ async def save_artifact(
+ app_name: str,
+ user_id: str,
+ session_id: str,
+ req: SaveArtifactRequest,
+ ) -> ArtifactVersion:
+ try:
+ version = await self.artifact_service.save_artifact(
+ app_name=app_name,
+ user_id=user_id,
+ session_id=session_id,
+ filename=req.filename,
+ artifact=req.artifact,
+ custom_metadata=req.custom_metadata,
+ )
+ except InputValidationError as ive:
+ raise HTTPException(status_code=400, detail=str(ive)) from ive
+ except Exception as exc: # pylint: disable=broad-exception-caught
+ logger.error(
+ "Internal error while saving artifact %s for app=%s user=%s"
+ " session=%s: %s",
+ req.filename,
+ app_name,
+ user_id,
+ session_id,
+ exc,
+ exc_info=True,
+ )
+ raise HTTPException(status_code=500, detail=str(exc)) from exc
+ artifact_version = await self.artifact_service.get_artifact_version(
+ app_name=app_name,
+ user_id=user_id,
+ session_id=session_id,
+ filename=req.filename,
+ version=version,
+ )
+ if artifact_version is None:
+ raise HTTPException(
+ status_code=500, detail="Artifact metadata unavailable"
+ )
+ return artifact_version
+
@app.get(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts",
response_model_exclude_none=True,
diff --git a/src/google/adk/cli/cli.py b/src/google/adk/cli/cli.py
index ed294d3922..a1b63a4c46 100644
--- a/src/google/adk/cli/cli.py
+++ b/src/google/adk/cli/cli.py
@@ -15,6 +15,7 @@
from __future__ import annotations
from datetime import datetime
+from pathlib import Path
from typing import Optional
from typing import Union
@@ -22,21 +23,21 @@
from google.genai import types
from pydantic import BaseModel
-from ..agents.base_agent import BaseAgent
from ..agents.llm_agent import LlmAgent
from ..apps.app import App
from ..artifacts.base_artifact_service import BaseArtifactService
-from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
from ..auth.credential_service.base_credential_service import BaseCredentialService
from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService
from ..runners import Runner
from ..sessions.base_session_service import BaseSessionService
-from ..sessions.in_memory_session_service import InMemorySessionService
from ..sessions.session import Session
from ..utils.context_utils import Aclosing
from ..utils.env_utils import is_env_enabled
+from .service_registry import load_services_module
from .utils import envs
from .utils.agent_loader import AgentLoader
+from .utils.service_factory import create_artifact_service_from_options
+from .utils.service_factory import create_session_service_from_options
class InputFile(BaseModel):
@@ -66,7 +67,7 @@ async def run_input_file(
)
with open(input_path, 'r', encoding='utf-8') as f:
input_file = InputFile.model_validate_json(f.read())
- input_file.state['_time'] = datetime.now()
+ input_file.state['_time'] = datetime.now().isoformat()
session = await session_service.create_session(
app_name=app_name, user_id=user_id, state=input_file.state
@@ -134,6 +135,8 @@ async def run_cli(
saved_session_file: Optional[str] = None,
save_session: bool,
session_id: Optional[str] = None,
+ session_service_uri: Optional[str] = None,
+ artifact_service_uri: Optional[str] = None,
) -> None:
"""Runs an interactive CLI for a certain agent.
@@ -148,24 +151,48 @@ async def run_cli(
contains a previously saved session, exclusive with input_file.
save_session: bool, whether to save the session on exit.
session_id: Optional[str], the session ID to save the session to on exit.
+ session_service_uri: Optional[str], custom session service URI.
+ artifact_service_uri: Optional[str], custom artifact service URI.
"""
+ agent_parent_path = Path(agent_parent_dir).resolve()
+ agent_root = agent_parent_path / agent_folder_name
+ load_services_module(str(agent_root))
+ user_id = 'test_user'
- artifact_service = InMemoryArtifactService()
- session_service = InMemorySessionService()
- credential_service = InMemoryCredentialService()
+ # Create session and artifact services using factory functions
+ # Sessions persist under //.adk/session.db by default.
+ session_service = create_session_service_from_options(
+ base_dir=agent_parent_path,
+ session_service_uri=session_service_uri,
+ )
- user_id = 'test_user'
- agent_or_app = AgentLoader(agents_dir=agent_parent_dir).load_agent(
+ artifact_service = create_artifact_service_from_options(
+ base_dir=agent_root,
+ artifact_service_uri=artifact_service_uri,
+ )
+
+ credential_service = InMemoryCredentialService()
+ agents_dir = str(agent_parent_path)
+ agent_or_app = AgentLoader(agents_dir=agents_dir).load_agent(
agent_folder_name
)
session_app_name = (
agent_or_app.name if isinstance(agent_or_app, App) else agent_folder_name
)
- session = await session_service.create_session(
- app_name=session_app_name, user_id=user_id
- )
if not is_env_enabled('ADK_DISABLE_LOAD_DOTENV'):
- envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir)
+ envs.load_dotenv_for_agent(agent_folder_name, agents_dir)
+
+ # Helper function for printing events
+ def _print_event(event) -> None:
+ content = event.content
+ if not content or not content.parts:
+ return
+ text_parts = [part.text for part in content.parts if part.text]
+ if not text_parts:
+ return
+ author = event.author or 'system'
+ click.echo(f'[{author}]: {"".join(text_parts)}')
+
if input_file:
session = await run_input_file(
app_name=session_app_name,
@@ -177,16 +204,22 @@ async def run_cli(
input_path=input_file,
)
elif saved_session_file:
+ # Load the saved session from file
with open(saved_session_file, 'r', encoding='utf-8') as f:
loaded_session = Session.model_validate_json(f.read())
+ # Create a new session in the service, copying state from the file
+ session = await session_service.create_session(
+ app_name=session_app_name,
+ user_id=user_id,
+ state=loaded_session.state if loaded_session else None,
+ )
+
+ # Append events from the file to the new session and display them
if loaded_session:
for event in loaded_session.events:
await session_service.append_event(session, event)
- content = event.content
- if not content or not content.parts or not content.parts[0].text:
- continue
- click.echo(f'[{event.author}]: {content.parts[0].text}')
+ _print_event(event)
await run_interactively(
agent_or_app,
@@ -196,6 +229,9 @@ async def run_cli(
credential_service,
)
else:
+ session = await session_service.create_session(
+ app_name=session_app_name, user_id=user_id
+ )
click.echo(f'Running agent {agent_or_app.name}, type exit to exit.')
await run_interactively(
agent_or_app,
@@ -207,9 +243,7 @@ async def run_cli(
if save_session:
session_id = session_id or input('Session ID to save: ')
- session_path = (
- f'{agent_parent_dir}/{agent_folder_name}/{session_id}.session.json'
- )
+ session_path = agent_root / f'{session_id}.session.json'
# Fetch the session again to get all the details.
session = await session_service.get_session(
@@ -217,7 +251,9 @@ async def run_cli(
user_id=session.user_id,
session_id=session.id,
)
- with open(session_path, 'w', encoding='utf-8') as f:
- f.write(session.model_dump_json(indent=2, exclude_none=True))
+ session_path.write_text(
+ session.model_dump_json(indent=2, exclude_none=True, by_alias=True),
+ encoding='utf-8',
+ )
print('Session saved to', session_path)
diff --git a/src/google/adk/cli/cli_deploy.py b/src/google/adk/cli/cli_deploy.py
index 8da65d0e21..45dce7fda6 100644
--- a/src/google/adk/cli/cli_deploy.py
+++ b/src/google/adk/cli/cli_deploy.py
@@ -915,7 +915,7 @@ def to_agent_engine(
)
else:
if project and region and not agent_engine_id.startswith('projects/'):
- agent_engine_id = f'projects/{project}/locations/{region}/agentEngines/{agent_engine_id}'
+ agent_engine_id = f'projects/{project}/locations/{region}/reasoningEngines/{agent_engine_id}'
client.agent_engines.update(name=agent_engine_id, config=agent_config)
click.secho(f'✅ Updated agent engine: {agent_engine_id}', fg='green')
finally:
diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py
index 529ee7319c..5d228f72f3 100644
--- a/src/google/adk/cli/cli_tools_click.py
+++ b/src/google/adk/cli/cli_tools_click.py
@@ -24,6 +24,7 @@
import os
from pathlib import Path
import tempfile
+import textwrap
from typing import Optional
import click
@@ -108,6 +109,18 @@ def parse_args(self, ctx, args):
logger = logging.getLogger("google_adk." + __name__)
+_ADK_WEB_WARNING = (
+ "ADK Web is for development purposes. It has access to all data and"
+ " should not be used in production."
+)
+
+
+def _warn_if_with_ui(with_ui: bool) -> None:
+ """Warn when deploying with the developer UI enabled."""
+ if with_ui:
+ click.secho(f"WARNING: {_ADK_WEB_WARNING}", fg="yellow", err=True)
+
+
@click.group(context_settings={"max_content_width": 240})
@click.version_option(version.__version__)
def main():
@@ -354,7 +367,62 @@ def validate_exclusive(ctx, param, value):
return value
+def adk_services_options():
+ """Decorator to add ADK services options to click commands."""
+
+ def decorator(func):
+ @click.option(
+ "--session_service_uri",
+ help=textwrap.dedent(
+ """\
+ Optional. The URI of the session service.
+ - Leave unset to use the in-memory session service (default).
+ - Use 'agentengine://' to connect to Agent Engine
+ sessions. can either be the full qualified resource
+ name 'projects/abc/locations/us-central1/reasoningEngines/123' or
+ the resource id '123'.
+ - Use 'memory://' to run with the in-memory session service.
+ - Use 'sqlite://' to connect to a SQLite DB.
+ - See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported database URIs."""
+ ),
+ )
+ @click.option(
+ "--artifact_service_uri",
+ type=str,
+ help=textwrap.dedent(
+ """\
+ Optional. The URI of the artifact service.
+ - Leave unset to store artifacts under '.adk/artifacts' locally.
+ - Use 'gs://' to connect to the GCS artifact service.
+ - Use 'memory://' to force the in-memory artifact service.
+ - Use 'file://' to store artifacts in a custom local directory."""
+ ),
+ default=None,
+ )
+ @click.option(
+ "--memory_service_uri",
+ type=str,
+ help=textwrap.dedent("""\
+ Optional. The URI of the memory service.
+ - Use 'rag://' to connect to Vertex AI Rag Memory Service.
+ - Use 'agentengine://' to connect to Agent Engine
+ sessions. can either be the full qualified resource
+ name 'projects/abc/locations/us-central1/reasoningEngines/123' or
+ the resource id '123'.
+ - Use 'memory://' to force the in-memory memory service."""),
+ default=None,
+ )
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ return func(*args, **kwargs)
+
+ return wrapper
+
+ return decorator
+
+
@main.command("run", cls=HelpfulCommand)
+@adk_services_options()
@click.option(
"--save_session",
type=bool,
@@ -409,6 +477,9 @@ def cli_run(
session_id: Optional[str],
replay: Optional[str],
resume: Optional[str],
+ session_service_uri: Optional[str] = None,
+ artifact_service_uri: Optional[str] = None,
+ memory_service_uri: Optional[str] = None,
):
"""Runs an interactive CLI for a certain agent.
@@ -420,6 +491,14 @@ def cli_run(
"""
logs.log_to_tmp_folder()
+ # Validation warning for memory_service_uri (not supported for adk run)
+ if memory_service_uri:
+ click.secho(
+ "WARNING: --memory_service_uri is not supported for adk run.",
+ fg="yellow",
+ err=True,
+ )
+
agent_parent_folder = os.path.dirname(agent)
agent_folder_name = os.path.basename(agent)
@@ -431,6 +510,8 @@ def cli_run(
saved_session_file=resume,
save_session=save_session,
session_id=session_id,
+ session_service_uri=session_service_uri,
+ artifact_service_uri=artifact_service_uri,
)
)
@@ -557,7 +638,7 @@ def cli_eval(
from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager
from ..evaluation.local_eval_sets_manager import load_eval_set_from_file
from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
- from ..evaluation.user_simulator_provider import UserSimulatorProvider
+ from ..evaluation.simulation.user_simulator_provider import UserSimulatorProvider
from .cli_eval import _collect_eval_results
from .cli_eval import _collect_inferences
from .cli_eval import get_root_agent
@@ -865,55 +946,6 @@ def wrapper(*args, **kwargs):
return decorator
-def adk_services_options():
- """Decorator to add ADK services options to click commands."""
-
- def decorator(func):
- @click.option(
- "--session_service_uri",
- help=(
- """Optional. The URI of the session service.
- - Use 'agentengine://' to connect to Agent Engine
- sessions. can either be the full qualified resource
- name 'projects/abc/locations/us-central1/reasoningEngines/123' or
- the resource id '123'.
- - Use 'sqlite://' to connect to an aio-sqlite
- based session service, which is good for local development.
- - Use 'postgresql://:@:/'
- to connect to a PostgreSQL DB.
- - See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls
- for more details on other database URIs supported by SQLAlchemy."""
- ),
- )
- @click.option(
- "--artifact_service_uri",
- type=str,
- help=(
- "Optional. The URI of the artifact service,"
- " supported URIs: gs:// for GCS artifact service."
- ),
- default=None,
- )
- @click.option(
- "--memory_service_uri",
- type=str,
- help=("""Optional. The URI of the memory service.
- - Use 'rag://' to connect to Vertex AI Rag Memory Service.
- - Use 'agentengine://' to connect to Agent Engine
- sessions. can either be the full qualified resource
- name 'projects/abc/locations/us-central1/reasoningEngines/123' or
- the resource id '123'."""),
- default=None,
- )
- @functools.wraps(func)
- def wrapper(*args, **kwargs):
- return func(*args, **kwargs)
-
- return wrapper
-
- return decorator
-
-
def deprecated_adk_services_options():
"""Deprecated ADK services options."""
@@ -921,7 +953,7 @@ def warn(alternative_param, ctx, param, value):
if value:
click.echo(
click.style(
- f"WARNING: Deprecated option {param.name} is used. Please use"
+ f"WARNING: Deprecated option --{param.name} is used. Please use"
f" {alternative_param} instead.",
fg="yellow",
),
@@ -1116,6 +1148,8 @@ def cli_web(
adk web --session_service_uri=[uri] --port=[port] path/to/agents_dir
"""
+ session_service_uri = session_service_uri or session_db_url
+ artifact_service_uri = artifact_service_uri or artifact_storage_uri
logs.setup_adk_logger(getattr(logging, log_level.upper()))
@asynccontextmanager
@@ -1140,8 +1174,6 @@ async def _lifespan(app: FastAPI):
fg="green",
)
- session_service_uri = session_service_uri or session_db_url
- artifact_service_uri = artifact_service_uri or artifact_storage_uri
app = get_fast_api_app(
agents_dir=agents_dir,
session_service_uri=session_service_uri,
@@ -1215,10 +1247,10 @@ def cli_api_server(
adk api_server --session_service_uri=[uri] --port=[port] path/to/agents_dir
"""
- logs.setup_adk_logger(getattr(logging, log_level.upper()))
-
session_service_uri = session_service_uri or session_db_url
artifact_service_uri = artifact_service_uri or artifact_storage_uri
+ logs.setup_adk_logger(getattr(logging, log_level.upper()))
+
config = uvicorn.Config(
get_fast_api_app(
agents_dir=agents_dir,
@@ -1408,6 +1440,8 @@ def cli_deploy_cloud_run(
err=True,
)
+ _warn_if_with_ui(with_ui)
+
session_service_uri = session_service_uri or session_db_url
artifact_service_uri = artifact_service_uri or artifact_storage_uri
@@ -1792,6 +1826,7 @@ def cli_deploy_gke(
--cluster_name=[cluster_name] path/to/my_agent
"""
try:
+ _warn_if_with_ui(with_ui)
cli_deploy.to_gke(
agent_folder=agent,
project=project,
diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py
index eec6bb646b..df06b1cf4c 100644
--- a/src/google/adk/cli/fast_api.py
+++ b/src/google/adk/cli/fast_api.py
@@ -14,6 +14,7 @@
from __future__ import annotations
+import importlib
import json
import logging
import os
@@ -51,6 +52,26 @@
logger = logging.getLogger("google_adk." + __name__)
+_LAZY_SERVICE_IMPORTS: dict[str, str] = {
+ "AgentLoader": ".utils.agent_loader",
+ "InMemoryArtifactService": "..artifacts.in_memory_artifact_service",
+ "InMemoryMemoryService": "..memory.in_memory_memory_service",
+ "InMemorySessionService": "..sessions.in_memory_session_service",
+ "LocalEvalSetResultsManager": "..evaluation.local_eval_set_results_manager",
+ "LocalEvalSetsManager": "..evaluation.local_eval_sets_manager",
+}
+
+
+def __getattr__(name: str):
+ """Lazily import defaults so patching in tests keeps working."""
+ if name not in _LAZY_SERVICE_IMPORTS:
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
+
+ module = importlib.import_module(_LAZY_SERVICE_IMPORTS[name], __package__)
+ attr = getattr(module, name)
+ globals()[name] = attr
+ return attr
+
def get_fast_api_app(
*,
@@ -321,25 +342,14 @@ async def get_agent_builder(
)
if a2a:
- try:
- from a2a.server.apps import A2AStarletteApplication
- from a2a.server.request_handlers import DefaultRequestHandler
- from a2a.server.tasks import InMemoryTaskStore
- from a2a.types import AgentCard
- from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH
-
- from ..a2a.executor.a2a_agent_executor import A2aAgentExecutor
-
- except ImportError as e:
- import sys
-
- if sys.version_info < (3, 10):
- raise ImportError(
- "A2A requires Python 3.10 or above. Please upgrade your Python"
- " version."
- ) from e
- else:
- raise e
+ from a2a.server.apps import A2AStarletteApplication
+ from a2a.server.request_handlers import DefaultRequestHandler
+ from a2a.server.tasks import InMemoryTaskStore
+ from a2a.types import AgentCard
+ from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH
+
+ from ..a2a.executor.a2a_agent_executor import A2aAgentExecutor
+
# locate all a2a agent apps in the agents directory
base_path = Path.cwd() / agents_dir
# the root agents directory should be an existing folder
diff --git a/src/google/adk/cli/service_registry.py b/src/google/adk/cli/service_registry.py
index 9f23b73035..3e7921e075 100644
--- a/src/google/adk/cli/service_registry.py
+++ b/src/google/adk/cli/service_registry.py
@@ -76,7 +76,6 @@ def my_session_factory(uri: str, **kwargs):
from ..artifacts.base_artifact_service import BaseArtifactService
from ..memory.base_memory_service import BaseMemoryService
-from ..sessions import InMemorySessionService
from ..sessions.base_session_service import BaseSessionService
from ..utils import yaml_utils
@@ -272,18 +271,14 @@ def gcs_artifact_factory(uri: str, **kwargs):
kwargs_copy = kwargs.copy()
kwargs_copy.pop("agents_dir", None)
+ kwargs_copy.pop("per_agent", None)
parsed_uri = urlparse(uri)
bucket_name = parsed_uri.netloc
return GcsArtifactService(bucket_name=bucket_name, **kwargs_copy)
- def file_artifact_factory(uri: str, **kwargs):
+ def file_artifact_factory(uri: str, **_):
from ..artifacts.file_artifact_service import FileArtifactService
- per_agent = kwargs.get("per_agent", False)
- if per_agent:
- raise ValueError(
- "file:// artifact URIs are not supported in multi-agent mode."
- )
parsed_uri = urlparse(uri)
if parsed_uri.netloc not in ("", "localhost"):
raise ValueError(
diff --git a/src/google/adk/cli/utils/__init__.py b/src/google/adk/cli/utils/__init__.py
index 8aa11b252b..1800f5d04c 100644
--- a/src/google/adk/cli/utils/__init__.py
+++ b/src/google/adk/cli/utils/__init__.py
@@ -18,8 +18,10 @@
from ...agents.base_agent import BaseAgent
from ...agents.llm_agent import LlmAgent
+from .dot_adk_folder import DotAdkFolder
from .state import create_empty_state
__all__ = [
'create_empty_state',
+ 'DotAdkFolder',
]
diff --git a/src/google/adk/cli/utils/agent_loader.py b/src/google/adk/cli/utils/agent_loader.py
index 9f01705d4f..d6965e5bbb 100644
--- a/src/google/adk/cli/utils/agent_loader.py
+++ b/src/google/adk/cli/utils/agent_loader.py
@@ -20,6 +20,8 @@
import os
from pathlib import Path
import sys
+from typing import Any
+from typing import Literal
from typing import Optional
from typing import Union
@@ -56,7 +58,7 @@ class AgentLoader(BaseAgentLoader):
"""
def __init__(self, agents_dir: str):
- self.agents_dir = agents_dir.rstrip("/")
+ self.agents_dir = str(Path(agents_dir))
self._original_sys_path = None
self._agent_cache: dict[str, Union[BaseAgent, App]] = {}
@@ -270,12 +272,13 @@ def _perform_load(self, agent_name: str) -> Union[BaseAgent, App]:
f"No root_agent found for '{agent_name}'. Searched in"
f" '{actual_agent_name}.agent.root_agent',"
f" '{actual_agent_name}.root_agent' and"
- f" '{actual_agent_name}/root_agent.yaml'.\n\nExpected directory"
- f" structure:\n /\n {actual_agent_name}/\n "
- " agent.py (with root_agent) OR\n root_agent.yaml\n\nThen run:"
- f" adk web \n\nEnsure '{agents_dir}/{actual_agent_name}' is"
- " structured correctly, an .env file can be loaded if present, and a"
- f" root_agent is exposed.{hint}"
+ f" '{actual_agent_name}{os.sep}root_agent.yaml'.\n\nExpected directory"
+ f" structure:\n {os.sep}\n "
+ f" {actual_agent_name}{os.sep}\n agent.py (with root_agent) OR\n "
+ " root_agent.yaml\n\nThen run: adk web \n\nEnsure"
+ f" '{os.path.join(agents_dir, actual_agent_name)}' is structured"
+ " correctly, an .env file can be loaded if present, and a root_agent"
+ f" is exposed.{hint}"
)
def _record_origin_metadata(
@@ -341,6 +344,50 @@ def list_agents(self) -> list[str]:
agent_names.sort()
return agent_names
+ def list_agents_detailed(self) -> list[dict[str, Any]]:
+ """Lists all agents with detailed metadata (name, description, type)."""
+ agent_names = self.list_agents()
+ apps_info = []
+
+ for agent_name in agent_names:
+ try:
+ loaded = self.load_agent(agent_name)
+ if isinstance(loaded, App):
+ agent = loaded.root_agent
+ else:
+ agent = loaded
+
+ language = self._determine_agent_language(agent_name)
+
+ app_info = {
+ "name": agent_name,
+ "root_agent_name": agent.name,
+ "description": agent.description,
+ "language": language,
+ }
+ apps_info.append(app_info)
+
+ except Exception as e:
+ logger.error("Failed to load agent '%s': %s", agent_name, e)
+ continue
+
+ return apps_info
+
+ def _determine_agent_language(
+ self, agent_name: str
+ ) -> Literal["yaml", "python"]:
+ """Determine the type of agent based on file structure."""
+ base_path = Path.cwd() / self.agents_dir / agent_name
+
+ if (base_path / "root_agent.yaml").exists():
+ return "yaml"
+ elif (base_path / "agent.py").exists():
+ return "python"
+ elif (base_path / "__init__.py").exists():
+ return "python"
+
+ raise ValueError(f"Could not determine agent type for '{agent_name}'.")
+
def remove_agent_from_cache(self, agent_name: str):
# Clear module cache for the agent and its submodules
keys_to_delete = [
diff --git a/src/google/adk/cli/utils/base_agent_loader.py b/src/google/adk/cli/utils/base_agent_loader.py
index d62a6b8651..bcef0dae42 100644
--- a/src/google/adk/cli/utils/base_agent_loader.py
+++ b/src/google/adk/cli/utils/base_agent_loader.py
@@ -18,6 +18,7 @@
from abc import ABC
from abc import abstractmethod
+from typing import Any
from typing import Union
from ...agents.base_agent import BaseAgent
@@ -34,3 +35,15 @@ def load_agent(self, agent_name: str) -> Union[BaseAgent, App]:
@abstractmethod
def list_agents(self) -> list[str]:
"""Lists all agents available in the agent loader in alphabetical order."""
+
+ def list_agents_detailed(self) -> list[dict[str, Any]]:
+ agent_names = self.list_agents()
+ return [
+ {
+ 'name': name,
+ 'display_name': None,
+ 'description': None,
+ 'type': None,
+ }
+ for name in agent_names
+ ]
diff --git a/src/google/adk/cli/utils/local_storage.py b/src/google/adk/cli/utils/local_storage.py
index 9e6b3f3d54..ec7099b8c8 100644
--- a/src/google/adk/cli/utils/local_storage.py
+++ b/src/google/adk/cli/utils/local_storage.py
@@ -57,14 +57,39 @@ def create_local_database_session_service(
return SqliteSessionService(db_path=str(session_db_path))
+def create_local_session_service(
+ *,
+ base_dir: Path | str,
+ per_agent: bool = False,
+) -> BaseSessionService:
+ """Creates a local SQLite-backed session service.
+
+ Args:
+ base_dir: The base directory for the agent(s).
+ per_agent: If True, creates a PerAgentDatabaseSessionService that stores
+ sessions in each agent's .adk folder. If False, creates a single
+ SqliteSessionService at base_dir/.adk/session.db.
+
+ Returns:
+ A BaseSessionService instance backed by SQLite.
+ """
+ if per_agent:
+ logger.info(
+ "Using per-agent session storage rooted at %s",
+ base_dir,
+ )
+ return PerAgentDatabaseSessionService(agents_root=base_dir)
+
+ return create_local_database_session_service(base_dir=base_dir)
+
+
def create_local_artifact_service(
- *, base_dir: Path | str, per_agent: bool = False
+ *, base_dir: Path | str
) -> BaseArtifactService:
"""Creates a file-backed artifact service rooted in `.adk/artifacts`.
Args:
base_dir: Directory whose `.adk` folder will store artifacts.
- per_agent: Indicates whether the service is being used in multi-agent mode.
Returns:
A `FileArtifactService` scoped to the derived root directory.
@@ -72,13 +97,7 @@ def create_local_artifact_service(
manager = DotAdkFolder(base_dir)
artifact_root = manager.artifacts_dir
artifact_root.mkdir(parents=True, exist_ok=True)
- if per_agent:
- logger.info(
- "Using shared file artifact service rooted at %s for multi-agent mode",
- artifact_root,
- )
- else:
- logger.info("Using file artifact service at %s", artifact_root)
+ logger.info("Using file artifact service at %s", artifact_root)
return FileArtifactService(root_dir=artifact_root)
diff --git a/src/google/adk/cli/utils/service_factory.py b/src/google/adk/cli/utils/service_factory.py
new file mode 100644
index 0000000000..60f4ddd3cf
--- /dev/null
+++ b/src/google/adk/cli/utils/service_factory.py
@@ -0,0 +1,120 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import annotations
+
+import logging
+from pathlib import Path
+from typing import Any
+from typing import Optional
+
+from ...artifacts.base_artifact_service import BaseArtifactService
+from ...memory.base_memory_service import BaseMemoryService
+from ...sessions.base_session_service import BaseSessionService
+from ..service_registry import get_service_registry
+from .local_storage import create_local_artifact_service
+from .local_storage import create_local_session_service
+
+logger = logging.getLogger("google_adk." + __name__)
+
+
+def create_session_service_from_options(
+ *,
+ base_dir: Path | str,
+ session_service_uri: Optional[str] = None,
+ session_db_kwargs: Optional[dict[str, Any]] = None,
+) -> BaseSessionService:
+ """Creates a session service based on CLI/web options."""
+ base_path = Path(base_dir)
+ registry = get_service_registry()
+
+ kwargs: dict[str, Any] = {
+ "agents_dir": str(base_path),
+ }
+ if session_db_kwargs:
+ kwargs.update(session_db_kwargs)
+
+ if session_service_uri:
+ logger.info("Using session service URI: %s", session_service_uri)
+ service = registry.create_session_service(session_service_uri, **kwargs)
+ if service is not None:
+ return service
+
+ # Fallback to DatabaseSessionService if the registry doesn't support the
+ # session service URI scheme. This keeps support for SQLAlchemy-compatible
+ # databases like AlloyDB or Cloud Spanner without explicit registration.
+ from ...sessions.database_session_service import DatabaseSessionService
+
+ fallback_kwargs = dict(kwargs)
+ fallback_kwargs.pop("agents_dir", None)
+ logger.info(
+ "Falling back to DatabaseSessionService for URI: %s",
+ session_service_uri,
+ )
+ return DatabaseSessionService(db_url=session_service_uri, **fallback_kwargs)
+
+ # Default to per-agent local SQLite storage in //.adk/.
+ return create_local_session_service(base_dir=base_path, per_agent=True)
+
+
+def create_memory_service_from_options(
+ *,
+ base_dir: Path | str,
+ memory_service_uri: Optional[str] = None,
+) -> BaseMemoryService:
+ """Creates a memory service based on CLI/web options."""
+ base_path = Path(base_dir)
+ registry = get_service_registry()
+
+ if memory_service_uri:
+ logger.info("Using memory service URI: %s", memory_service_uri)
+ service = registry.create_memory_service(
+ memory_service_uri,
+ agents_dir=str(base_path),
+ )
+ if service is None:
+ raise ValueError(f"Unsupported memory service URI: {memory_service_uri}")
+ return service
+
+ logger.info("Using in-memory memory service")
+ from ...memory.in_memory_memory_service import InMemoryMemoryService
+
+ return InMemoryMemoryService()
+
+
+def create_artifact_service_from_options(
+ *,
+ base_dir: Path | str,
+ artifact_service_uri: Optional[str] = None,
+) -> BaseArtifactService:
+ """Creates an artifact service based on CLI/web options."""
+ base_path = Path(base_dir)
+ registry = get_service_registry()
+
+ if artifact_service_uri:
+ logger.info("Using artifact service URI: %s", artifact_service_uri)
+ service = registry.create_artifact_service(
+ artifact_service_uri,
+ agents_dir=str(base_path),
+ )
+ if service is None:
+ logger.warning(
+ "Unsupported artifact service URI: %s, falling back to in-memory",
+ artifact_service_uri,
+ )
+ from ...artifacts.in_memory_artifact_service import InMemoryArtifactService
+
+ return InMemoryArtifactService()
+ return service
+
+ return create_local_artifact_service(base_dir=base_path)
diff --git a/src/google/adk/code_executors/unsafe_local_code_executor.py b/src/google/adk/code_executors/unsafe_local_code_executor.py
index 6dd2ae9d8c..b47fbd17e9 100644
--- a/src/google/adk/code_executors/unsafe_local_code_executor.py
+++ b/src/google/adk/code_executors/unsafe_local_code_executor.py
@@ -72,7 +72,7 @@ def execute_code(
_prepare_globals(code_execution_input.code, globals_)
stdout = io.StringIO()
with redirect_stdout(stdout):
- exec(code_execution_input.code, globals_)
+ exec(code_execution_input.code, globals_, globals_)
output = stdout.getvalue()
except Exception as e:
error = str(e)
diff --git a/src/google/adk/errors/input_validation_error.py b/src/google/adk/errors/input_validation_error.py
new file mode 100644
index 0000000000..76b1625a10
--- /dev/null
+++ b/src/google/adk/errors/input_validation_error.py
@@ -0,0 +1,28 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+
+class InputValidationError(ValueError):
+ """Represents an error raised when user input fails validation."""
+
+ def __init__(self, message="Invalid input."):
+ """Initializes the InputValidationError exception.
+
+ Args:
+ message (str): A message describing why the input is invalid.
+ """
+ self.message = message
+ super().__init__(self.message)
diff --git a/src/google/adk/evaluation/agent_evaluator.py b/src/google/adk/evaluation/agent_evaluator.py
index cafa712f56..514681bfa9 100644
--- a/src/google/adk/evaluation/agent_evaluator.py
+++ b/src/google/adk/evaluation/agent_evaluator.py
@@ -50,7 +50,7 @@
from .evaluator import EvalStatus
from .in_memory_eval_sets_manager import InMemoryEvalSetsManager
from .local_eval_sets_manager import convert_eval_set_to_pydantic_schema
-from .user_simulator_provider import UserSimulatorProvider
+from .simulation.user_simulator_provider import UserSimulatorProvider
logger = logging.getLogger("google_adk." + __name__)
diff --git a/src/google/adk/evaluation/eval_config.py b/src/google/adk/evaluation/eval_config.py
index d5b94af5e1..13b2e92274 100644
--- a/src/google/adk/evaluation/eval_config.py
+++ b/src/google/adk/evaluation/eval_config.py
@@ -27,7 +27,7 @@
from ..evaluation.eval_metrics import EvalMetric
from .eval_metrics import BaseCriterion
from .eval_metrics import Threshold
-from .user_simulator import BaseUserSimulatorConfig
+from .simulation.user_simulator import BaseUserSimulatorConfig
logger = logging.getLogger("google_adk." + __name__)
diff --git a/src/google/adk/evaluation/evaluation_generator.py b/src/google/adk/evaluation/evaluation_generator.py
index e9c7dc5436..5d8b48c150 100644
--- a/src/google/adk/evaluation/evaluation_generator.py
+++ b/src/google/adk/evaluation/evaluation_generator.py
@@ -45,9 +45,9 @@
from .eval_case import SessionInput
from .eval_set import EvalSet
from .request_intercepter_plugin import _RequestIntercepterPlugin
-from .user_simulator import Status as UserSimulatorStatus
-from .user_simulator import UserSimulator
-from .user_simulator_provider import UserSimulatorProvider
+from .simulation.user_simulator import Status as UserSimulatorStatus
+from .simulation.user_simulator import UserSimulator
+from .simulation.user_simulator_provider import UserSimulatorProvider
_USER_AUTHOR = "user"
_DEFAULT_AUTHOR = "agent"
diff --git a/src/google/adk/evaluation/local_eval_service.py b/src/google/adk/evaluation/local_eval_service.py
index 806a8d690d..f454266e00 100644
--- a/src/google/adk/evaluation/local_eval_service.py
+++ b/src/google/adk/evaluation/local_eval_service.py
@@ -31,6 +31,8 @@
from ..memory.base_memory_service import BaseMemoryService
from ..sessions.base_session_service import BaseSessionService
from ..sessions.in_memory_session_service import InMemorySessionService
+from ..utils._client_labels_utils import client_label_context
+from ..utils._client_labels_utils import EVAL_CLIENT_LABEL
from ..utils.feature_decorator import experimental
from .base_eval_service import BaseEvalService
from .base_eval_service import EvaluateConfig
@@ -53,7 +55,7 @@
from .evaluator import PerInvocationResult
from .metric_evaluator_registry import DEFAULT_METRIC_EVALUATOR_REGISTRY
from .metric_evaluator_registry import MetricEvaluatorRegistry
-from .user_simulator_provider import UserSimulatorProvider
+from .simulation.user_simulator_provider import UserSimulatorProvider
logger = logging.getLogger('google_adk.' + __name__)
@@ -249,11 +251,12 @@ async def _evaluate_single_inference_result(
for eval_metric in evaluate_config.eval_metrics:
# Perform evaluation of the metric.
try:
- evaluation_result = await self._evaluate_metric(
- eval_metric=eval_metric,
- actual_invocations=inference_result.inferences,
- expected_invocations=eval_case.conversation,
- )
+ with client_label_context(EVAL_CLIENT_LABEL):
+ evaluation_result = await self._evaluate_metric(
+ eval_metric=eval_metric,
+ actual_invocations=inference_result.inferences,
+ expected_invocations=eval_case.conversation,
+ )
except Exception as e:
# We intentionally catch the Exception as we don't want failures to
# affect other metric evaluation.
@@ -403,17 +406,18 @@ async def _perform_inference_single_eval_item(
)
try:
- inferences = (
- await EvaluationGenerator._generate_inferences_from_root_agent(
- root_agent=root_agent,
- user_simulator=self._user_simulator_provider.provide(eval_case),
- initial_session=initial_session,
- session_id=session_id,
- session_service=self._session_service,
- artifact_service=self._artifact_service,
- memory_service=self._memory_service,
- )
- )
+ with client_label_context(EVAL_CLIENT_LABEL):
+ inferences = (
+ await EvaluationGenerator._generate_inferences_from_root_agent(
+ root_agent=root_agent,
+ user_simulator=self._user_simulator_provider.provide(eval_case),
+ initial_session=initial_session,
+ session_id=session_id,
+ session_service=self._session_service,
+ artifact_service=self._artifact_service,
+ memory_service=self._memory_service,
+ )
+ )
inference_result.inferences = inferences
inference_result.status = InferenceStatus.SUCCESS
diff --git a/src/google/adk/evaluation/simulation/__init__.py b/src/google/adk/evaluation/simulation/__init__.py
new file mode 100644
index 0000000000..0a2669d7a2
--- /dev/null
+++ b/src/google/adk/evaluation/simulation/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/src/google/adk/evaluation/llm_backed_user_simulator.py b/src/google/adk/evaluation/simulation/llm_backed_user_simulator.py
similarity index 96%
rename from src/google/adk/evaluation/llm_backed_user_simulator.py
rename to src/google/adk/evaluation/simulation/llm_backed_user_simulator.py
index 2fbfcc44d1..4af228772d 100644
--- a/src/google/adk/evaluation/llm_backed_user_simulator.py
+++ b/src/google/adk/evaluation/simulation/llm_backed_user_simulator.py
@@ -22,14 +22,14 @@
from pydantic import Field
from typing_extensions import override
-from ..events.event import Event
-from ..models.llm_request import LlmRequest
-from ..models.registry import LLMRegistry
-from ..utils.context_utils import Aclosing
-from ..utils.feature_decorator import experimental
-from ._retry_options_utils import add_default_retry_options_if_not_present
-from .conversation_scenarios import ConversationScenario
-from .evaluator import Evaluator
+from ...events.event import Event
+from ...models.llm_request import LlmRequest
+from ...models.registry import LLMRegistry
+from ...utils.context_utils import Aclosing
+from ...utils.feature_decorator import experimental
+from .._retry_options_utils import add_default_retry_options_if_not_present
+from ..conversation_scenarios import ConversationScenario
+from ..evaluator import Evaluator
from .user_simulator import BaseUserSimulatorConfig
from .user_simulator import NextUserMessage
from .user_simulator import Status
diff --git a/src/google/adk/evaluation/static_user_simulator.py b/src/google/adk/evaluation/simulation/static_user_simulator.py
similarity index 93%
rename from src/google/adk/evaluation/static_user_simulator.py
rename to src/google/adk/evaluation/simulation/static_user_simulator.py
index 4c5e2cb54d..e1de18a706 100644
--- a/src/google/adk/evaluation/static_user_simulator.py
+++ b/src/google/adk/evaluation/simulation/static_user_simulator.py
@@ -19,10 +19,10 @@
from typing_extensions import override
-from ..events.event import Event
-from ..utils.feature_decorator import experimental
-from .eval_case import StaticConversation
-from .evaluator import Evaluator
+from ...events.event import Event
+from ...utils.feature_decorator import experimental
+from ..eval_case import StaticConversation
+from ..evaluator import Evaluator
from .user_simulator import BaseUserSimulatorConfig
from .user_simulator import NextUserMessage
from .user_simulator import Status
diff --git a/src/google/adk/evaluation/user_simulator.py b/src/google/adk/evaluation/simulation/user_simulator.py
similarity index 95%
rename from src/google/adk/evaluation/user_simulator.py
rename to src/google/adk/evaluation/simulation/user_simulator.py
index c5ab013d7c..57656b76de 100644
--- a/src/google/adk/evaluation/user_simulator.py
+++ b/src/google/adk/evaluation/simulation/user_simulator.py
@@ -26,10 +26,10 @@
from pydantic import model_validator
from pydantic import ValidationError
-from ..events.event import Event
-from ..utils.feature_decorator import experimental
-from .common import EvalBaseModel
-from .evaluator import Evaluator
+from ...events.event import Event
+from ...utils.feature_decorator import experimental
+from ..common import EvalBaseModel
+from ..evaluator import Evaluator
class BaseUserSimulatorConfig(BaseModel):
diff --git a/src/google/adk/evaluation/user_simulator_provider.py b/src/google/adk/evaluation/simulation/user_simulator_provider.py
similarity index 97%
rename from src/google/adk/evaluation/user_simulator_provider.py
rename to src/google/adk/evaluation/simulation/user_simulator_provider.py
index 1aea8c8c92..b1bfd3226c 100644
--- a/src/google/adk/evaluation/user_simulator_provider.py
+++ b/src/google/adk/evaluation/simulation/user_simulator_provider.py
@@ -16,8 +16,8 @@
from typing import Optional
-from ..utils.feature_decorator import experimental
-from .eval_case import EvalCase
+from ...utils.feature_decorator import experimental
+from ..eval_case import EvalCase
from .llm_backed_user_simulator import LlmBackedUserSimulator
from .static_user_simulator import StaticUserSimulator
from .user_simulator import BaseUserSimulatorConfig
diff --git a/src/google/adk/flows/llm_flows/agent_transfer.py b/src/google/adk/flows/llm_flows/agent_transfer.py
index 037b8c6d50..ea144bf75d 100644
--- a/src/google/adk/flows/llm_flows/agent_transfer.py
+++ b/src/google/adk/flows/llm_flows/agent_transfer.py
@@ -24,9 +24,8 @@
from ...agents.invocation_context import InvocationContext
from ...events.event import Event
from ...models.llm_request import LlmRequest
-from ...tools.function_tool import FunctionTool
from ...tools.tool_context import ToolContext
-from ...tools.transfer_to_agent_tool import transfer_to_agent
+from ...tools.transfer_to_agent_tool import TransferToAgentTool
from ._base_llm_processor import BaseLlmRequestProcessor
if typing.TYPE_CHECKING:
@@ -50,13 +49,18 @@ async def run_async(
if not transfer_targets:
return
+ transfer_to_agent_tool = TransferToAgentTool(
+ agent_names=[agent.name for agent in transfer_targets]
+ )
+
llm_request.append_instructions([
_build_target_agents_instructions(
- invocation_context.agent, transfer_targets
+ transfer_to_agent_tool.name,
+ invocation_context.agent,
+ transfer_targets,
)
])
- transfer_to_agent_tool = FunctionTool(func=transfer_to_agent)
tool_context = ToolContext(invocation_context)
await transfer_to_agent_tool.process_llm_request(
tool_context=tool_context, llm_request=llm_request
@@ -80,10 +84,13 @@ def _build_target_agents_info(target_agent: BaseAgent) -> str:
def _build_target_agents_instructions(
- agent: LlmAgent, target_agents: list[BaseAgent]
+ tool_name: str,
+ agent: LlmAgent,
+ target_agents: list[BaseAgent],
) -> str:
# Build list of available agent names for the NOTE
- # target_agents already includes parent agent if applicable, so no need to add it again
+ # target_agents already includes parent agent if applicable,
+ # so no need to add it again
available_agent_names = [target_agent.name for target_agent in target_agents]
# Sort for consistency
@@ -101,15 +108,16 @@ def _build_target_agents_instructions(
_build_target_agents_info(target_agent) for target_agent in target_agents
])}
-If you are the best to answer the question according to your description, you
-can answer it.
+If you are the best to answer the question according to your description,
+you can answer it.
If another agent is better for answering the question according to its
-description, call `{_TRANSFER_TO_AGENT_FUNCTION_NAME}` function to transfer the
-question to that agent. When transferring, do not generate any text other than
-the function call.
+description, call `{tool_name}` function to transfer the question to that
+agent. When transferring, do not generate any text other than the function
+call.
-**NOTE**: the only available agents for `{_TRANSFER_TO_AGENT_FUNCTION_NAME}` function are {formatted_agent_names}.
+**NOTE**: the only available agents for `{tool_name}` function are
+{formatted_agent_names}.
"""
if agent.parent_agent and not agent.disallow_transfer_to_parent:
@@ -119,9 +127,6 @@ def _build_target_agents_instructions(
return si
-_TRANSFER_TO_AGENT_FUNCTION_NAME = transfer_to_agent.__name__
-
-
def _get_transfer_targets(agent: LlmAgent) -> list[BaseAgent]:
from ...agents.llm_agent import LlmAgent
diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py
index db50e77809..824cd26be1 100644
--- a/src/google/adk/flows/llm_flows/base_llm_flow.py
+++ b/src/google/adk/flows/llm_flows/base_llm_flow.py
@@ -156,7 +156,7 @@ async def run_live(
break
logger.debug('Receive new event: %s', event)
yield event
- # send back the function response
+ # send back the function response to models
if event.get_function_responses():
logger.debug(
'Sending back last function response event: %s', event
@@ -164,6 +164,16 @@ async def run_live(
invocation_context.live_request_queue.send_content(
event.content
)
+ # We handle agent transfer here in `run_live` rather than
+ # in `_postprocess_live` to prevent duplication of function
+ # response processing. If agent transfer were handled in
+ # `_postprocess_live`, events yielded from child agent's
+ # `run_live` would bubble up to parent agent's `run_live`,
+ # causing `event.get_function_responses()` to be true in both
+ # child and parent, and `send_content()` to be called twice for
+ # the same function response. By handling agent transfer here,
+ # we ensure that only child agent processes its own function
+ # responses after the transfer.
if (
event.content
and event.content.parts
@@ -174,7 +184,21 @@ async def run_live(
await asyncio.sleep(DEFAULT_TRANSFER_AGENT_DELAY)
# cancel the tasks that belongs to the closed connection.
send_task.cancel()
+ logger.debug('Closing live connection')
await llm_connection.close()
+ logger.debug('Live connection closed.')
+ # transfer to the sub agent.
+ transfer_to_agent = event.actions.transfer_to_agent
+ if transfer_to_agent:
+ logger.debug('Transferring to agent: %s', transfer_to_agent)
+ agent_to_run = self._get_agent_to_run(
+ invocation_context, transfer_to_agent
+ )
+ async with Aclosing(
+ agent_to_run.run_live(invocation_context)
+ ) as agen:
+ async for item in agen:
+ yield item
if (
event.content
and event.content.parts
@@ -638,15 +662,6 @@ async def _postprocess_live(
)
yield final_event
- transfer_to_agent = function_response_event.actions.transfer_to_agent
- if transfer_to_agent:
- agent_to_run = self._get_agent_to_run(
- invocation_context, transfer_to_agent
- )
- async with Aclosing(agent_to_run.run_live(invocation_context)) as agen:
- async for item in agen:
- yield item
-
async def _postprocess_run_processors_async(
self, invocation_context: InvocationContext, llm_response: LlmResponse
) -> AsyncGenerator[Event, None]:
diff --git a/src/google/adk/flows/llm_flows/contents.py b/src/google/adk/flows/llm_flows/contents.py
index 9274cd462d..fefa014c45 100644
--- a/src/google/adk/flows/llm_flows/contents.py
+++ b/src/google/adk/flows/llm_flows/contents.py
@@ -224,7 +224,8 @@ def _contains_empty_content(event: Event) -> bool:
This can happen to the events that only changed session state.
When both content and transcriptions are empty, the event will be considered
- as empty.
+ as empty. The content is considered empty if none of its parts contain text,
+ inline data, file data, function call, or function response.
Args:
event: The event to check.
@@ -239,7 +240,14 @@ def _contains_empty_content(event: Event) -> bool:
not event.content
or not event.content.role
or not event.content.parts
- or event.content.parts[0].text == ''
+ or all(
+ not p.text
+ and not p.inline_data
+ and not p.file_data
+ and not p.function_call
+ and not p.function_response
+ for p in [event.content.parts[0]]
+ )
) and (not event.output_transcription and not event.input_transcription)
diff --git a/src/google/adk/memory/vertex_ai_memory_bank_service.py b/src/google/adk/memory/vertex_ai_memory_bank_service.py
index b8f434c563..5df012e027 100644
--- a/src/google/adk/memory/vertex_ai_memory_bank_service.py
+++ b/src/google/adk/memory/vertex_ai_memory_bank_service.py
@@ -48,14 +48,15 @@ def __init__(
Args:
project: The project ID of the Memory Bank to use.
location: The location of the Memory Bank to use.
- agent_engine_id: The ID of the agent engine to use for the Memory Bank.
+ agent_engine_id: The ID of the agent engine to use for the Memory Bank,
e.g. '456' in
- 'projects/my-project/locations/us-central1/reasoningEngines/456'.
+ 'projects/my-project/locations/us-central1/reasoningEngines/456'. To
+ extract from api_resource.name, use:
+ ``agent_engine.api_resource.name.split('/')[-1]``
express_mode_api_key: The API key to use for Express Mode. If not
provided, the API key from the GOOGLE_API_KEY environment variable will
- be used. It will only be used if GOOGLE_GENAI_USE_VERTEXAI is true.
- Do not use Google AI Studio API key for this field. For more details,
- visit
+ be used. It will only be used if GOOGLE_GENAI_USE_VERTEXAI is true. Do
+ not use Google AI Studio API key for this field. For more details, visit
https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview
"""
self._project = project
@@ -65,6 +66,14 @@ def __init__(
project, location, express_mode_api_key
)
+ if agent_engine_id and '/' in agent_engine_id:
+ logger.warning(
+ "agent_engine_id appears to be a full resource path: '%s'. "
+ "Expected just the ID (e.g., '456'). "
+ "Extract the ID using: agent_engine.api_resource.name.split('/')[-1]",
+ agent_engine_id,
+ )
+
@override
async def add_session_to_memory(self, session: Session):
if not self._agent_engine_id:
diff --git a/src/google/adk/models/__init__.py b/src/google/adk/models/__init__.py
index 9f3c2a2c48..1be0cc698e 100644
--- a/src/google/adk/models/__init__.py
+++ b/src/google/adk/models/__init__.py
@@ -33,3 +33,23 @@
LLMRegistry.register(Gemini)
LLMRegistry.register(Gemma)
LLMRegistry.register(ApigeeLlm)
+
+# Optionally register Claude if anthropic package is installed
+try:
+ from .anthropic_llm import Claude
+
+ LLMRegistry.register(Claude)
+ __all__.append('Claude')
+except Exception:
+ # Claude support requires: pip install google-adk[extensions]
+ pass
+
+# Optionally register LiteLlm if litellm package is installed
+try:
+ from .lite_llm import LiteLlm
+
+ LLMRegistry.register(LiteLlm)
+ __all__.append('LiteLlm')
+except Exception:
+ # LiteLLM support requires: pip install google-adk[extensions]
+ pass
diff --git a/src/google/adk/models/anthropic_llm.py b/src/google/adk/models/anthropic_llm.py
index 6f343367a3..163fbe4571 100644
--- a/src/google/adk/models/anthropic_llm.py
+++ b/src/google/adk/models/anthropic_llm.py
@@ -28,7 +28,8 @@
from typing import TYPE_CHECKING
from typing import Union
-from anthropic import AnthropicVertex
+from anthropic import AsyncAnthropic
+from anthropic import AsyncAnthropicVertex
from anthropic import NOT_GIVEN
from anthropic import types as anthropic_types
from google.genai import types
@@ -41,7 +42,7 @@
if TYPE_CHECKING:
from .llm_request import LlmRequest
-__all__ = ["Claude"]
+__all__ = ["AnthropicLlm", "Claude"]
logger = logging.getLogger("google_adk." + __name__)
@@ -155,9 +156,11 @@ def content_to_message_param(
) -> anthropic_types.MessageParam:
message_block = []
for part in content.parts or []:
- # Image data is not supported in Claude for model turns.
- if _is_image_part(part):
- logger.warning("Image data is not supported in Claude for model turns.")
+ # Image data is not supported in Claude for assistant turns.
+ if content.role != "user" and _is_image_part(part):
+ logger.warning(
+ "Image data is not supported in Claude for assistant turns."
+ )
continue
message_block.append(part_to_message_block(part))
@@ -262,15 +265,15 @@ def function_declaration_to_tool_param(
)
-class Claude(BaseLlm):
- """Integration with Claude models served from Vertex AI.
+class AnthropicLlm(BaseLlm):
+ """Integration with Claude models via the Anthropic API.
Attributes:
model: The name of the Claude model.
max_tokens: The maximum number of tokens to generate.
"""
- model: str = "claude-3-5-sonnet-v2@20241022"
+ model: str = "claude-sonnet-4-20250514"
max_tokens: int = 8192
@classmethod
@@ -302,7 +305,7 @@ async def generate_content_async(
else NOT_GIVEN
)
# TODO(b/421255973): Enable streaming for anthropic models.
- message = self._anthropic_client.messages.create(
+ message = await self._anthropic_client.messages.create(
model=llm_request.model,
system=llm_request.config.system_instruction,
messages=messages,
@@ -313,7 +316,23 @@ async def generate_content_async(
yield message_to_generate_content_response(message)
@cached_property
- def _anthropic_client(self) -> AnthropicVertex:
+ def _anthropic_client(self) -> AsyncAnthropic:
+ return AsyncAnthropic()
+
+
+class Claude(AnthropicLlm):
+ """Integration with Claude models served from Vertex AI.
+
+ Attributes:
+ model: The name of the Claude model.
+ max_tokens: The maximum number of tokens to generate.
+ """
+
+ model: str = "claude-3-5-sonnet-v2@20241022"
+
+ @cached_property
+ @override
+ def _anthropic_client(self) -> AsyncAnthropicVertex:
if (
"GOOGLE_CLOUD_PROJECT" not in os.environ
or "GOOGLE_CLOUD_LOCATION" not in os.environ
@@ -323,7 +342,7 @@ def _anthropic_client(self) -> AnthropicVertex:
" Anthropic on Vertex."
)
- return AnthropicVertex(
+ return AsyncAnthropicVertex(
project_id=os.environ["GOOGLE_CLOUD_PROJECT"],
region=os.environ["GOOGLE_CLOUD_LOCATION"],
)
diff --git a/src/google/adk/models/base_llm_connection.py b/src/google/adk/models/base_llm_connection.py
index afce550b13..1bf522740e 100644
--- a/src/google/adk/models/base_llm_connection.py
+++ b/src/google/adk/models/base_llm_connection.py
@@ -72,7 +72,8 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
Yields:
LlmResponse: The model response.
"""
- pass
+ # We need to yield here to help type checkers infer the correct type.
+ yield
@abstractmethod
async def close(self):
diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py
index 0b72c79f83..55d4b62e96 100644
--- a/src/google/adk/models/gemini_llm_connection.py
+++ b/src/google/adk/models/gemini_llm_connection.py
@@ -21,6 +21,7 @@
from google.genai import types
from ..utils.context_utils import Aclosing
+from ..utils.variant_utils import GoogleLLMVariant
from .base_llm_connection import BaseLlmConnection
from .llm_response import LlmResponse
@@ -36,10 +37,15 @@
class GeminiLlmConnection(BaseLlmConnection):
"""The Gemini model connection."""
- def __init__(self, gemini_session: live.AsyncSession):
+ def __init__(
+ self,
+ gemini_session: live.AsyncSession,
+ api_backend: GoogleLLMVariant = GoogleLLMVariant.VERTEX_AI,
+ ):
self._gemini_session = gemini_session
self._input_transcription_text: str = ''
self._output_transcription_text: str = ''
+ self._api_backend = api_backend
async def send_history(self, history: list[types.Content]):
"""Sends the conversation history to the gemini model.
@@ -171,6 +177,9 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
yield self.__build_full_text_response(text)
text = ''
yield llm_response
+ # Note: in some cases, tool_call may arrive before
+ # generation_complete, causing transcription to appear after
+ # tool_call in the session log.
if message.server_content.input_transcription:
if message.server_content.input_transcription.text:
self._input_transcription_text += (
@@ -215,6 +224,32 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
partial=False,
)
self._output_transcription_text = ''
+ # The Gemini API might not send a transcription finished signal.
+ # Instead, we rely on generation_complete, turn_complete or
+ # interrupted signals to flush any pending transcriptions.
+ if self._api_backend == GoogleLLMVariant.GEMINI_API and (
+ message.server_content.interrupted
+ or message.server_content.turn_complete
+ or message.server_content.generation_complete
+ ):
+ if self._input_transcription_text:
+ yield LlmResponse(
+ input_transcription=types.Transcription(
+ text=self._input_transcription_text,
+ finished=True,
+ ),
+ partial=False,
+ )
+ self._input_transcription_text = ''
+ if self._output_transcription_text:
+ yield LlmResponse(
+ output_transcription=types.Transcription(
+ text=self._output_transcription_text,
+ finished=True,
+ ),
+ partial=False,
+ )
+ self._output_transcription_text = ''
if message.server_content.turn_complete:
if text:
yield self.__build_full_text_response(text)
@@ -244,7 +279,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
]
yield LlmResponse(content=types.Content(role='model', parts=parts))
if message.session_resumption_update:
- logger.info('Received session resumption message: %s', message)
+ logger.debug('Received session resumption message: %s', message)
yield (
LlmResponse(
live_session_resumption_update=message.session_resumption_update
diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py
index 1bdd311104..6b21cf62c7 100644
--- a/src/google/adk/models/google_llm.py
+++ b/src/google/adk/models/google_llm.py
@@ -19,8 +19,6 @@
import copy
from functools import cached_property
import logging
-import os
-import sys
from typing import AsyncGenerator
from typing import cast
from typing import Optional
@@ -31,7 +29,7 @@
from google.genai.errors import ClientError
from typing_extensions import override
-from .. import version
+from ..utils._client_labels_utils import get_client_labels
from ..utils.context_utils import Aclosing
from ..utils.streaming_utils import StreamingResponseAggregator
from ..utils.variant_utils import GoogleLLMVariant
@@ -49,8 +47,7 @@
_NEW_LINE = '\n'
_EXCLUDED_PART_FIELD = {'inline_data': {'data'}}
-_AGENT_ENGINE_TELEMETRY_TAG = 'remote_reasoning_engine'
-_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_AGENT_ENGINE_ID'
+
_RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE = """
On how to mitigate this issue, please refer to:
@@ -245,7 +242,7 @@ def api_client(self) -> Client:
return Client(
http_options=types.HttpOptions(
- headers=self._tracking_headers,
+ headers=self._tracking_headers(),
retry_options=self.retry_options,
)
)
@@ -258,16 +255,12 @@ def _api_backend(self) -> GoogleLLMVariant:
else GoogleLLMVariant.GEMINI_API
)
- @cached_property
def _tracking_headers(self) -> dict[str, str]:
- framework_label = f'google-adk/{version.__version__}'
- if os.environ.get(_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME):
- framework_label = f'{framework_label}+{_AGENT_ENGINE_TELEMETRY_TAG}'
- language_label = 'gl-python/' + sys.version.split()[0]
- version_header_value = f'{framework_label} {language_label}'
+ labels = get_client_labels()
+ header_value = ' '.join(labels)
tracking_headers = {
- 'x-goog-api-client': version_header_value,
- 'user-agent': version_header_value,
+ 'x-goog-api-client': header_value,
+ 'user-agent': header_value,
}
return tracking_headers
@@ -286,7 +279,7 @@ def _live_api_client(self) -> Client:
return Client(
http_options=types.HttpOptions(
- headers=self._tracking_headers, api_version=self._live_api_version
+ headers=self._tracking_headers(), api_version=self._live_api_version
)
)
@@ -310,7 +303,7 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
if not llm_request.live_connect_config.http_options.headers:
llm_request.live_connect_config.http_options.headers = {}
llm_request.live_connect_config.http_options.headers.update(
- self._tracking_headers
+ self._tracking_headers()
)
llm_request.live_connect_config.http_options.api_version = (
self._live_api_version
@@ -325,6 +318,23 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
types.Part.from_text(text=llm_request.config.system_instruction)
],
)
+ if (
+ llm_request.live_connect_config.session_resumption
+ and llm_request.live_connect_config.session_resumption.transparent
+ ):
+ logger.debug(
+ 'session resumption config: %s',
+ llm_request.live_connect_config.session_resumption,
+ )
+ logger.debug(
+ 'self._api_backend: %s',
+ self._api_backend,
+ )
+ if self._api_backend == GoogleLLMVariant.GEMINI_API:
+ raise ValueError(
+ 'Transparent session resumption is only supported for Vertex AI'
+ ' backend. Please use Vertex AI backend.'
+ )
llm_request.live_connect_config.tools = llm_request.config.tools
logger.info('Connecting to live for model: %s', llm_request.model)
logger.debug('Connecting to live with llm_request:%s', llm_request)
@@ -332,7 +342,7 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
async with self._live_api_client.aio.live.connect(
model=llm_request.model, config=llm_request.live_connect_config
) as live_session:
- yield GeminiLlmConnection(live_session)
+ yield GeminiLlmConnection(live_session, api_backend=self._api_backend)
async def _adapt_computer_use_tool(self, llm_request: LlmRequest) -> None:
"""Adapt the google computer use predefined functions to the adk computer use toolset."""
@@ -380,7 +390,7 @@ async def _preprocess_request(self, llm_request: LlmRequest) -> None:
def _merge_tracking_headers(self, headers: dict[str, str]) -> dict[str, str]:
"""Merge tracking headers to the given headers."""
headers = headers or {}
- for key, tracking_header_value in self._tracking_headers.items():
+ for key, tracking_header_value in self._tracking_headers().items():
custom_value = headers.get(key, None)
if not custom_value:
headers[key] = tracking_header_value
diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py
index c263a41b2a..162db05945 100644
--- a/src/google/adk/models/lite_llm.py
+++ b/src/google/adk/models/lite_llm.py
@@ -39,8 +39,8 @@
from litellm import acompletion
from litellm import ChatCompletionAssistantMessage
from litellm import ChatCompletionAssistantToolCall
-from litellm import ChatCompletionDeveloperMessage
from litellm import ChatCompletionMessageToolCall
+from litellm import ChatCompletionSystemMessage
from litellm import ChatCompletionToolMessage
from litellm import ChatCompletionUserMessage
from litellm import completion
@@ -96,6 +96,64 @@ def _decode_inline_text_data(raw_bytes: bytes) -> str:
return raw_bytes.decode("latin-1", errors="replace")
+def _iter_reasoning_texts(reasoning_value: Any) -> Iterable[str]:
+ """Yields textual fragments from provider specific reasoning payloads."""
+ if reasoning_value is None:
+ return
+
+ if isinstance(reasoning_value, types.Content):
+ if not reasoning_value.parts:
+ return
+ for part in reasoning_value.parts:
+ if part and part.text:
+ yield part.text
+ return
+
+ if isinstance(reasoning_value, str):
+ yield reasoning_value
+ return
+
+ if isinstance(reasoning_value, list):
+ for value in reasoning_value:
+ yield from _iter_reasoning_texts(value)
+ return
+
+ if isinstance(reasoning_value, dict):
+ # LiteLLM currently nests “reasoning” text under a few known keys.
+ # (Documented in https://docs.litellm.ai/docs/openai#reasoning-outputs)
+ for key in ("text", "content", "reasoning", "reasoning_content"):
+ text_value = reasoning_value.get(key)
+ if isinstance(text_value, str):
+ yield text_value
+ return
+
+ text_attr = getattr(reasoning_value, "text", None)
+ if isinstance(text_attr, str):
+ yield text_attr
+ elif isinstance(reasoning_value, (int, float, bool)):
+ yield str(reasoning_value)
+
+
+def _convert_reasoning_value_to_parts(reasoning_value: Any) -> List[types.Part]:
+ """Converts provider reasoning payloads into Gemini thought parts."""
+ return [
+ types.Part(text=text, thought=True)
+ for text in _iter_reasoning_texts(reasoning_value)
+ if text
+ ]
+
+
+def _extract_reasoning_value(message: Message | Dict[str, Any]) -> Any:
+ """Fetches the reasoning payload from a LiteLLM message or dict."""
+ if message is None:
+ return None
+ if hasattr(message, "reasoning_content"):
+ return getattr(message, "reasoning_content")
+ if isinstance(message, dict):
+ return message.get("reasoning_content")
+ return None
+
+
class ChatCompletionFileUrlObject(TypedDict, total=False):
file_data: str
file_id: str
@@ -113,6 +171,10 @@ class TextChunk(BaseModel):
text: str
+class ReasoningChunk(BaseModel):
+ parts: List[types.Part]
+
+
class UsageMetadataChunk(BaseModel):
prompt_tokens: int
completion_tokens: int
@@ -570,10 +632,8 @@ def _to_litellm_role(role: Optional[str]) -> Literal["user", "assistant"]:
}
-def _schema_to_dict(schema: types.Schema) -> dict:
- """Recursively converts a types.Schema to a pure-python dict
-
- with all enum values written as lower-case strings.
+def _schema_to_dict(schema: types.Schema | dict[str, Any]) -> dict:
+ """Recursively converts a schema object or dict to a pure-python dict.
Args:
schema: The schema to convert.
@@ -581,38 +641,36 @@ def _schema_to_dict(schema: types.Schema) -> dict:
Returns:
The dictionary representation of the schema.
"""
- # Dump without json encoding so we still get Enum members
- schema_dict = schema.model_dump(exclude_none=True)
+ schema_dict = (
+ schema.model_dump(exclude_none=True)
+ if isinstance(schema, types.Schema)
+ else dict(schema)
+ )
+ enum_values = schema_dict.get("enum")
+ if isinstance(enum_values, (list, tuple)):
+ schema_dict["enum"] = [value for value in enum_values if value is not None]
- # ---- normalise this level ------------------------------------------------
- if "type" in schema_dict:
- # schema_dict["type"] can be an Enum or a str
+ if "type" in schema_dict and schema_dict["type"] is not None:
t = schema_dict["type"]
- schema_dict["type"] = (t.value if isinstance(t, types.Type) else t).lower()
+ schema_dict["type"] = (
+ t.value if isinstance(t, types.Type) else str(t)
+ ).lower()
- # ---- recurse into `items` -----------------------------------------------
if "items" in schema_dict:
- schema_dict["items"] = _schema_to_dict(
- schema.items
- if isinstance(schema.items, types.Schema)
- else types.Schema.model_validate(schema_dict["items"])
+ items = schema_dict["items"]
+ schema_dict["items"] = (
+ _schema_to_dict(items)
+ if isinstance(items, (types.Schema, dict))
+ else items
)
- # ---- recurse into `properties` ------------------------------------------
if "properties" in schema_dict:
new_props = {}
for key, value in schema_dict["properties"].items():
- # value is a dict → rebuild a Schema object and recurse
- if isinstance(value, dict):
- new_props[key] = _schema_to_dict(types.Schema.model_validate(value))
- # value is already a Schema instance
- elif isinstance(value, types.Schema):
+ if isinstance(value, (types.Schema, dict)):
new_props[key] = _schema_to_dict(value)
- # plain dict without nested schemas
else:
new_props[key] = value
- if "type" in new_props[key]:
- new_props[key]["type"] = new_props[key]["type"].lower()
schema_dict["properties"] = new_props
return schema_dict
@@ -660,7 +718,6 @@ def _function_declaration_to_tool_param(
},
}
- # Handle required field from parameters
required_fields = (
getattr(function_declaration.parameters, "required", None)
if function_declaration.parameters
@@ -668,8 +725,6 @@ def _function_declaration_to_tool_param(
)
if required_fields:
tool_params["function"]["parameters"]["required"] = required_fields
- # parameters_json_schema already has required field in the json schema,
- # no need to add it separately
return tool_params
@@ -678,7 +733,14 @@ def _model_response_to_chunk(
response: ModelResponse,
) -> Generator[
Tuple[
- Optional[Union[TextChunk, FunctionChunk, UsageMetadataChunk]],
+ Optional[
+ Union[
+ TextChunk,
+ FunctionChunk,
+ UsageMetadataChunk,
+ ReasoningChunk,
+ ]
+ ],
Optional[str],
],
None,
@@ -703,11 +765,18 @@ def _model_response_to_chunk(
message_content: Optional[OpenAIMessageContent] = None
tool_calls: list[ChatCompletionMessageToolCall] = []
+ reasoning_parts: List[types.Part] = []
if message is not None:
(
message_content,
tool_calls,
) = _split_message_content_and_tool_calls(message)
+ reasoning_value = _extract_reasoning_value(message)
+ if reasoning_value:
+ reasoning_parts = _convert_reasoning_value_to_parts(reasoning_value)
+
+ if reasoning_parts:
+ yield ReasoningChunk(parts=reasoning_parts), finish_reason
if message_content:
yield TextChunk(text=message_content), finish_reason
@@ -771,8 +840,13 @@ def _model_response_to_generate_content_response(
if not message:
raise ValueError("No message in response")
+ thought_parts = _convert_reasoning_value_to_parts(
+ _extract_reasoning_value(message)
+ )
llm_response = _message_to_generate_content_response(
- message, model_version=response.model
+ message,
+ model_version=response.model,
+ thought_parts=thought_parts or None,
)
if finish_reason:
# If LiteLLM already provides a FinishReason enum (e.g., for Gemini), use
@@ -797,7 +871,11 @@ def _model_response_to_generate_content_response(
def _message_to_generate_content_response(
- message: Message, *, is_partial: bool = False, model_version: str = None
+ message: Message,
+ *,
+ is_partial: bool = False,
+ model_version: str = None,
+ thought_parts: Optional[List[types.Part]] = None,
) -> LlmResponse:
"""Converts a litellm message to LlmResponse.
@@ -810,7 +888,13 @@ def _message_to_generate_content_response(
The LlmResponse.
"""
- parts = []
+ parts: List[types.Part] = []
+ if not thought_parts:
+ thought_parts = _convert_reasoning_value_to_parts(
+ _extract_reasoning_value(message)
+ )
+ if thought_parts:
+ parts.extend(thought_parts)
message_content, tool_calls = _split_message_content_and_tool_calls(message)
if isinstance(message_content, str) and message_content:
parts.append(types.Part.from_text(text=message_content))
@@ -899,8 +983,8 @@ def _get_completion_inputs(
if llm_request.config.system_instruction:
messages.insert(
0,
- ChatCompletionDeveloperMessage(
- role="developer",
+ ChatCompletionSystemMessage(
+ role="system",
content=llm_request.config.system_instruction,
),
)
@@ -972,15 +1056,9 @@ def _build_function_declaration_log(
k: v.model_dump(exclude_none=True)
for k, v in func_decl.parameters.properties.items()
})
- elif func_decl.parameters_json_schema:
- param_str = str(func_decl.parameters_json_schema)
-
return_str = "None"
if func_decl.response:
return_str = str(func_decl.response.model_dump(exclude_none=True))
- elif func_decl.response_json_schema:
- return_str = str(func_decl.response_json_schema)
-
return f"{func_decl.name}: {param_str} -> {return_str}"
@@ -1182,6 +1260,7 @@ async def generate_content_async(
if stream:
text = ""
+ reasoning_parts: List[types.Part] = []
# Track function calls by index
function_calls = {} # index -> {name, args, id}
completion_args["stream"] = True
@@ -1223,6 +1302,14 @@ async def generate_content_async(
is_partial=True,
model_version=part.model,
)
+ elif isinstance(chunk, ReasoningChunk):
+ if chunk.parts:
+ reasoning_parts.extend(chunk.parts)
+ yield LlmResponse(
+ content=types.Content(role="model", parts=list(chunk.parts)),
+ partial=True,
+ model_version=part.model,
+ )
elif isinstance(chunk, UsageMetadataChunk):
usage_metadata = types.GenerateContentResponseUsageMetadata(
prompt_token_count=chunk.prompt_tokens,
@@ -1256,16 +1343,27 @@ async def generate_content_async(
tool_calls=tool_calls,
),
model_version=part.model,
+ thought_parts=list(reasoning_parts)
+ if reasoning_parts
+ else None,
)
)
text = ""
+ reasoning_parts = []
function_calls.clear()
- elif finish_reason == "stop" and text:
+ elif finish_reason == "stop" and (text or reasoning_parts):
+ message_content = text if text else None
aggregated_llm_response = _message_to_generate_content_response(
- ChatCompletionAssistantMessage(role="assistant", content=text),
+ ChatCompletionAssistantMessage(
+ role="assistant", content=message_content
+ ),
model_version=part.model,
+ thought_parts=list(reasoning_parts)
+ if reasoning_parts
+ else None,
)
text = ""
+ reasoning_parts = []
# waiting until streaming ends to yield the llm_response as litellm tends
# to send chunk that contains usage_metadata after the chunk with
@@ -1290,11 +1388,19 @@ async def generate_content_async(
def supported_models(cls) -> list[str]:
"""Provides the list of supported models.
- LiteLlm supports all models supported by litellm. We do not keep track of
- these models here. So we return an empty list.
+ This registers common provider prefixes. LiteLlm can handle many more,
+ but these patterns activate the integration for the most common use cases.
+ See https://docs.litellm.ai/docs/providers for a full list.
Returns:
A list of supported models.
"""
- return []
+ return [
+ # For OpenAI models (e.g., "openai/gpt-4o")
+ r"openai/.*",
+ # For Groq models via Groq API (e.g., "groq/llama3-70b-8192")
+ r"groq/.*",
+ # For Anthropic models (e.g., "anthropic/claude-3-opus-20240229")
+ r"anthropic/.*",
+ ]
diff --git a/src/google/adk/models/registry.py b/src/google/adk/models/registry.py
index 22e24d4c18..852996ff40 100644
--- a/src/google/adk/models/registry.py
+++ b/src/google/adk/models/registry.py
@@ -99,4 +99,26 @@ def resolve(model: str) -> type[BaseLlm]:
if re.compile(regex).fullmatch(model):
return llm_class
- raise ValueError(f'Model {model} not found.')
+ # Provide helpful error messages for known patterns
+ error_msg = f'Model {model} not found.'
+
+ # Check if it matches known patterns that require optional dependencies
+ if re.match(r'^claude-', model):
+ error_msg += (
+ '\n\nClaude models require the anthropic package.'
+ '\nInstall it with: pip install google-adk[extensions]'
+ '\nOr: pip install anthropic>=0.43.0'
+ )
+ elif '/' in model:
+ # Any model with provider/model format likely needs LiteLLM
+ error_msg += (
+ '\n\nProvider-style models (e.g., "provider/model-name") require'
+ ' the litellm package.'
+ '\nInstall it with: pip install google-adk[extensions]'
+ '\nOr: pip install litellm>=1.75.5'
+ '\n\nSupported providers include: openai, groq, anthropic, and 100+'
+ ' others.'
+ '\nSee https://docs.litellm.ai/docs/providers for a full list.'
+ )
+
+ raise ValueError(error_msg)
diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py
index 2bb0168928..4cf5a29546 100644
--- a/src/google/adk/runners.py
+++ b/src/google/adk/runners.py
@@ -19,6 +19,7 @@
import logging
from pathlib import Path
import queue
+import sys
from typing import Any
from typing import AsyncGenerator
from typing import Callable
@@ -66,6 +67,23 @@
logger = logging.getLogger('google_adk.' + __name__)
+def _is_tool_call_or_response(event: Event) -> bool:
+ return bool(event.get_function_calls() or event.get_function_responses())
+
+
+def _is_transcription(event: Event) -> bool:
+ return (
+ event.input_transcription is not None
+ or event.output_transcription is not None
+ )
+
+
+def _has_non_empty_transcription_text(transcription) -> bool:
+ return bool(
+ transcription and transcription.text and transcription.text.strip()
+ )
+
+
class Runner:
"""The Runner class is used to run agents.
@@ -625,6 +643,7 @@ async def _exec_with_plugin(
invocation_context: The invocation context
session: The current session
execute_fn: A callable that returns an AsyncGenerator of Events
+ is_live_call: Whether this is a live call
Yields:
Events from the execution, including any generated by plugins
@@ -650,13 +669,74 @@ async def _exec_with_plugin(
yield early_exit_event
else:
# Step 2: Otherwise continue with normal execution
+ # Note for live/bidi:
+ # the transcription may arrive later then the action(function call
+ # event and thus function response event). In this case, the order of
+ # transcription and function call event will be wrong if we just
+ # append as it arrives. To address this, we should check if there is
+ # transcription going on. If there is transcription going on, we
+ # should hold on appending the function call event until the
+ # transcription is finished. The transcription in progress can be
+ # identified by checking if the transcription event is partial. When
+ # the next transcription event is not partial, it means the previous
+ # transcription is finished. Then if there is any buffered function
+ # call event, we should append them after this finished(non-parital)
+ # transcription event.
+ buffered_events: list[Event] = []
+ is_transcribing: bool = False
+
async with Aclosing(execute_fn(invocation_context)) as agen:
async for event in agen:
- if not event.partial:
- if self._should_append_event(event, is_live_call):
+ if is_live_call:
+ if event.partial and _is_transcription(event):
+ is_transcribing = True
+ if is_transcribing and _is_tool_call_or_response(event):
+ # only buffer function call and function response event which is
+ # non-partial
+ buffered_events.append(event)
+ continue
+ # Note for live/bidi: for audio response, it's considered as
+ # non-paritla event(event.partial=None)
+ # event.partial=False and event.partial=None are considered as
+ # non-partial event; event.partial=True is considered as partial
+ # event.
+ if event.partial is not True:
+ if _is_transcription(event) and (
+ _has_non_empty_transcription_text(event.input_transcription)
+ or _has_non_empty_transcription_text(
+ event.output_transcription
+ )
+ ):
+ # transcription end signal, append buffered events
+ is_transcribing = False
+ logger.debug(
+ 'Appending transcription finished event: %s', event
+ )
+ if self._should_append_event(event, is_live_call):
+ await self.session_service.append_event(
+ session=session, event=event
+ )
+
+ for buffered_event in buffered_events:
+ logger.debug('Appending buffered event: %s', buffered_event)
+ await self.session_service.append_event(
+ session=session, event=buffered_event
+ )
+ buffered_events = []
+ else:
+ # non-transcription event or empty transcription event, for
+ # example, event that stores blob reference, should be appended.
+ if self._should_append_event(event, is_live_call):
+ logger.debug('Appending non-buffered event: %s', event)
+ await self.session_service.append_event(
+ session=session, event=event
+ )
+ else:
+ if event.partial is not True:
await self.session_service.append_event(
session=session, event=event
)
+
# Step 3: Run the on_event callbacks to optionally modify the event.
modified_event = await plugin_manager.run_on_event_callback(
invocation_context=invocation_context, event=event
@@ -985,16 +1065,16 @@ async def run_debug(
over session management, event streaming, and error handling.
Args:
- user_messages: Message(s) to send to the agent. Can be:
- - Single string: "What is 2+2?"
- - List of strings: ["Hello!", "What's my name?"]
+ user_messages: Message(s) to send to the agent. Can be: - Single string:
+ "What is 2+2?" - List of strings: ["Hello!", "What's my name?"]
user_id: User identifier. Defaults to "debug_user_id".
- session_id: Session identifier for conversation persistence.
- Defaults to "debug_session_id". Reuse the same ID to continue a conversation.
+ session_id: Session identifier for conversation persistence. Defaults to
+ "debug_session_id". Reuse the same ID to continue a conversation.
run_config: Optional configuration for the agent execution.
- quiet: If True, suppresses console output. Defaults to False (output shown).
- verbose: If True, shows detailed tool calls and responses. Defaults to False
- for cleaner output showing only final agent responses.
+ quiet: If True, suppresses console output. Defaults to False (output
+ shown).
+ verbose: If True, shows detailed tool calls and responses. Defaults to
+ False for cleaner output showing only final agent responses.
Returns:
list[Event]: All events from all messages.
@@ -1011,7 +1091,8 @@ async def run_debug(
>>> await runner.run_debug(["Hello!", "What's my name?"])
Continue a debug session:
- >>> await runner.run_debug("What did we discuss?") # Continues default session
+ >>> await runner.run_debug("What did we discuss?") # Continues default
+ session
Separate debug sessions:
>>> await runner.run_debug("Hi", user_id="alice", session_id="debug1")
@@ -1353,7 +1434,12 @@ async def close(self):
logger.info('Runner closed.')
- async def __aenter__(self):
+ if sys.version_info < (3, 11):
+ Self = 'Runner' # pylint: disable=invalid-name
+ else:
+ from typing import Self # pylint: disable=g-import-not-at-top
+
+ async def __aenter__(self) -> Self:
"""Async context manager entry."""
return self
diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py
index b929f23409..a352918211 100644
--- a/src/google/adk/sessions/database_session_service.py
+++ b/src/google/adk/sessions/database_session_service.py
@@ -722,6 +722,11 @@ async def append_event(self, session: Session, event: Event) -> Event:
if session_state_delta:
storage_session.state = storage_session.state | session_state_delta
+ if storage_session._dialect_name == "sqlite":
+ update_time = datetime.utcfromtimestamp(event.timestamp)
+ else:
+ update_time = datetime.fromtimestamp(event.timestamp)
+ storage_session.update_time = update_time
sql_session.add(StorageEvent.from_event(session, event))
await sql_session.commit()
diff --git a/src/google/adk/sessions/migration/README.md b/src/google/adk/sessions/migration/README.md
new file mode 100644
index 0000000000..6a9079534e
--- /dev/null
+++ b/src/google/adk/sessions/migration/README.md
@@ -0,0 +1,109 @@
+# Process for a New Schema Version
+
+This document outlines the steps required to introduce a new database schema
+version for `DatabaseSessionService`. Let's assume you are introducing schema
+version `2.0`, migrating from `1.0`.
+
+## 1. Update SQLAlchemy Models
+
+Modify the SQLAlchemy model classes (`StorageSession`, `StorageEvent`,
+`StorageAppState`, `StorageUserState`, `StorageMetadata`) in
+`database_session_service.py` to reflect the new `2.0` schema. This could
+involve adding new `mapped_column` definitions, changing types, or adding new
+classes for new tables.
+
+## 2. Create a New Migration Script
+
+You need to create a script that migrates data from schema `1.0` to `2.0`.
+
+* Create a new file, for example:
+ `google/adk/sessions/migration/migrate_1_0_to_2_0.py`.
+* This script must contain a `migrate(source_db_url: str, dest_db_url: str)`
+ function, similar to `migrate_from_sqlalchemy_pickle.py`.
+* Inside this function:
+ * Connect to the `source_db_url` (which has schema 1.0) and `dest_db_url`
+ engines using SQLAlchemy.
+ * **Important**: Create the tables in the destination database using the
+ new 2.0 schema definition by calling
+ `dss.Base.metadata.create_all(dest_engine)`.
+ * Read data from the source tables (schema 1.0). The recommended way to do
+ this without relying on outdated models is to use `sqlalchemy.text`,
+ like:
+
+ ```python
+ from sqlalchemy import text
+ ...
+ rows = source_session.execute(text("SELECT * FROM sessions")).mappings().all()
+ ```
+
+ * For each row read from the source, transform the data as necessary to
+ fit the `2.0` schema, and create an instance of the corresponding new
+ SQLAlchemy model (e.g., `dss.StorageSession(...)`).
+ * Add these new `2.0` objects to the destination session, ideally using
+ `dest_session.merge()` to upsert.
+ * After migrating data for all tables, ensure the destination database is
+ marked with the new schema version:
+
+ ```python
+ from google.adk.sessions import database_session_service as dss
+ from google.adk.sessions.migration import _schema_check
+ ...
+ dest_session.merge(
+ dss.StorageMetadata(
+ key=_schema_check.SCHEMA_VERSION_KEY,
+ value="2.0",
+ )
+ )
+ dest_session.commit()
+ ```
+
+## 3. Update Schema Version Constant
+
+You need to update `CURRENT_SCHEMA_VERSION` in
+`google/adk/sessions/migration/_schema_check.py` to reflect the new version:
+
+```python
+CURRENT_SCHEMA_VERSION = "2.0"
+```
+
+This will also update `LATEST_VERSION` in `migration_runner.py`, as it uses this
+constant.
+
+## 4. Register the New Migration in Migration Runner
+
+In `google/adk/sessions/migration/migration_runner.py`, import your new
+migration script and add it to the `MIGRATIONS` dictionary. This tells the
+runner how to get from version `1.0` to `2.0`. For example:
+
+```python
+from google.adk.sessions.migration import _schema_check
+from google.adk.sessions.migration import migrate_from_sqlalchemy_pickle
+from google.adk.sessions.migration import migrate_1_0_to_2_0
+...
+MIGRATIONS = {
+ _schema_check.SCHEMA_VERSION_0_1_PICKLE: (
+ _schema_check.SCHEMA_VERSION_1_0_JSON,
+ migrate_from_sqlalchemy_pickle.migrate,
+ ),
+ _schema_check.SCHEMA_VERSION_1_0_JSON: (
+ "2.0",
+ migrate_1_0_to_2_0.migrate,
+ ),
+}
+```
+
+## 5. Update `DatabaseSessionService` Business Logic
+
+If your schema change affects how data should be read or written during normal
+operation (e.g., you added a new column that needs to be populated on session
+creation), update the methods within `DatabaseSessionService` (`create_session`,
+`get_session`, `append_event`, etc.) in `database_session_service.py`
+accordingly.
+
+## 6. CLI Command Changes
+
+No changes are needed for the Click command definition in `cli_tools_click.py`.
+The `adk migrate session` command calls `migration_runner.upgrade()`, which will
+now automatically detect the source database version and apply the necessary
+migration steps (e.g., `0.1 -> 1.0 -> 2.0`, or `1.0 -> 2.0`) to reach
+`LATEST_VERSION`.
\ No newline at end of file
diff --git a/src/google/adk/sessions/sqlite_session_service.py b/src/google/adk/sessions/sqlite_session_service.py
index 10d05f6dfd..8ba6531f52 100644
--- a/src/google/adk/sessions/sqlite_session_service.py
+++ b/src/google/adk/sessions/sqlite_session_service.py
@@ -323,7 +323,7 @@ async def append_event(self, session: Session, event: Event) -> Event:
# Trim temp state before persisting
event = self._trim_temp_delta_state(event)
- now = time.time()
+ event_timestamp = event.timestamp
async with self._get_db_connection() as db:
# Check for stale session
@@ -355,11 +355,15 @@ async def append_event(self, session: Session, event: Event) -> Event:
if app_state_delta:
await self._upsert_app_state(
- db, session.app_name, app_state_delta, now
+ db, session.app_name, app_state_delta, event_timestamp
)
if user_state_delta:
await self._upsert_user_state(
- db, session.app_name, session.user_id, user_state_delta, now
+ db,
+ session.app_name,
+ session.user_id,
+ user_state_delta,
+ event_timestamp,
)
if session_state_delta:
await self._update_session_state_in_db(
@@ -368,7 +372,7 @@ async def append_event(self, session: Session, event: Event) -> Event:
session.user_id,
session.id,
session_state_delta,
- now,
+ event_timestamp,
)
has_session_state_delta = True
@@ -392,12 +396,17 @@ async def append_event(self, session: Session, event: Event) -> Event:
await db.execute(
"UPDATE sessions SET update_time=? WHERE app_name=? AND user_id=?"
" AND id=?",
- (now, session.app_name, session.user_id, session.id),
+ (
+ event_timestamp,
+ session.app_name,
+ session.user_id,
+ session.id,
+ ),
)
await db.commit()
- # Update timestamp with commit time
- session.last_update_time = now
+ # Update timestamp based on event time
+ session.last_update_time = event_timestamp
# Also update the in-memory session
await super().append_event(session=session, event=event)
diff --git a/src/google/adk/tools/__init__.py b/src/google/adk/tools/__init__.py
index 1777bd93c5..32264adcbd 100644
--- a/src/google/adk/tools/__init__.py
+++ b/src/google/adk/tools/__init__.py
@@ -21,6 +21,7 @@
if TYPE_CHECKING:
from ..auth.auth_tool import AuthToolArguments
from .agent_tool import AgentTool
+ from .api_registry import ApiRegistry
from .apihub_tool.apihub_toolset import APIHubToolset
from .base_tool import BaseTool
from .discovery_engine_search_tool import DiscoveryEngineSearchTool
@@ -37,6 +38,7 @@
from .preload_memory_tool import preload_memory_tool as preload_memory
from .tool_context import ToolContext
from .transfer_to_agent_tool import transfer_to_agent
+ from .transfer_to_agent_tool import TransferToAgentTool
from .url_context_tool import url_context
from .vertex_ai_search_tool import VertexAiSearchTool
@@ -75,10 +77,15 @@
'preload_memory': ('.preload_memory_tool', 'preload_memory_tool'),
'ToolContext': ('.tool_context', 'ToolContext'),
'transfer_to_agent': ('.transfer_to_agent_tool', 'transfer_to_agent'),
+ 'TransferToAgentTool': (
+ '.transfer_to_agent_tool',
+ 'TransferToAgentTool',
+ ),
'url_context': ('.url_context_tool', 'url_context'),
'VertexAiSearchTool': ('.vertex_ai_search_tool', 'VertexAiSearchTool'),
'MCPToolset': ('.mcp_tool.mcp_toolset', 'MCPToolset'),
'McpToolset': ('.mcp_tool.mcp_toolset', 'McpToolset'),
+ 'ApiRegistry': ('.api_registry', 'ApiRegistry'),
}
__all__ = list(_LAZY_MAPPING.keys())
diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py
index 702f6e43aa..46d8616619 100644
--- a/src/google/adk/tools/agent_tool.py
+++ b/src/google/adk/tools/agent_tool.py
@@ -45,11 +45,22 @@ class AgentTool(BaseTool):
Attributes:
agent: The agent to wrap.
skip_summarization: Whether to skip summarization of the agent output.
+ include_plugins: Whether to propagate plugins from the parent runner context
+ to the agent's runner. When True (default), the agent will inherit all
+ plugins from its parent. Set to False to run the agent with an isolated
+ plugin environment.
"""
- def __init__(self, agent: BaseAgent, skip_summarization: bool = False):
+ def __init__(
+ self,
+ agent: BaseAgent,
+ skip_summarization: bool = False,
+ *,
+ include_plugins: bool = True,
+ ):
self.agent = agent
self.skip_summarization: bool = skip_summarization
+ self.include_plugins = include_plugins
super().__init__(name=agent.name, description=agent.description)
@@ -68,6 +79,8 @@ def _get_declaration(self) -> types.FunctionDeclaration:
result = _automatic_function_calling_util.build_function_declaration(
func=self.agent.input_schema, variant=self._api_variant
)
+ # Override the description with the agent's description
+ result.description = self.agent.description
else:
result = types.FunctionDeclaration(
parameters=types.Schema(
@@ -130,6 +143,11 @@ async def run_async(
invocation_context.app_name if invocation_context else None
)
child_app_name = parent_app_name or self.agent.name
+ plugins = (
+ tool_context._invocation_context.plugin_manager.plugins
+ if self.include_plugins
+ else None
+ )
runner = Runner(
app_name=child_app_name,
agent=self.agent,
@@ -137,7 +155,7 @@ async def run_async(
session_service=InMemorySessionService(),
memory_service=InMemoryMemoryService(),
credential_service=tool_context._invocation_context.credential_service,
- plugins=list(tool_context._invocation_context.plugin_manager.plugins),
+ plugins=plugins,
)
state_dict = {
@@ -192,7 +210,9 @@ def from_config(
agent_tool_config.agent, config_abs_path
)
return cls(
- agent=agent, skip_summarization=agent_tool_config.skip_summarization
+ agent=agent,
+ skip_summarization=agent_tool_config.skip_summarization,
+ include_plugins=agent_tool_config.include_plugins,
)
@@ -204,3 +224,6 @@ class AgentToolConfig(BaseToolConfig):
skip_summarization: bool = False
"""Whether to skip summarization of the agent output."""
+
+ include_plugins: bool = True
+ """Whether to include plugins from parent runner context."""
diff --git a/src/google/adk/tools/api_registry.py b/src/google/adk/tools/api_registry.py
new file mode 100644
index 0000000000..e3f0076404
--- /dev/null
+++ b/src/google/adk/tools/api_registry.py
@@ -0,0 +1,123 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import sys
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Union
+
+import google.auth
+import google.auth.transport.requests
+import httpx
+
+from .base_toolset import ToolPredicate
+from .mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams
+from .mcp_tool.mcp_toolset import McpToolset
+
+API_REGISTRY_URL = "https://cloudapiregistry.googleapis.com"
+
+
+class ApiRegistry:
+ """Registry that provides McpToolsets for MCP servers registered in API Registry."""
+
+ def __init__(
+ self,
+ api_registry_project_id: str,
+ location: str = "global",
+ header_provider: Optional[
+ Callable[[ReadonlyContext], Dict[str, str]]
+ ] = None,
+ ):
+ """Initialize the API Registry.
+
+ Args:
+ api_registry_project_id: The project ID for the Google Cloud API Registry.
+ location: The location of the API Registry resources.
+ header_provider: Optional function to provide additional headers for MCP
+ server calls.
+ """
+ self.api_registry_project_id = api_registry_project_id
+ self.location = location
+ self._credentials, _ = google.auth.default()
+ self._mcp_servers: Dict[str, Dict[str, Any]] = {}
+ self._header_provider = header_provider
+
+ url = f"{API_REGISTRY_URL}/v1beta/projects/{self.api_registry_project_id}/locations/{self.location}/mcpServers"
+ try:
+ request = google.auth.transport.requests.Request()
+ self._credentials.refresh(request)
+ headers = {
+ "Authorization": f"Bearer {self._credentials.token}",
+ "Content-Type": "application/json",
+ }
+ with httpx.Client() as client:
+ response = client.get(url, headers=headers)
+ response.raise_for_status()
+ mcp_servers_list = response.json().get("mcpServers", [])
+ for server in mcp_servers_list:
+ server_name = server.get("name", "")
+ if server_name:
+ self._mcp_servers[server_name] = server
+ except (httpx.HTTPError, ValueError) as e:
+ # Handle error in fetching or parsing tool definitions
+ raise RuntimeError(
+ f"Error fetching MCP servers from API Registry: {e}"
+ ) from e
+
+ def get_toolset(
+ self,
+ mcp_server_name: str,
+ tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
+ tool_name_prefix: Optional[str] = None,
+ ) -> McpToolset:
+ """Return the MCP Toolset based on the params.
+
+ Args:
+ mcp_server_name: Filter to select the MCP server name to get tools
+ from.
+ tool_filter: Optional filter to select specific tools. Can be a list of
+ tool names or a ToolPredicate function.
+ tool_name_prefix: Optional prefix to prepend to the names of the tools
+ returned by the toolset.
+
+ Returns:
+ McpToolset: A toolset for the MCP server specified.
+ """
+ server = self._mcp_servers.get(mcp_server_name)
+ if not server:
+ raise ValueError(
+ f"MCP server {mcp_server_name} not found in API Registry."
+ )
+ if not server.get("urls"):
+ raise ValueError(f"MCP server {mcp_server_name} has no URLs.")
+
+ mcp_server_url = server["urls"][0]
+ request = google.auth.transport.requests.Request()
+ self._credentials.refresh(request)
+ headers = {
+ "Authorization": f"Bearer {self._credentials.token}",
+ }
+ return McpToolset(
+ connection_params=StreamableHTTPConnectionParams(
+ url="https://" + mcp_server_url,
+ headers=headers,
+ ),
+ tool_filter=tool_filter,
+ tool_name_prefix=tool_name_prefix,
+ header_provider=self._header_provider,
+ )
diff --git a/src/google/adk/tools/bigquery/config.py b/src/google/adk/tools/bigquery/config.py
index 7768f214ed..39b6a3d9b6 100644
--- a/src/google/adk/tools/bigquery/config.py
+++ b/src/google/adk/tools/bigquery/config.py
@@ -101,6 +101,16 @@ class BigQueryToolConfig(BaseModel):
locations, see https://cloud.google.com/bigquery/docs/locations.
"""
+ job_labels: Optional[dict[str, str]] = None
+ """Labels to apply to BigQuery jobs for tracking and monitoring.
+
+ These labels will be added to all BigQuery jobs executed by the tools.
+ Labels must be key-value pairs where both keys and values are strings.
+ Labels can be used for billing, monitoring, and resource organization.
+ For more information about labels, see
+ https://cloud.google.com/bigquery/docs/labels-intro.
+ """
+
@field_validator('maximum_bytes_billed')
@classmethod
def validate_maximum_bytes_billed(cls, v):
@@ -121,3 +131,13 @@ def validate_application_name(cls, v):
if v and ' ' in v:
raise ValueError('Application name should not contain spaces.')
return v
+
+ @field_validator('job_labels')
+ @classmethod
+ def validate_job_labels(cls, v):
+ """Validate that job_labels keys are not empty."""
+ if v is not None:
+ for key in v.keys():
+ if not key:
+ raise ValueError('Label keys cannot be empty.')
+ return v
diff --git a/src/google/adk/tools/bigquery/query_tool.py b/src/google/adk/tools/bigquery/query_tool.py
index 666dc3c5a1..5bcd734e70 100644
--- a/src/google/adk/tools/bigquery/query_tool.py
+++ b/src/google/adk/tools/bigquery/query_tool.py
@@ -68,7 +68,10 @@ def _execute_sql(
bq_connection_properties = []
# BigQuery job labels if applicable
- bq_job_labels = {}
+ bq_job_labels = (
+ settings.job_labels.copy() if settings and settings.job_labels else {}
+ )
+
if caller_id:
bq_job_labels["adk-bigquery-tool"] = caller_id
if settings and settings.application_name:
diff --git a/src/google/adk/tools/crewai_tool.py b/src/google/adk/tools/crewai_tool.py
index eaef479274..875b82e5b9 100644
--- a/src/google/adk/tools/crewai_tool.py
+++ b/src/google/adk/tools/crewai_tool.py
@@ -30,16 +30,9 @@
try:
from crewai.tools import BaseTool as CrewaiBaseTool
except ImportError as e:
- import sys
-
- if sys.version_info < (3, 10):
- raise ImportError(
- 'Crewai Tools require Python 3.10+. Please upgrade your Python version.'
- ) from e
- else:
- raise ImportError(
- "Crewai Tools require pip install 'google-adk[extensions]'."
- ) from e
+ raise ImportError(
+ "Crewai Tools require pip install 'google-adk[extensions]'."
+ ) from e
class CrewaiTool(FunctionTool):
diff --git a/src/google/adk/tools/mcp_tool/__init__.py b/src/google/adk/tools/mcp_tool/__init__.py
index f1e56b99c4..1170b2e1af 100644
--- a/src/google/adk/tools/mcp_tool/__init__.py
+++ b/src/google/adk/tools/mcp_tool/__init__.py
@@ -39,15 +39,7 @@
except ImportError as e:
import logging
- import sys
logger = logging.getLogger('google_adk.' + __name__)
-
- if sys.version_info < (3, 10):
- logger.warning(
- 'MCP Tool requires Python 3.10 or above. Please upgrade your Python'
- ' version.'
- )
- else:
- logger.debug('MCP Tool is not installed')
- logger.debug(e)
+ logger.debug('MCP Tool is not installed')
+ logger.debug(e)
diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py
index d95d48f282..c9c4c2ae66 100644
--- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py
+++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py
@@ -29,24 +29,13 @@
from typing import Union
import anyio
+from mcp import ClientSession
+from mcp import StdioServerParameters
+from mcp.client.sse import sse_client
+from mcp.client.stdio import stdio_client
+from mcp.client.streamable_http import streamablehttp_client
from pydantic import BaseModel
-try:
- from mcp import ClientSession
- from mcp import StdioServerParameters
- from mcp.client.sse import sse_client
- from mcp.client.stdio import stdio_client
- from mcp.client.streamable_http import streamablehttp_client
-except ImportError as e:
-
- if sys.version_info < (3, 10):
- raise ImportError(
- 'MCP Tool requires Python 3.10 or above. Please upgrade your Python'
- ' version.'
- ) from e
- else:
- raise e
-
logger = logging.getLogger('google_adk.' + __name__)
@@ -108,10 +97,10 @@ class StreamableHTTPConnectionParams(BaseModel):
terminate_on_close: bool = True
-def retry_on_closed_resource(func):
- """Decorator to automatically retry action when MCP session is closed.
+def retry_on_errors(func):
+ """Decorator to automatically retry action when MCP session errors occur.
- When MCP session was closed, the decorator will automatically retry the
+ When MCP session errors occur, the decorator will automatically retry the
action once. The create_session method will handle creating a new session
if the old one was disconnected.
@@ -126,11 +115,11 @@ def retry_on_closed_resource(func):
async def wrapper(self, *args, **kwargs):
try:
return await func(self, *args, **kwargs)
- except (anyio.ClosedResourceError, anyio.BrokenResourceError):
- # If the session connection is closed or unusable, we will retry the
- # function to reconnect to the server. create_session will handle
- # detecting and replacing disconnected sessions.
- logger.info('Retrying %s due to closed resource', func.__name__)
+ except Exception as e:
+ # If an error is thrown, we will retry the function to reconnect to the
+ # server. create_session will handle detecting and replacing disconnected
+ # sessions.
+ logger.info('Retrying %s due to error: %s', func.__name__, e)
return await func(self, *args, **kwargs)
return wrapper
diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py
index ad420a3d0f..b15f2c73fe 100644
--- a/src/google/adk/tools/mcp_tool/mcp_tool.py
+++ b/src/google/adk/tools/mcp_tool/mcp_tool.py
@@ -17,7 +17,6 @@
import base64
import inspect
import logging
-import sys
from typing import Any
from typing import Callable
from typing import Dict
@@ -27,35 +26,21 @@
from fastapi.openapi.models import APIKeyIn
from google.genai.types import FunctionDeclaration
+from mcp.types import Tool as McpBaseTool
from typing_extensions import override
from ...agents.readonly_context import ReadonlyContext
-from ...features import FeatureName
-from ...features import is_feature_enabled
-from .._gemini_schema_util import _to_gemini_schema
-from .mcp_session_manager import MCPSessionManager
-from .mcp_session_manager import retry_on_closed_resource
-
-# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
-# their Python version to 3.10 if it fails.
-try:
- from mcp.types import Tool as McpBaseTool
-except ImportError as e:
- if sys.version_info < (3, 10):
- raise ImportError(
- "MCP Tool requires Python 3.10 or above. Please upgrade your Python"
- " version."
- ) from e
- else:
- raise e
-
-
from ...auth.auth_credential import AuthCredential
from ...auth.auth_schemes import AuthScheme
from ...auth.auth_tool import AuthConfig
+from ...features import FeatureName
+from ...features import is_feature_enabled
+from .._gemini_schema_util import _to_gemini_schema
from ..base_authenticated_tool import BaseAuthenticatedTool
# import
from ..tool_context import ToolContext
+from .mcp_session_manager import MCPSessionManager
+from .mcp_session_manager import retry_on_errors
logger = logging.getLogger("google_adk." + __name__)
@@ -195,7 +180,7 @@ async def run_async(
return {"error": "This tool call is rejected."}
return await super().run_async(args=args, tool_context=tool_context)
- @retry_on_closed_resource
+ @retry_on_errors
@override
async def _run_async_impl(
self, *, args, tool_context: ToolContext, credential: AuthCredential
diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py
index daa88f9031..035b75878b 100644
--- a/src/google/adk/tools/mcp_tool/mcp_toolset.py
+++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py
@@ -25,6 +25,8 @@
from typing import Union
import warnings
+from mcp import StdioServerParameters
+from mcp.types import ListToolsResult
from pydantic import model_validator
from typing_extensions import override
@@ -37,27 +39,10 @@
from ..tool_configs import BaseToolConfig
from ..tool_configs import ToolArgsConfig
from .mcp_session_manager import MCPSessionManager
-from .mcp_session_manager import retry_on_closed_resource
+from .mcp_session_manager import retry_on_errors
from .mcp_session_manager import SseConnectionParams
from .mcp_session_manager import StdioConnectionParams
from .mcp_session_manager import StreamableHTTPConnectionParams
-
-# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
-# their Python version to 3.10 if it fails.
-try:
- from mcp import StdioServerParameters
- from mcp.types import ListToolsResult
-except ImportError as e:
- import sys
-
- if sys.version_info < (3, 10):
- raise ImportError(
- "MCP Tool requires Python 3.10 or above. Please upgrade your Python"
- " version."
- ) from e
- else:
- raise e
-
from .mcp_tool import MCPTool
logger = logging.getLogger("google_adk." + __name__)
@@ -155,7 +140,7 @@ def __init__(
self._auth_credential = auth_credential
self._require_confirmation = require_confirmation
- @retry_on_closed_resource
+ @retry_on_errors
async def get_tools(
self,
readonly_context: Optional[ReadonlyContext] = None,
diff --git a/src/google/adk/tools/openapi_tool/common/common.py b/src/google/adk/tools/openapi_tool/common/common.py
index 7187b1bd1b..1df3125e3d 100644
--- a/src/google/adk/tools/openapi_tool/common/common.py
+++ b/src/google/adk/tools/openapi_tool/common/common.py
@@ -64,11 +64,9 @@ class ApiParameter(BaseModel):
required: bool = False
def model_post_init(self, _: Any):
- self.py_name = (
- self.py_name
- if self.py_name
- else rename_python_keywords(_to_snake_case(self.original_name))
- )
+ if not self.py_name:
+ inferred_name = rename_python_keywords(_to_snake_case(self.original_name))
+ self.py_name = inferred_name or self._default_py_name()
if isinstance(self.param_schema, str):
self.param_schema = Schema.model_validate_json(self.param_schema)
@@ -77,6 +75,16 @@ def model_post_init(self, _: Any):
self.type_hint = TypeHintHelper.get_type_hint(self.param_schema)
return self
+ def _default_py_name(self) -> str:
+ location_defaults = {
+ 'body': 'body',
+ 'query': 'query_param',
+ 'path': 'path_param',
+ 'header': 'header_param',
+ 'cookie': 'cookie_param',
+ }
+ return location_defaults.get(self.param_location or '', 'value')
+
@model_serializer
def _serialize(self):
return {
diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py
index 06d692a2b0..326ff6787e 100644
--- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py
+++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py
@@ -139,10 +139,19 @@ def _process_request_body(self):
)
)
else:
+ # Prefer explicit body name to avoid empty keys when schema lacks type
+ # information (e.g., oneOf/anyOf/allOf) while retaining legacy behavior
+ # for simple scalar types.
+ if schema and (schema.oneOf or schema.anyOf or schema.allOf):
+ param_name = 'body'
+ elif not schema or not schema.type:
+ param_name = 'body'
+ else:
+ param_name = ''
+
self._params.append(
- # Empty name for unnamed body param
ApiParameter(
- original_name='',
+ original_name=param_name,
param_location='body',
param_schema=schema,
description=description,
diff --git a/src/google/adk/tools/spanner/search_tool.py b/src/google/adk/tools/spanner/search_tool.py
index b3cf797edf..2b6b9777dc 100644
--- a/src/google/adk/tools/spanner/search_tool.py
+++ b/src/google/adk/tools/spanner/search_tool.py
@@ -20,23 +20,34 @@
from typing import List
from typing import Optional
-from google.adk.tools.spanner import client
-from google.adk.tools.spanner.settings import SpannerToolSettings
-from google.adk.tools.tool_context import ToolContext
from google.auth.credentials import Credentials
from google.cloud.spanner_admin_database_v1.types import DatabaseDialect
from google.cloud.spanner_v1.database import Database
-# Embedding options
-_SPANNER_EMBEDDING_MODEL_NAME = "spanner_embedding_model_name"
-_VERTEX_AI_EMBEDDING_MODEL_ENDPOINT = "vertex_ai_embedding_model_endpoint"
+from . import client
+from . import utils
+from .settings import APPROXIMATE_NEAREST_NEIGHBORS
+from .settings import EXACT_NEAREST_NEIGHBORS
+from .settings import SpannerToolSettings
+
+# Embedding model settings.
+# Only for Spanner GoogleSQL dialect database, and use Spanner ML.PREDICT
+# function.
+_SPANNER_GSQL_EMBEDDING_MODEL_NAME = "spanner_googlesql_embedding_model_name"
+# Only for Spanner PostgreSQL dialect database, and use spanner.ML_PREDICT_ROW
+# to inferencing with Vertex AI embedding model endpoint.
+_SPANNER_PG_VERTEX_AI_EMBEDDING_MODEL_ENDPOINT = (
+ "spanner_postgresql_vertex_ai_embedding_model_endpoint"
+)
+# For both Spanner GoogleSQL and PostgreSQL dialects, use Vertex AI embedding
+# model to generate embeddings for vector similarity search.
+_VERTEX_AI_EMBEDDING_MODEL_NAME = "vertex_ai_embedding_model_name"
+_OUTPUT_DIMENSIONALITY = "output_dimensionality"
# Search options
_TOP_K = "top_k"
_DISTANCE_TYPE = "distance_type"
_NEAREST_NEIGHBORS_ALGORITHM = "nearest_neighbors_algorithm"
-_EXACT_NEAREST_NEIGHBORS = "EXACT_NEAREST_NEIGHBORS"
-_APPROXIMATE_NEAREST_NEIGHBORS = "APPROXIMATE_NEAREST_NEIGHBORS"
_NUM_LEAVES_TO_SEARCH = "num_leaves_to_search"
# Constants
@@ -48,12 +59,12 @@
def _generate_googlesql_for_embedding_query(
- spanner_embedding_model_name: str,
+ spanner_gsql_embedding_model_name: str,
) -> str:
return f"""
SELECT embeddings.values
FROM ML.PREDICT(
- MODEL {spanner_embedding_model_name},
+ MODEL {spanner_gsql_embedding_model_name},
(SELECT CAST(@{_GOOGLESQL_PARAMETER_TEXT_QUERY} AS STRING) as content)
)
"""
@@ -61,37 +72,60 @@ def _generate_googlesql_for_embedding_query(
def _generate_postgresql_for_embedding_query(
vertex_ai_embedding_model_endpoint: str,
+ output_dimensionality: Optional[int],
) -> str:
+ instances_json = f"""
+ 'instances',
+ JSONB_BUILD_ARRAY(
+ JSONB_BUILD_OBJECT(
+ 'content',
+ ${_POSTGRESQL_PARAMETER_TEXT_QUERY}::TEXT
+ )
+ )
+ """
+
+ params_list = []
+ if output_dimensionality is not None:
+ params_list.append(f"""
+ 'parameters',
+ JSONB_BUILD_OBJECT(
+ 'outputDimensionality',
+ {output_dimensionality}
+ )
+ """)
+
+ jsonb_build_args = ",\n".join([instances_json] + params_list)
+
return f"""
- SELECT spanner.FLOAT32_ARRAY( spanner.ML_PREDICT_ROW(
- '{vertex_ai_embedding_model_endpoint}',
- JSONB_BUILD_OBJECT(
- 'instances',
- JSONB_BUILD_ARRAY( JSONB_BUILD_OBJECT(
- 'content',
- ${_POSTGRESQL_PARAMETER_TEXT_QUERY}::TEXT
- ))
+ SELECT spanner.FLOAT32_ARRAY(
+ spanner.ML_PREDICT_ROW(
+ '{vertex_ai_embedding_model_endpoint}',
+ JSONB_BUILD_OBJECT(
+ {jsonb_build_args}
+ )
+ ) -> 'predictions' -> 0 -> 'embeddings' -> 'values'
)
- ) -> 'predictions'->0->'embeddings'->'values' )
"""
def _get_embedding_for_query(
database: Database,
dialect: DatabaseDialect,
- spanner_embedding_model_name: Optional[str],
- vertex_ai_embedding_model_endpoint: Optional[str],
+ spanner_gsql_embedding_model_name: Optional[str],
+ spanner_pg_vertex_ai_embedding_model_endpoint: Optional[str],
query: str,
+ output_dimensionality: Optional[int] = None,
) -> List[float]:
"""Gets the embedding for the query."""
if dialect == DatabaseDialect.POSTGRESQL:
embedding_query = _generate_postgresql_for_embedding_query(
- vertex_ai_embedding_model_endpoint
+ spanner_pg_vertex_ai_embedding_model_endpoint,
+ output_dimensionality,
)
params = {f"p{_POSTGRESQL_PARAMETER_TEXT_QUERY}": query}
else:
embedding_query = _generate_googlesql_for_embedding_query(
- spanner_embedding_model_name
+ spanner_gsql_embedding_model_name
)
params = {_GOOGLESQL_PARAMETER_TEXT_QUERY: query}
with database.snapshot() as snapshot:
@@ -101,8 +135,8 @@ def _get_embedding_for_query(
def _get_postgresql_distance_function(distance_type: str) -> str:
return {
- "COSINE_DISTANCE": "spanner.cosine_distance",
- "EUCLIDEAN_DISTANCE": "spanner.euclidean_distance",
+ "COSINE": "spanner.cosine_distance",
+ "EUCLIDEAN": "spanner.euclidean_distance",
"DOT_PRODUCT": "spanner.dot_product",
}[distance_type]
@@ -110,13 +144,13 @@ def _get_postgresql_distance_function(distance_type: str) -> str:
def _get_googlesql_distance_function(distance_type: str, ann: bool) -> str:
if ann:
return {
- "COSINE_DISTANCE": "APPROX_COSINE_DISTANCE",
- "EUCLIDEAN_DISTANCE": "APPROX_EUCLIDEAN_DISTANCE",
+ "COSINE": "APPROX_COSINE_DISTANCE",
+ "EUCLIDEAN": "APPROX_EUCLIDEAN_DISTANCE",
"DOT_PRODUCT": "APPROX_DOT_PRODUCT",
}[distance_type]
return {
- "COSINE_DISTANCE": "COSINE_DISTANCE",
- "EUCLIDEAN_DISTANCE": "EUCLIDEAN_DISTANCE",
+ "COSINE": "COSINE_DISTANCE",
+ "EUCLIDEAN": "EUCLIDEAN_DISTANCE",
"DOT_PRODUCT": "DOT_PRODUCT",
}[distance_type]
@@ -172,7 +206,7 @@ def _generate_sql_for_ann(
"""Generates a SQL query for ANN search."""
if dialect == DatabaseDialect.POSTGRESQL:
raise NotImplementedError(
- f"{_APPROXIMATE_NEAREST_NEIGHBORS} is not supported for PostgreSQL"
+ f"{APPROXIMATE_NEAREST_NEIGHBORS} is not supported for PostgreSQL"
" dialect."
)
distance_function = _get_googlesql_distance_function(distance_type, ann=True)
@@ -206,8 +240,6 @@ def similarity_search(
columns: List[str],
embedding_options: Dict[str, str],
credentials: Credentials,
- settings: SpannerToolSettings,
- tool_context: ToolContext,
additional_filter: Optional[str] = None,
search_options: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
@@ -234,21 +266,34 @@ def similarity_search(
columns (List[str]): A list of column names, representing the additional
columns to return in the search results.
embedding_options (Dict[str, str]): A dictionary of options to use for
- the embedding service. The following options are supported:
- - spanner_embedding_model_name: (For GoogleSQL dialect) The
+ the embedding service. **Exactly one of the following three keys
+ MUST be present in this dictionary**:
+ `vertex_ai_embedding_model_name`, `spanner_googlesql_embedding_model_name`,
+ or `spanner_postgresql_vertex_ai_embedding_model_endpoint`.
+ - vertex_ai_embedding_model_name (str): (Supported both **GoogleSQL and
+ PostgreSQL** dialects Spanner database) The name of a
+ public Vertex AI embedding model (e.g., `'text-embedding-005'`).
+ If specified, the tool generates embeddings client-side using the
+ Vertex AI embedding model.
+ - spanner_googlesql_embedding_model_name (str): (For GoogleSQL dialect) The
name of the embedding model that is registered in Spanner via a
`CREATE MODEL` statement. For more details, see
https://cloud.google.com/spanner/docs/ml-tutorial-embeddings#generate_and_store_text_embeddings
- - vertex_ai_embedding_model_endpoint: (For PostgreSQL dialect)
- The fully qualified endpoint of the Vertex AI embedding model,
- in the format of
+ If specified, embedding generation is performed using Spanner's
+ `ML.PREDICT` function.
+ - spanner_postgresql_vertex_ai_embedding_model_endpoint (str):
+ (For PostgreSQL dialect) The fully qualified endpoint of the Vertex AI
+ embedding model, in the format of
`projects/$project/locations/$location/publishers/google/models/$model_name`,
where $project is the project hosting the Vertex AI endpoint,
$location is the location of the endpoint, and $model_name is
the name of the text embedding model.
+ If specified, embedding generation is performed using Spanner's
+ `spanner.ML_PREDICT_ROW` function.
+ - output_dimensionality: Optional. The output dimensionality of the
+ embedding. If not specified, the embedding model's default output
+ dimensionality will be used.
credentials (Credentials): The credentials to use for the request.
- settings (SpannerToolSettings): The configuration for the tool.
- tool_context (ToolContext): The context for the tool.
additional_filter (Optional[str]): An optional filter to apply to the
search query. If provided, this will be added to the WHERE clause of the
final query.
@@ -257,9 +302,9 @@ def similarity_search(
- top_k: The number of most similar documents to return. The
default value is 4.
- distance_type: The distance type to use to perform the
- similarity search. Valid values include "COSINE_DISTANCE",
- "EUCLIDEAN_DISTANCE", and "DOT_PRODUCT". Default value is
- "COSINE_DISTANCE".
+ similarity search. Valid values include "COSINE",
+ "EUCLIDEAN", and "DOT_PRODUCT". Default value is
+ "COSINE".
- nearest_neighbors_algorithm: The nearest neighbors search
algorithm to use. Valid values include "EXACT_NEAREST_NEIGHBORS"
and "APPROXIMATE_NEAREST_NEIGHBORS". Default value is
@@ -287,15 +332,13 @@ def similarity_search(
... embedding_column_to_search="product_description_embedding",
... columns=["product_name", "product_description", "price_in_cents"],
... credentials=credentials,
- ... settings=settings,
- ... tool_context=tool_context,
... additional_filter="price_in_cents < 100000",
... embedding_options={
- ... "spanner_embedding_model_name": "my_embedding_model"
+ ... "vertex_ai_embedding_model_name": "text-embedding-005"
... },
... search_options={
... "top_k": 2,
- ... "distance_type": "COSINE_DISTANCE"
+ ... "distance_type": "COSINE"
... }
... )
{
@@ -336,33 +379,68 @@ def similarity_search(
embedding_options = {}
if search_options is None:
search_options = {}
- spanner_embedding_model_name = embedding_options.get(
- _SPANNER_EMBEDDING_MODEL_NAME
+
+ exclusive_embedding_model_keys = {
+ _VERTEX_AI_EMBEDDING_MODEL_NAME,
+ _SPANNER_GSQL_EMBEDDING_MODEL_NAME,
+ _SPANNER_PG_VERTEX_AI_EMBEDDING_MODEL_ENDPOINT,
+ }
+ if (
+ len(
+ exclusive_embedding_model_keys.intersection(
+ embedding_options.keys()
+ )
+ )
+ != 1
+ ):
+ raise ValueError("Exactly one embedding model option must be specified.")
+
+ vertex_ai_embedding_model_name = embedding_options.get(
+ _VERTEX_AI_EMBEDDING_MODEL_NAME
+ )
+ spanner_gsql_embedding_model_name = embedding_options.get(
+ _SPANNER_GSQL_EMBEDDING_MODEL_NAME
+ )
+ spanner_pg_vertex_ai_embedding_model_endpoint = embedding_options.get(
+ _SPANNER_PG_VERTEX_AI_EMBEDDING_MODEL_ENDPOINT
)
if (
database.database_dialect == DatabaseDialect.GOOGLE_STANDARD_SQL
- and spanner_embedding_model_name is None
+ and vertex_ai_embedding_model_name is None
+ and spanner_gsql_embedding_model_name is None
):
raise ValueError(
- f"embedding_options['{_SPANNER_EMBEDDING_MODEL_NAME}']"
- " must be specified for GoogleSQL dialect."
+ f"embedding_options['{_VERTEX_AI_EMBEDDING_MODEL_NAME}'] or"
+ f" embedding_options['{_SPANNER_GSQL_EMBEDDING_MODEL_NAME}'] must be"
+ " specified for GoogleSQL dialect Spanner database."
)
- vertex_ai_embedding_model_endpoint = embedding_options.get(
- _VERTEX_AI_EMBEDDING_MODEL_ENDPOINT
- )
if (
database.database_dialect == DatabaseDialect.POSTGRESQL
- and vertex_ai_embedding_model_endpoint is None
+ and vertex_ai_embedding_model_name is None
+ and spanner_pg_vertex_ai_embedding_model_endpoint is None
):
raise ValueError(
- f"embedding_options['{_VERTEX_AI_EMBEDDING_MODEL_ENDPOINT}']"
- " must be specified for PostgreSQL dialect."
+ f"embedding_options['{_VERTEX_AI_EMBEDDING_MODEL_NAME}'] or"
+ f" embedding_options['{_SPANNER_PG_VERTEX_AI_EMBEDDING_MODEL_ENDPOINT}']"
+ " must be specified for PostgreSQL dialect Spanner database."
+ )
+ output_dimensionality = embedding_options.get(_OUTPUT_DIMENSIONALITY)
+ if (
+ output_dimensionality is not None
+ and spanner_gsql_embedding_model_name is not None
+ ):
+ # Currently, Spanner GSQL Model ML.PREDICT does not support
+ # output_dimensionality parameter for inference embedding models.
+ raise ValueError(
+ f"embedding_options[{_OUTPUT_DIMENSIONALITY}] is not supported when"
+ f" embedding_options['{_SPANNER_GSQL_EMBEDDING_MODEL_NAME}'] is"
+ " specified."
)
# Use cosine distance by default.
distance_type = search_options.get(_DISTANCE_TYPE)
if distance_type is None:
- distance_type = "COSINE_DISTANCE"
+ distance_type = "COSINE"
top_k = search_options.get(_TOP_K)
if top_k is None:
@@ -370,26 +448,36 @@ def similarity_search(
# Use EXACT_NEAREST_NEIGHBORS (i.e. kNN) by default.
nearest_neighbors_algorithm = search_options.get(
- _NEAREST_NEIGHBORS_ALGORITHM, _EXACT_NEAREST_NEIGHBORS
+ _NEAREST_NEIGHBORS_ALGORITHM,
+ EXACT_NEAREST_NEIGHBORS,
)
if nearest_neighbors_algorithm not in (
- _EXACT_NEAREST_NEIGHBORS,
- _APPROXIMATE_NEAREST_NEIGHBORS,
+ EXACT_NEAREST_NEIGHBORS,
+ APPROXIMATE_NEAREST_NEIGHBORS,
):
raise NotImplementedError(
f"Unsupported search_options['{_NEAREST_NEIGHBORS_ALGORITHM}']:"
f" {nearest_neighbors_algorithm}"
)
- embedding = _get_embedding_for_query(
- database,
- database.database_dialect,
- spanner_embedding_model_name,
- vertex_ai_embedding_model_endpoint,
- query,
- )
+ # Generate embedding for the query according to the embedding options.
+ if vertex_ai_embedding_model_name:
+ embedding = utils.embed_contents(
+ vertex_ai_embedding_model_name,
+ [query],
+ output_dimensionality,
+ )[0]
+ else:
+ embedding = _get_embedding_for_query(
+ database,
+ database.database_dialect,
+ spanner_gsql_embedding_model_name,
+ spanner_pg_vertex_ai_embedding_model_endpoint,
+ query,
+ output_dimensionality,
+ )
- if nearest_neighbors_algorithm == _EXACT_NEAREST_NEIGHBORS:
+ if nearest_neighbors_algorithm == EXACT_NEAREST_NEIGHBORS:
sql = _generate_sql_for_knn(
database.database_dialect,
table_name,
@@ -438,5 +526,100 @@ def similarity_search(
except Exception as ex:
return {
"status": "ERROR",
- "error_details": str(ex),
+ "error_details": repr(ex),
+ }
+
+
+def vector_store_similarity_search(
+ query: str,
+ credentials: Credentials,
+ settings: SpannerToolSettings,
+) -> Dict[str, Any]:
+ """Performs a semantic similarity search to retrieve relevant context from the Spanner vector store.
+
+ This function performs vector similarity search directly on a vector store
+ table in Spanner database and returns the relevant data.
+
+ Args:
+ query (str): The search string based on the user's question.
+ credentials (Credentials): The credentials to use for the request.
+ settings (SpannerToolSettings): The configuration for the tool.
+
+ Returns:
+ Dict[str, Any]: A dictionary representing the result of the search.
+ On success, it contains {"status": "SUCCESS", "rows": [...]}. The last
+ column of each row is the distance between the query and the row result.
+ On error, it contains {"status": "ERROR", "error_details": "..."}.
+
+ Examples:
+ >>> vector_store_similarity_search(
+ ... query="Spanner database optimization techniques for high QPS",
+ ... credentials=credentials,
+ ... settings=settings
+ ... )
+ {
+ "status": "SUCCESS",
+ "rows": [
+ (
+ "Optimizing Query Performance",
+ 0.12,
+ ),
+ (
+ "Schema Design Best Practices",
+ 0.25,
+ ),
+ (
+ "Using Secondary Indexes Effectively",
+ 0.31,
+ ),
+ ...
+ ],
+ }
+ """
+
+ try:
+ if not settings or not settings.vector_store_settings:
+ raise ValueError("Spanner vector store settings are not set.")
+
+ # Get the embedding model settings.
+ embedding_options = {
+ _VERTEX_AI_EMBEDDING_MODEL_NAME: (
+ settings.vector_store_settings.vertex_ai_embedding_model_name
+ ),
+ _OUTPUT_DIMENSIONALITY: settings.vector_store_settings.vector_length,
+ }
+
+ # Get the search settings.
+ search_options = {
+ _TOP_K: settings.vector_store_settings.top_k,
+ _DISTANCE_TYPE: settings.vector_store_settings.distance_type,
+ _NEAREST_NEIGHBORS_ALGORITHM: (
+ settings.vector_store_settings.nearest_neighbors_algorithm
+ ),
+ }
+ if (
+ settings.vector_store_settings.nearest_neighbors_algorithm
+ == APPROXIMATE_NEAREST_NEIGHBORS
+ ):
+ search_options[_NUM_LEAVES_TO_SEARCH] = (
+ settings.vector_store_settings.num_leaves_to_search
+ )
+
+ return similarity_search(
+ project_id=settings.vector_store_settings.project_id,
+ instance_id=settings.vector_store_settings.instance_id,
+ database_id=settings.vector_store_settings.database_id,
+ table_name=settings.vector_store_settings.table_name,
+ query=query,
+ embedding_column_to_search=settings.vector_store_settings.embedding_column,
+ columns=settings.vector_store_settings.selected_columns,
+ embedding_options=embedding_options,
+ credentials=credentials,
+ additional_filter=settings.vector_store_settings.additional_filter,
+ search_options=search_options,
+ )
+ except Exception as ex:
+ return {
+ "status": "ERROR",
+ "error_details": repr(ex),
}
diff --git a/src/google/adk/tools/spanner/settings.py b/src/google/adk/tools/spanner/settings.py
index 5d097258f4..a76331ba65 100644
--- a/src/google/adk/tools/spanner/settings.py
+++ b/src/google/adk/tools/spanner/settings.py
@@ -16,20 +16,115 @@
from enum import Enum
from typing import List
+from typing import Literal
+from typing import Optional
from pydantic import BaseModel
+from pydantic import model_validator
from ...utils.feature_decorator import experimental
+# Vector similarity search nearest neighbors search algorithms.
+EXACT_NEAREST_NEIGHBORS = "EXACT_NEAREST_NEIGHBORS"
+APPROXIMATE_NEAREST_NEIGHBORS = "APPROXIMATE_NEAREST_NEIGHBORS"
+NearestNeighborsAlgorithm = Literal[
+ EXACT_NEAREST_NEIGHBORS,
+ APPROXIMATE_NEAREST_NEIGHBORS,
+]
+
class Capabilities(Enum):
"""Capabilities indicating what type of operation tools are allowed to be performed on Spanner."""
- DATA_READ = 'data_read'
+ DATA_READ = "data_read"
"""Read only data operations tools are allowed."""
-@experimental('Tool settings defaults may have breaking change in the future.')
+class SpannerVectorStoreSettings(BaseModel):
+ """Settings for Spanner Vector Store.
+
+ This is used for vector similarity search in a Spanner vector store table.
+ Provide the vector store table and the embedding model settings to use with
+ the `vector_store_similarity_search` tool.
+ """
+
+ project_id: str
+ """Required. The GCP project id in which the Spanner database resides."""
+
+ instance_id: str
+ """Required. The instance id of the Spanner database."""
+
+ database_id: str
+ """Required. The database id of the Spanner database."""
+
+ table_name: str
+ """Required. The name of the vector store table to use for vector similarity search."""
+
+ content_column: str
+ """Required. The name of the content column in the vector store table. By default, this column value is also returned as part of the vector similarity search result."""
+
+ embedding_column: str
+ """Required. The name of the embedding column to search in the vector store table."""
+
+ vector_length: int
+ """Required. The the dimension of the vectors in the `embedding_column`."""
+
+ vertex_ai_embedding_model_name: str
+ """Required. The Vertex AI embedding model name, which is used to generate embeddings for vector store and vector similarity search.
+ For example, 'text-embedding-005'.
+
+ Note: the output dimensionality of the embedding model should be the same as the value specified in the `vector_length` field.
+ Otherwise, a runtime error might be raised during a query.
+ """
+
+ selected_columns: List[str] = []
+ """Required. The vector store table columns to return in the vector similarity search result.
+
+ By default, only the `content_column` value and the distance value are returned.
+ If sepecified, the list of selected columns and the distance value are returned.
+ For example, if `selected_columns` is ['col1', 'col2'], then the result will contain the values of 'col1' and 'col2' columns and the distance value.
+ """
+
+ nearest_neighbors_algorithm: NearestNeighborsAlgorithm = (
+ "EXACT_NEAREST_NEIGHBORS"
+ )
+ """The algorithm used to perform vector similarity search. This value can be EXACT_NEAREST_NEIGHBORS or APPROXIMATE_NEAREST_NEIGHBORS.
+
+ For more details about EXACT_NEAREST_NEIGHBORS, see https://docs.cloud.google.com/spanner/docs/find-k-nearest-neighbors
+ For more details about APPROXIMATE_NEAREST_NEIGHBORS, see https://docs.cloud.google.com/spanner/docs/find-approximate-nearest-neighbors
+ """
+
+ top_k: int = 4
+ """Required. The number of neighbors to return for each vector similarity search query. The default value is 4."""
+
+ distance_type: str = "COSINE"
+ """Required. The distance metric used to build the vector index or perform vector similarity search. This value can be COSINE, DOT_PRODUCT, or EUCLIDEAN."""
+
+ num_leaves_to_search: Optional[int] = None
+ """Optional. This option specifies how many leaf nodes of the index are searched.
+
+ Note: this option is only used when the nearest neighbors search algorithm (`nearest_neighbors_algorithm`) is APPROXIMATE_NEAREST_NEIGHBORS.
+ For more details, see https://docs.cloud.google.com/spanner/docs/vector-index-best-practices
+ """
+
+ additional_filter: Optional[str] = None
+ """Optional. An optional filter to apply to the search query. If provided, this will be added to the WHERE clause of the final query."""
+
+ @model_validator(mode="after")
+ def __post_init__(self):
+ """Validate the embedding settings."""
+ if not self.vector_length or self.vector_length <= 0:
+ raise ValueError(
+ "Invalid vector length in the Spanner vector store settings."
+ )
+
+ if not self.selected_columns:
+ self.selected_columns = [self.content_column]
+
+ return self
+
+
+@experimental("Tool settings defaults may have breaking change in the future.")
class SpannerToolSettings(BaseModel):
"""Settings for Spanner tools."""
@@ -44,3 +139,6 @@ class SpannerToolSettings(BaseModel):
max_executed_query_result_rows: int = 50
"""Maximum number of rows to return from a query result."""
+
+ vector_store_settings: Optional[SpannerVectorStoreSettings] = None
+ """Settings for Spanner vector store and vector similarity search."""
diff --git a/src/google/adk/tools/spanner/spanner_toolset.py b/src/google/adk/tools/spanner/spanner_toolset.py
index 861314abb3..6496014f74 100644
--- a/src/google/adk/tools/spanner/spanner_toolset.py
+++ b/src/google/adk/tools/spanner/spanner_toolset.py
@@ -47,6 +47,8 @@ class SpannerToolset(BaseToolset):
- spanner_list_named_schemas
- spanner_get_table_schema
- spanner_execute_sql
+ - spanner_similarity_search
+ - spanner_vector_store_similarity_search
"""
def __init__(
@@ -121,6 +123,16 @@ async def get_tools(
tool_settings=self._tool_settings,
)
)
+ if self._tool_settings.vector_store_settings:
+ # Only add the vector store similarity search tool if the vector store
+ # settings are specified.
+ all_tools.append(
+ GoogleTool(
+ func=search_tool.vector_store_similarity_search,
+ credentials_config=self._credentials_config,
+ tool_settings=self._tool_settings,
+ )
+ )
return [
tool
diff --git a/src/google/adk/tools/spanner/utils.py b/src/google/adk/tools/spanner/utils.py
index 64c03859e1..f1b710ec85 100644
--- a/src/google/adk/tools/spanner/utils.py
+++ b/src/google/adk/tools/spanner/utils.py
@@ -105,3 +105,27 @@ def execute_sql(
"status": "ERROR",
"error_details": str(ex),
}
+
+
+def embed_contents(
+ vertex_ai_embedding_model_name: str,
+ contents: list[str],
+ output_dimensionality: Optional[int] = None,
+) -> list[list[float]]:
+ """Embed the given contents into list of vectors using the Vertex AI embedding model endpoint."""
+ try:
+ from google.genai import Client
+ from google.genai.types import EmbedContentConfig
+
+ client = Client()
+ config = EmbedContentConfig()
+ if output_dimensionality:
+ config.output_dimensionality = output_dimensionality
+ response = client.models.embed_content(
+ model=vertex_ai_embedding_model_name,
+ contents=contents,
+ config=config,
+ )
+ return [list(e.values) for e in response.embeddings]
+ except Exception as ex:
+ raise RuntimeError(f"Failed to embed content: {ex!r}") from ex
diff --git a/src/google/adk/tools/transfer_to_agent_tool.py b/src/google/adk/tools/transfer_to_agent_tool.py
index 99ee234b30..2124e6aab9 100644
--- a/src/google/adk/tools/transfer_to_agent_tool.py
+++ b/src/google/adk/tools/transfer_to_agent_tool.py
@@ -14,6 +14,12 @@
from __future__ import annotations
+from typing import Optional
+
+from google.genai import types
+from typing_extensions import override
+
+from .function_tool import FunctionTool
from .tool_context import ToolContext
@@ -23,7 +29,61 @@ def transfer_to_agent(agent_name: str, tool_context: ToolContext) -> None:
This tool hands off control to another agent when it's more suitable to
answer the user's question according to the agent's description.
+ Note:
+ For most use cases, you should use TransferToAgentTool instead of this
+ function directly. TransferToAgentTool provides additional enum constraints
+ that prevent LLMs from hallucinating invalid agent names.
+
Args:
agent_name: the agent name to transfer to.
"""
tool_context.actions.transfer_to_agent = agent_name
+
+
+class TransferToAgentTool(FunctionTool):
+ """A specialized FunctionTool for agent transfer with enum constraints.
+
+ This tool enhances the base transfer_to_agent function by adding JSON Schema
+ enum constraints to the agent_name parameter. This prevents LLMs from
+ hallucinating invalid agent names by restricting choices to only valid agents.
+
+ Attributes:
+ agent_names: List of valid agent names that can be transferred to.
+ """
+
+ def __init__(
+ self,
+ agent_names: list[str],
+ ):
+ """Initialize the TransferToAgentTool.
+
+ Args:
+ agent_names: List of valid agent names that can be transferred to.
+ """
+ super().__init__(func=transfer_to_agent)
+ self._agent_names = agent_names
+
+ @override
+ def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
+ """Add enum constraint to the agent_name parameter.
+
+ Returns:
+ FunctionDeclaration with enum constraint on agent_name parameter.
+ """
+ function_decl = super()._get_declaration()
+ if not function_decl:
+ return function_decl
+
+ # Handle parameters (types.Schema object)
+ if function_decl.parameters:
+ agent_name_schema = function_decl.parameters.properties.get('agent_name')
+ if agent_name_schema:
+ agent_name_schema.enum = self._agent_names
+
+ # Handle parameters_json_schema (dict)
+ if function_decl.parameters_json_schema:
+ properties = function_decl.parameters_json_schema.get('properties', {})
+ if 'agent_name' in properties:
+ properties['agent_name']['enum'] = self._agent_names
+
+ return function_decl
diff --git a/src/google/adk/tools/vertex_ai_search_tool.py b/src/google/adk/tools/vertex_ai_search_tool.py
index aff5be1552..e0c228be0e 100644
--- a/src/google/adk/tools/vertex_ai_search_tool.py
+++ b/src/google/adk/tools/vertex_ai_search_tool.py
@@ -14,6 +14,7 @@
from __future__ import annotations
+import logging
from typing import Optional
from typing import TYPE_CHECKING
@@ -25,6 +26,8 @@
from .base_tool import BaseTool
from .tool_context import ToolContext
+logger = logging.getLogger('google_adk.' + __name__)
+
if TYPE_CHECKING:
from ..models import LlmRequest
@@ -102,6 +105,30 @@ async def process_llm_request(
)
llm_request.config = llm_request.config or types.GenerateContentConfig()
llm_request.config.tools = llm_request.config.tools or []
+
+ # Format data_store_specs concisely for logging
+ if self.data_store_specs:
+ spec_ids = [
+ spec.data_store.split('/')[-1] if spec.data_store else 'unnamed'
+ for spec in self.data_store_specs
+ ]
+ specs_info = (
+ f'{len(self.data_store_specs)} spec(s): [{", ".join(spec_ids)}]'
+ )
+ else:
+ specs_info = None
+
+ logger.debug(
+ 'Adding Vertex AI Search tool config to LLM request: '
+ 'datastore=%s, engine=%s, filter=%s, max_results=%s, '
+ 'data_store_specs=%s',
+ self.data_store_id,
+ self.search_engine_id,
+ self.filter,
+ self.max_results,
+ specs_info,
+ )
+
llm_request.config.tools.append(
types.Tool(
retrieval=types.Retrieval(
diff --git a/src/google/adk/utils/_client_labels_utils.py b/src/google/adk/utils/_client_labels_utils.py
new file mode 100644
index 0000000000..72858c3c1d
--- /dev/null
+++ b/src/google/adk/utils/_client_labels_utils.py
@@ -0,0 +1,78 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from contextlib import contextmanager
+import contextvars
+import os
+import sys
+from typing import List
+
+from .. import version
+
+_ADK_LABEL = "google-adk"
+_LANGUAGE_LABEL = "gl-python"
+_AGENT_ENGINE_TELEMETRY_TAG = "remote_reasoning_engine"
+_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME = "GOOGLE_CLOUD_AGENT_ENGINE_ID"
+
+
+EVAL_CLIENT_LABEL = f"google-adk-eval/{version.__version__}"
+"""Label used to denote calls emerging to external system as a part of Evals."""
+
+# The ContextVar holds client label collected for the current request.
+_LABEL_CONTEXT: contextvars.ContextVar[str] = contextvars.ContextVar(
+ "_LABEL_CONTEXT", default=None
+)
+
+
+def _get_default_labels() -> List[str]:
+ """Returns a list of labels that are always added."""
+ framework_label = f"{_ADK_LABEL}/{version.__version__}"
+
+ if os.environ.get(_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME):
+ framework_label = f"{framework_label}+{_AGENT_ENGINE_TELEMETRY_TAG}"
+
+ language_label = f"{_LANGUAGE_LABEL}/" + sys.version.split()[0]
+ return [framework_label, language_label]
+
+
+@contextmanager
+def client_label_context(client_label: str):
+ """Runs the operation within the context of the given client label."""
+ current_client_label = _LABEL_CONTEXT.get()
+
+ if current_client_label is not None:
+ raise ValueError(
+ "Client label already exists. You can only add one client label."
+ )
+
+ token = _LABEL_CONTEXT.set(client_label)
+
+ try:
+ yield
+ finally:
+ # Restore the previous state of the context variable
+ _LABEL_CONTEXT.reset(token)
+
+
+def get_client_labels() -> List[str]:
+ """Returns the current list of client labels that can be added to HTTP Headers."""
+ labels = _get_default_labels()
+ current_client_label = _LABEL_CONTEXT.get()
+
+ if current_client_label:
+ labels.append(current_client_label)
+
+ return labels
diff --git a/src/google/adk/utils/context_utils.py b/src/google/adk/utils/context_utils.py
index bd8dacb9d8..a75feae3dd 100644
--- a/src/google/adk/utils/context_utils.py
+++ b/src/google/adk/utils/context_utils.py
@@ -20,30 +20,7 @@
from __future__ import annotations
-from contextlib import AbstractAsyncContextManager
-from typing import Any
-from typing import AsyncGenerator
+from contextlib import aclosing
-
-class Aclosing(AbstractAsyncContextManager):
- """Async context manager for safely finalizing an asynchronously cleaned-up
- resource such as an async generator, calling its ``aclose()`` method.
- Needed to correctly close contexts for OTel spans.
- See https://github.com/google/adk-python/issues/1670#issuecomment-3115891100.
-
- Based on
- https://docs.python.org/3/library/contextlib.html#contextlib.aclosing
- which is available in Python 3.10+.
-
- TODO: replace all occurrences with contextlib.aclosing once Python 3.9 is no
- longer supported.
- """
-
- def __init__(self, async_generator: AsyncGenerator[Any, None]):
- self.async_generator = async_generator
-
- async def __aenter__(self):
- return self.async_generator
-
- async def __aexit__(self, *exc_info):
- await self.async_generator.aclose()
+# Re-export aclosing for backward compatibility
+Aclosing = aclosing
diff --git a/src/google/adk/utils/streaming_utils.py b/src/google/adk/utils/streaming_utils.py
index eb75365467..eae80aa7cc 100644
--- a/src/google/adk/utils/streaming_utils.py
+++ b/src/google/adk/utils/streaming_utils.py
@@ -14,6 +14,7 @@
from __future__ import annotations
+from typing import Any
from typing import AsyncGenerator
from typing import Optional
@@ -43,6 +44,12 @@ def __init__(self):
self._current_text_is_thought: Optional[bool] = None
self._finish_reason: Optional[types.FinishReason] = None
+ # For streaming function call arguments
+ self._current_fc_name: Optional[str] = None
+ self._current_fc_args: dict[str, Any] = {}
+ self._current_fc_id: Optional[str] = None
+ self._current_thought_signature: Optional[str] = None
+
def _flush_text_buffer_to_sequence(self):
"""Flush current text buffer to parts sequence.
@@ -61,6 +68,171 @@ def _flush_text_buffer_to_sequence(self):
self._current_text_buffer = ''
self._current_text_is_thought = None
+ def _get_value_from_partial_arg(
+ self, partial_arg: types.PartialArg, json_path: str
+ ):
+ """Extract value from a partial argument.
+
+ Args:
+ partial_arg: The partial argument object
+ json_path: JSONPath for this argument
+
+ Returns:
+ Tuple of (value, has_value) where has_value indicates if a value exists
+ """
+ value = None
+ has_value = False
+
+ if partial_arg.string_value is not None:
+ # For streaming strings, append chunks to existing value
+ string_chunk = partial_arg.string_value
+ has_value = True
+
+ # Get current value for this path (if any)
+ path_without_prefix = (
+ json_path[2:] if json_path.startswith('$.') else json_path
+ )
+ path_parts = path_without_prefix.split('.')
+
+ # Try to get existing value
+ existing_value = self._current_fc_args
+ for part in path_parts:
+ if isinstance(existing_value, dict) and part in existing_value:
+ existing_value = existing_value[part]
+ else:
+ existing_value = None
+ break
+
+ # Append to existing string or set new value
+ if isinstance(existing_value, str):
+ value = existing_value + string_chunk
+ else:
+ value = string_chunk
+
+ elif partial_arg.number_value is not None:
+ value = partial_arg.number_value
+ has_value = True
+ elif partial_arg.bool_value is not None:
+ value = partial_arg.bool_value
+ has_value = True
+ elif partial_arg.null_value is not None:
+ value = None
+ has_value = True
+
+ return value, has_value
+
+ def _set_value_by_json_path(self, json_path: str, value: Any):
+ """Set a value in _current_fc_args using JSONPath notation.
+
+ Args:
+ json_path: JSONPath string like "$.location" or "$.location.latitude"
+ value: The value to set
+ """
+ # Remove leading "$." from jsonPath
+ if json_path.startswith('$.'):
+ path = json_path[2:]
+ else:
+ path = json_path
+
+ # Split path into components
+ path_parts = path.split('.')
+
+ # Navigate to the correct location and set the value
+ current = self._current_fc_args
+ for part in path_parts[:-1]:
+ if part not in current:
+ current[part] = {}
+ current = current[part]
+
+ # Set the final value
+ current[path_parts[-1]] = value
+
+ def _flush_function_call_to_sequence(self):
+ """Flush current function call to parts sequence.
+
+ This creates a complete FunctionCall part from accumulated partial args.
+ """
+ if self._current_fc_name:
+ # Create function call part with accumulated args
+ fc_part = types.Part.from_function_call(
+ name=self._current_fc_name,
+ args=self._current_fc_args.copy(),
+ )
+
+ # Set the ID if provided (directly on the function_call object)
+ if self._current_fc_id and fc_part.function_call:
+ fc_part.function_call.id = self._current_fc_id
+
+ # Set thought_signature if provided (on the Part, not FunctionCall)
+ if self._current_thought_signature:
+ fc_part.thought_signature = self._current_thought_signature
+
+ self._parts_sequence.append(fc_part)
+
+ # Reset FC state
+ self._current_fc_name = None
+ self._current_fc_args = {}
+ self._current_fc_id = None
+ self._current_thought_signature = None
+
+ def _process_streaming_function_call(self, fc: types.FunctionCall):
+ """Process a streaming function call with partialArgs.
+
+ Args:
+ fc: The function call object with partial_args
+ """
+ # Save function name if present (first chunk)
+ if fc.name:
+ self._current_fc_name = fc.name
+ if fc.id:
+ self._current_fc_id = fc.id
+
+ # Process each partial argument
+ for partial_arg in getattr(fc, 'partial_args', []):
+ json_path = partial_arg.json_path
+ if not json_path:
+ continue
+
+ # Extract value from partial arg
+ value, has_value = self._get_value_from_partial_arg(
+ partial_arg, json_path
+ )
+
+ # Set the value using JSONPath (only if a value was provided)
+ if has_value:
+ self._set_value_by_json_path(json_path, value)
+
+ # Check if function call is complete
+ fc_will_continue = getattr(fc, 'will_continue', False)
+ if not fc_will_continue:
+ # Function call complete, flush it
+ self._flush_text_buffer_to_sequence()
+ self._flush_function_call_to_sequence()
+
+ def _process_function_call_part(self, part: types.Part):
+ """Process a function call part (streaming or non-streaming).
+
+ Args:
+ part: The part containing a function call
+ """
+ fc = part.function_call
+
+ # Check if this is a streaming FC (has partialArgs)
+ if hasattr(fc, 'partial_args') and fc.partial_args:
+ # Streaming function call arguments
+
+ # Save thought_signature from the part (first chunk should have it)
+ if part.thought_signature and not self._current_thought_signature:
+ self._current_thought_signature = part.thought_signature
+ self._process_streaming_function_call(fc)
+ else:
+ # Non-streaming function call (standard format with args)
+ # Skip empty function calls (used as streaming end markers)
+ if fc.name:
+ # Flush any buffered text first, then add the FC part
+ self._flush_text_buffer_to_sequence()
+ self._parts_sequence.append(part)
+
async def process_response(
self, response: types.GenerateContentResponse
) -> AsyncGenerator[LlmResponse, None]:
@@ -101,8 +273,12 @@ async def process_response(
if not self._current_text_buffer:
self._current_text_is_thought = part.thought
self._current_text_buffer += part.text
+ elif part.function_call:
+ # Process function call (handles both streaming Args and
+ # non-streaming Args)
+ self._process_function_call_part(part)
else:
- # Non-text part (function_call, bytes, etc.)
+ # Other non-text parts (bytes, etc.)
# Flush any buffered text first, then add the non-text part
self._flush_text_buffer_to_sequence()
self._parts_sequence.append(part)
@@ -155,8 +331,9 @@ def close(self) -> Optional[LlmResponse]:
if is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING):
# Always generate final aggregated response in progressive mode
if self._response and self._response.candidates:
- # Flush any remaining text buffer to complete the sequence
+ # Flush any remaining buffers to complete the sequence
self._flush_text_buffer_to_sequence()
+ self._flush_function_call_to_sequence()
# Use the parts sequence which preserves original ordering
final_parts = self._parts_sequence
diff --git a/src/google/adk/version.py b/src/google/adk/version.py
index 0a21522cb6..a287db284c 100644
--- a/src/google/adk/version.py
+++ b/src/google/adk/version.py
@@ -13,4 +13,4 @@
# limitations under the License.
# version: major.minor.patch
-__version__ = "1.18.0"
+__version__ = "1.20.0"
diff --git a/tests/unittests/a2a/converters/test_event_converter.py b/tests/unittests/a2a/converters/test_event_converter.py
index cb3f7a6858..49b7d3c2b6 100644
--- a/tests/unittests/a2a/converters/test_event_converter.py
+++ b/tests/unittests/a2a/converters/test_event_converter.py
@@ -12,50 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import sys
from unittest.mock import Mock
from unittest.mock import patch
+from a2a.types import DataPart
+from a2a.types import Message
+from a2a.types import Role
+from a2a.types import Task
+from a2a.types import TaskState
+from a2a.types import TaskStatusUpdateEvent
+from google.adk.a2a.converters.event_converter import _create_artifact_id
+from google.adk.a2a.converters.event_converter import _create_error_status_event
+from google.adk.a2a.converters.event_converter import _create_status_update_event
+from google.adk.a2a.converters.event_converter import _get_adk_metadata_key
+from google.adk.a2a.converters.event_converter import _get_context_metadata
+from google.adk.a2a.converters.event_converter import _process_long_running_tool
+from google.adk.a2a.converters.event_converter import _serialize_metadata_value
+from google.adk.a2a.converters.event_converter import ARTIFACT_ID_SEPARATOR
+from google.adk.a2a.converters.event_converter import convert_a2a_task_to_event
+from google.adk.a2a.converters.event_converter import convert_event_to_a2a_events
+from google.adk.a2a.converters.event_converter import convert_event_to_a2a_message
+from google.adk.a2a.converters.event_converter import DEFAULT_ERROR_MESSAGE
+from google.adk.a2a.converters.utils import ADK_METADATA_KEY_PREFIX
+from google.adk.agents.invocation_context import InvocationContext
+from google.adk.events.event import Event
+from google.adk.events.event_actions import EventActions
import pytest
-# Skip all tests in this module if Python version is less than 3.10
-pytestmark = pytest.mark.skipif(
- sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
-)
-
-# Import dependencies with version checking
-try:
- from a2a.types import DataPart
- from a2a.types import Message
- from a2a.types import Role
- from a2a.types import Task
- from a2a.types import TaskState
- from a2a.types import TaskStatusUpdateEvent
- from google.adk.a2a.converters.event_converter import _create_artifact_id
- from google.adk.a2a.converters.event_converter import _create_error_status_event
- from google.adk.a2a.converters.event_converter import _create_status_update_event
- from google.adk.a2a.converters.event_converter import _get_adk_metadata_key
- from google.adk.a2a.converters.event_converter import _get_context_metadata
- from google.adk.a2a.converters.event_converter import _process_long_running_tool
- from google.adk.a2a.converters.event_converter import _serialize_metadata_value
- from google.adk.a2a.converters.event_converter import ARTIFACT_ID_SEPARATOR
- from google.adk.a2a.converters.event_converter import convert_a2a_task_to_event
- from google.adk.a2a.converters.event_converter import convert_event_to_a2a_events
- from google.adk.a2a.converters.event_converter import convert_event_to_a2a_message
- from google.adk.a2a.converters.event_converter import DEFAULT_ERROR_MESSAGE
- from google.adk.a2a.converters.utils import ADK_METADATA_KEY_PREFIX
- from google.adk.agents.invocation_context import InvocationContext
- from google.adk.events.event import Event
- from google.adk.events.event_actions import EventActions
-except ImportError as e:
- if sys.version_info < (3, 10):
- # Imports are not needed since tests will be skipped due to pytestmark.
- # The imported names are only used within test methods, not at module level,
- # so no NameError occurs during module compilation.
- pass
- else:
- raise e
-
class TestEventConverter:
"""Test suite for event_converter module."""
diff --git a/tests/unittests/a2a/converters/test_part_converter.py b/tests/unittests/a2a/converters/test_part_converter.py
index 5a8bad1096..541ab7709d 100644
--- a/tests/unittests/a2a/converters/test_part_converter.py
+++ b/tests/unittests/a2a/converters/test_part_converter.py
@@ -13,38 +13,21 @@
# limitations under the License.
import json
-import sys
from unittest.mock import Mock
from unittest.mock import patch
+from a2a import types as a2a_types
+from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT
+from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE
+from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL
+from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE
+from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_KEY
+from google.adk.a2a.converters.part_converter import convert_a2a_part_to_genai_part
+from google.adk.a2a.converters.part_converter import convert_genai_part_to_a2a_part
+from google.adk.a2a.converters.utils import _get_adk_metadata_key
+from google.genai import types as genai_types
import pytest
-# Skip all tests in this module if Python version is less than 3.10
-pytestmark = pytest.mark.skipif(
- sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
-)
-
-# Import dependencies with version checking
-try:
- from a2a import types as a2a_types
- from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT
- from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE
- from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL
- from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE
- from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_KEY
- from google.adk.a2a.converters.part_converter import convert_a2a_part_to_genai_part
- from google.adk.a2a.converters.part_converter import convert_genai_part_to_a2a_part
- from google.adk.a2a.converters.utils import _get_adk_metadata_key
- from google.genai import types as genai_types
-except ImportError as e:
- if sys.version_info < (3, 10):
- # Imports are not needed since tests will be skipped due to pytestmark.
- # The imported names are only used within test methods, not at module level,
- # so no NameError occurs during module compilation.
- pass
- else:
- raise e
-
class TestConvertA2aPartToGenaiPart:
"""Test cases for convert_a2a_part_to_genai_part function."""
diff --git a/tests/unittests/a2a/converters/test_request_converter.py b/tests/unittests/a2a/converters/test_request_converter.py
index a7c21e4dbc..173b122d7c 100644
--- a/tests/unittests/a2a/converters/test_request_converter.py
+++ b/tests/unittests/a2a/converters/test_request_converter.py
@@ -12,33 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import sys
from unittest.mock import Mock
from unittest.mock import patch
+from a2a.server.agent_execution import RequestContext
+from google.adk.a2a.converters.request_converter import _get_user_id
+from google.adk.a2a.converters.request_converter import convert_a2a_request_to_agent_run_request
+from google.adk.runners import RunConfig
+from google.genai import types as genai_types
import pytest
-# Skip all tests in this module if Python version is less than 3.10
-pytestmark = pytest.mark.skipif(
- sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
-)
-
-# Import dependencies with version checking
-try:
- from a2a.server.agent_execution import RequestContext
- from google.adk.a2a.converters.request_converter import _get_user_id
- from google.adk.a2a.converters.request_converter import convert_a2a_request_to_agent_run_request
- from google.adk.runners import RunConfig
- from google.genai import types as genai_types
-except ImportError as e:
- if sys.version_info < (3, 10):
- # Imports are not needed since tests will be skipped due to pytestmark.
- # The imported names are only used within test methods, not at module level,
- # so no NameError occurs during module compilation.
- pass
- else:
- raise e
-
class TestGetUserId:
"""Test cases for _get_user_id function."""
diff --git a/tests/unittests/a2a/converters/test_utils.py b/tests/unittests/a2a/converters/test_utils.py
index 6c8511161a..0d896852aa 100644
--- a/tests/unittests/a2a/converters/test_utils.py
+++ b/tests/unittests/a2a/converters/test_utils.py
@@ -12,31 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import sys
-
+from google.adk.a2a.converters.utils import _from_a2a_context_id
+from google.adk.a2a.converters.utils import _get_adk_metadata_key
+from google.adk.a2a.converters.utils import _to_a2a_context_id
+from google.adk.a2a.converters.utils import ADK_CONTEXT_ID_PREFIX
+from google.adk.a2a.converters.utils import ADK_METADATA_KEY_PREFIX
import pytest
-# Skip all tests in this module if Python version is less than 3.10
-pytestmark = pytest.mark.skipif(
- sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
-)
-
-# Import dependencies with version checking
-try:
- from google.adk.a2a.converters.utils import _from_a2a_context_id
- from google.adk.a2a.converters.utils import _get_adk_metadata_key
- from google.adk.a2a.converters.utils import _to_a2a_context_id
- from google.adk.a2a.converters.utils import ADK_CONTEXT_ID_PREFIX
- from google.adk.a2a.converters.utils import ADK_METADATA_KEY_PREFIX
-except ImportError as e:
- if sys.version_info < (3, 10):
- # Imports are not needed since tests will be skipped due to pytestmark.
- # The imported names are only used within test methods, not at module level,
- # so no NameError occurs during module compilation.
- pass
- else:
- raise e
-
class TestUtilsFunctions:
"""Test suite for utils module functions."""
diff --git a/tests/unittests/a2a/executor/test_a2a_agent_executor.py b/tests/unittests/a2a/executor/test_a2a_agent_executor.py
index 4bcc7a91d7..58d7521f7d 100644
--- a/tests/unittests/a2a/executor/test_a2a_agent_executor.py
+++ b/tests/unittests/a2a/executor/test_a2a_agent_executor.py
@@ -12,41 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import sys
from unittest.mock import AsyncMock
from unittest.mock import Mock
from unittest.mock import patch
+from a2a.server.agent_execution.context import RequestContext
+from a2a.server.events.event_queue import EventQueue
+from a2a.types import Message
+from a2a.types import TaskState
+from a2a.types import TextPart
+from google.adk.a2a.converters.request_converter import AgentRunRequest
+from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor
+from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutorConfig
+from google.adk.events.event import Event
+from google.adk.runners import RunConfig
+from google.adk.runners import Runner
+from google.genai.types import Content
import pytest
-# Skip all tests in this module if Python version is less than 3.10
-pytestmark = pytest.mark.skipif(
- sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
-)
-
-# Import dependencies with version checking
-try:
- from a2a.server.agent_execution.context import RequestContext
- from a2a.server.events.event_queue import EventQueue
- from a2a.types import Message
- from a2a.types import TaskState
- from a2a.types import TextPart
- from google.adk.a2a.converters.request_converter import AgentRunRequest
- from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor
- from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutorConfig
- from google.adk.events.event import Event
- from google.adk.runners import RunConfig
- from google.adk.runners import Runner
- from google.genai.types import Content
-except ImportError as e:
- if sys.version_info < (3, 10):
- # Imports are not needed since tests will be skipped due to pytestmark.
- # The imported names are only used within test methods, not at module level,
- # so no NameError occurs during module compilation.
- pass
- else:
- raise e
-
class TestA2aAgentExecutor:
"""Test suite for A2aAgentExecutor class."""
diff --git a/tests/unittests/a2a/executor/test_task_result_aggregator.py b/tests/unittests/a2a/executor/test_task_result_aggregator.py
index 9d03db9dc8..b809b62728 100644
--- a/tests/unittests/a2a/executor/test_task_result_aggregator.py
+++ b/tests/unittests/a2a/executor/test_task_result_aggregator.py
@@ -12,35 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import sys
from unittest.mock import Mock
+from a2a.types import Message
+from a2a.types import Part
+from a2a.types import Role
+from a2a.types import TaskState
+from a2a.types import TaskStatus
+from a2a.types import TaskStatusUpdateEvent
+from a2a.types import TextPart
+from google.adk.a2a.executor.task_result_aggregator import TaskResultAggregator
import pytest
-# Skip all tests in this module if Python version is less than 3.10
-pytestmark = pytest.mark.skipif(
- sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
-)
-
-# Import dependencies with version checking
-try:
- from a2a.types import Message
- from a2a.types import Part
- from a2a.types import Role
- from a2a.types import TaskState
- from a2a.types import TaskStatus
- from a2a.types import TaskStatusUpdateEvent
- from a2a.types import TextPart
- from google.adk.a2a.executor.task_result_aggregator import TaskResultAggregator
-except ImportError as e:
- if sys.version_info < (3, 10):
- # Imports are not needed since tests will be skipped due to pytestmark.
- # The imported names are only used within test methods, not at module level,
- # so no NameError occurs during module compilation.
- pass
- else:
- raise e
-
def create_test_message(text: str):
"""Helper function to create a test Message object."""
diff --git a/tests/unittests/a2a/utils/test_agent_card_builder.py b/tests/unittests/a2a/utils/test_agent_card_builder.py
index e0b62468e5..3bf3202897 100644
--- a/tests/unittests/a2a/utils/test_agent_card_builder.py
+++ b/tests/unittests/a2a/utils/test_agent_card_builder.py
@@ -12,55 +12,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import sys
from unittest.mock import Mock
from unittest.mock import patch
+from a2a.types import AgentCapabilities
+from a2a.types import AgentCard
+from a2a.types import AgentProvider
+from a2a.types import AgentSkill
+from a2a.types import SecurityScheme
+from google.adk.a2a.utils.agent_card_builder import _build_agent_description
+from google.adk.a2a.utils.agent_card_builder import _build_llm_agent_description_with_instructions
+from google.adk.a2a.utils.agent_card_builder import _build_loop_description
+from google.adk.a2a.utils.agent_card_builder import _build_orchestration_skill
+from google.adk.a2a.utils.agent_card_builder import _build_parallel_description
+from google.adk.a2a.utils.agent_card_builder import _build_sequential_description
+from google.adk.a2a.utils.agent_card_builder import _convert_example_tool_examples
+from google.adk.a2a.utils.agent_card_builder import _extract_examples_from_instruction
+from google.adk.a2a.utils.agent_card_builder import _get_agent_skill_name
+from google.adk.a2a.utils.agent_card_builder import _get_agent_type
+from google.adk.a2a.utils.agent_card_builder import _get_default_description
+from google.adk.a2a.utils.agent_card_builder import _get_input_modes
+from google.adk.a2a.utils.agent_card_builder import _get_output_modes
+from google.adk.a2a.utils.agent_card_builder import _get_workflow_description
+from google.adk.a2a.utils.agent_card_builder import _replace_pronouns
+from google.adk.a2a.utils.agent_card_builder import AgentCardBuilder
+from google.adk.agents.base_agent import BaseAgent
+from google.adk.agents.llm_agent import LlmAgent
+from google.adk.agents.loop_agent import LoopAgent
+from google.adk.agents.parallel_agent import ParallelAgent
+from google.adk.agents.sequential_agent import SequentialAgent
+from google.adk.tools.example_tool import ExampleTool
import pytest
-# Skip all tests in this module if Python version is less than 3.10
-pytestmark = pytest.mark.skipif(
- sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
-)
-
-# Import dependencies with version checking
-try:
- from a2a.types import AgentCapabilities
- from a2a.types import AgentCard
- from a2a.types import AgentProvider
- from a2a.types import AgentSkill
- from a2a.types import SecurityScheme
- from google.adk.a2a.utils.agent_card_builder import _build_agent_description
- from google.adk.a2a.utils.agent_card_builder import _build_llm_agent_description_with_instructions
- from google.adk.a2a.utils.agent_card_builder import _build_loop_description
- from google.adk.a2a.utils.agent_card_builder import _build_orchestration_skill
- from google.adk.a2a.utils.agent_card_builder import _build_parallel_description
- from google.adk.a2a.utils.agent_card_builder import _build_sequential_description
- from google.adk.a2a.utils.agent_card_builder import _convert_example_tool_examples
- from google.adk.a2a.utils.agent_card_builder import _extract_examples_from_instruction
- from google.adk.a2a.utils.agent_card_builder import _get_agent_skill_name
- from google.adk.a2a.utils.agent_card_builder import _get_agent_type
- from google.adk.a2a.utils.agent_card_builder import _get_default_description
- from google.adk.a2a.utils.agent_card_builder import _get_input_modes
- from google.adk.a2a.utils.agent_card_builder import _get_output_modes
- from google.adk.a2a.utils.agent_card_builder import _get_workflow_description
- from google.adk.a2a.utils.agent_card_builder import _replace_pronouns
- from google.adk.a2a.utils.agent_card_builder import AgentCardBuilder
- from google.adk.agents.base_agent import BaseAgent
- from google.adk.agents.llm_agent import LlmAgent
- from google.adk.agents.loop_agent import LoopAgent
- from google.adk.agents.parallel_agent import ParallelAgent
- from google.adk.agents.sequential_agent import SequentialAgent
- from google.adk.tools.example_tool import ExampleTool
-except ImportError as e:
- if sys.version_info < (3, 10):
- # Imports are not needed since tests will be skipped due to pytestmark.
- # The imported names are only used within test methods, not at module level,
- # so no NameError occurs during module compilation.
- pass
- else:
- raise e
-
class TestAgentCardBuilder:
"""Test suite for AgentCardBuilder class."""
diff --git a/tests/unittests/a2a/utils/test_agent_to_a2a.py b/tests/unittests/a2a/utils/test_agent_to_a2a.py
index ee80b0233b..503e572f2f 100644
--- a/tests/unittests/a2a/utils/test_agent_to_a2a.py
+++ b/tests/unittests/a2a/utils/test_agent_to_a2a.py
@@ -12,42 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import sys
from unittest.mock import AsyncMock
from unittest.mock import Mock
from unittest.mock import patch
+from a2a.server.apps import A2AStarletteApplication
+from a2a.server.request_handlers import DefaultRequestHandler
+from a2a.server.tasks import InMemoryTaskStore
+from a2a.types import AgentCard
+from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor
+from google.adk.a2a.utils.agent_card_builder import AgentCardBuilder
+from google.adk.a2a.utils.agent_to_a2a import to_a2a
+from google.adk.agents.base_agent import BaseAgent
+from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
+from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService
+from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
+from google.adk.runners import Runner
+from google.adk.sessions.in_memory_session_service import InMemorySessionService
import pytest
-
-# Skip all tests in this module if Python version is less than 3.10
-pytestmark = pytest.mark.skipif(
- sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
-)
-
-# Import dependencies with version checking
-try:
- from a2a.server.apps import A2AStarletteApplication
- from a2a.server.request_handlers import DefaultRequestHandler
- from a2a.server.tasks import InMemoryTaskStore
- from a2a.types import AgentCard
- from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor
- from google.adk.a2a.utils.agent_card_builder import AgentCardBuilder
- from google.adk.a2a.utils.agent_to_a2a import to_a2a
- from google.adk.agents.base_agent import BaseAgent
- from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
- from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService
- from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
- from google.adk.runners import Runner
- from google.adk.sessions.in_memory_session_service import InMemorySessionService
- from starlette.applications import Starlette
-except ImportError as e:
- if sys.version_info < (3, 10):
- # Imports are not needed since tests will be skipped due to pytestmark.
- # The imported names are only used within test methods, not at module level,
- # so no NameError occurs during module compilation.
- pass
- else:
- raise e
+from starlette.applications import Starlette
class TestToA2A:
diff --git a/tests/unittests/agents/test_agent_config.py b/tests/unittests/agents/test_agent_config.py
index c2300f5f5d..86fda7fc9b 100644
--- a/tests/unittests/agents/test_agent_config.py
+++ b/tests/unittests/agents/test_agent_config.py
@@ -12,18 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import ntpath
+import os
from pathlib import Path
+from textwrap import dedent
from typing import Literal
from typing import Type
+from unittest import mock
from google.adk.agents import config_agent_utils
from google.adk.agents.agent_config import AgentConfig
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.base_agent_config import BaseAgentConfig
+from google.adk.agents.common_configs import AgentRefConfig
from google.adk.agents.llm_agent import LlmAgent
from google.adk.agents.loop_agent import LoopAgent
from google.adk.agents.parallel_agent import ParallelAgent
from google.adk.agents.sequential_agent import SequentialAgent
+from google.adk.models.lite_llm import LiteLlm
import pytest
import yaml
@@ -254,6 +260,53 @@ def test_agent_config_discriminator_llm_agent_with_sub_agents(
assert config.root.agent_class == agent_class_value
+def test_agent_config_litellm_model_with_custom_args(tmp_path: Path):
+ yaml_content = """\
+name: managed_api_agent
+description: Agent using LiteLLM managed endpoint
+instruction: Respond concisely.
+model_code:
+ name: google.adk.models.lite_llm.LiteLlm
+ args:
+ - name: model
+ value: kimi/k2
+ - name: api_base
+ value: https://proxy.litellm.ai/v1
+"""
+ config_file = tmp_path / "litellm_agent.yaml"
+ config_file.write_text(yaml_content)
+
+ agent = config_agent_utils.from_config(str(config_file))
+
+ assert isinstance(agent, LlmAgent)
+ assert isinstance(agent.model, LiteLlm)
+ assert agent.model.model == "kimi/k2"
+ assert agent.model._additional_args.get("api_base") == (
+ "https://proxy.litellm.ai/v1"
+ )
+
+
+def test_agent_config_legacy_model_mapping_still_supported(tmp_path: Path):
+ yaml_content = """\
+name: managed_api_agent
+description: Agent using LiteLLM managed endpoint
+instruction: Respond concisely.
+model:
+ name: google.adk.models.lite_llm.LiteLlm
+ args:
+ - name: model
+ value: kimi/k2
+"""
+ config_file = tmp_path / "legacy_litellm_agent.yaml"
+ config_file.write_text(yaml_content)
+
+ agent = config_agent_utils.from_config(str(config_file))
+
+ assert isinstance(agent, LlmAgent)
+ assert isinstance(agent.model, LiteLlm)
+ assert agent.model.model == "kimi/k2"
+
+
def test_agent_config_discriminator_custom_agent():
class MyCustomAgentConfig(BaseAgentConfig):
agent_class: Literal["mylib.agents.MyCustomAgent"] = (
@@ -280,3 +333,91 @@ class MyCustomAgentConfig(BaseAgentConfig):
config.root.model_dump()
)
assert my_custom_config.other_field == "other value"
+
+
+@pytest.mark.parametrize(
+ ("config_rel_path", "child_rel_path", "child_name", "instruction"),
+ [
+ (
+ Path("main.yaml"),
+ Path("sub_agents/child.yaml"),
+ "child_agent",
+ "I am a child agent",
+ ),
+ (
+ Path("level1/level2/nested_main.yaml"),
+ Path("sub/nested_child.yaml"),
+ "nested_child",
+ "I am nested",
+ ),
+ ],
+)
+def test_resolve_agent_reference_resolves_relative_paths(
+ config_rel_path: Path,
+ child_rel_path: Path,
+ child_name: str,
+ instruction: str,
+ tmp_path: Path,
+):
+ """Verify resolve_agent_reference resolves relative sub-agent paths."""
+ config_file = tmp_path / config_rel_path
+ config_file.parent.mkdir(parents=True, exist_ok=True)
+
+ child_config_path = config_file.parent / child_rel_path
+ child_config_path.parent.mkdir(parents=True, exist_ok=True)
+ child_config_path.write_text(dedent(f"""
+ agent_class: LlmAgent
+ name: {child_name}
+ model: gemini-2.0-flash
+ instruction: {instruction}
+ """).lstrip())
+
+ config_file.write_text(dedent(f"""
+ agent_class: LlmAgent
+ name: main_agent
+ model: gemini-2.0-flash
+ instruction: I am the main agent
+ sub_agents:
+ - config_path: {child_rel_path.as_posix()}
+ """).lstrip())
+
+ ref_config = AgentRefConfig(config_path=child_rel_path.as_posix())
+ agent = config_agent_utils.resolve_agent_reference(
+ ref_config, str(config_file)
+ )
+
+ assert agent.name == child_name
+
+ config_dir = os.path.dirname(str(config_file.resolve()))
+ assert config_dir == str(config_file.parent.resolve())
+
+ expected_child_path = os.path.join(config_dir, *child_rel_path.parts)
+ assert os.path.exists(expected_child_path)
+
+
+def test_resolve_agent_reference_uses_windows_dirname():
+ """Ensure Windows-style config references resolve via os.path.dirname."""
+ ref_config = AgentRefConfig(config_path="sub\\child.yaml")
+ recorded: dict[str, str] = {}
+
+ def fake_from_config(path: str):
+ recorded["path"] = path
+ return "sentinel"
+
+ with (
+ mock.patch.object(
+ config_agent_utils,
+ "from_config",
+ autospec=True,
+ side_effect=fake_from_config,
+ ),
+ mock.patch.object(config_agent_utils.os, "path", ntpath),
+ ):
+ referencing = r"C:\workspace\agents\main.yaml"
+ result = config_agent_utils.resolve_agent_reference(ref_config, referencing)
+
+ expected_path = ntpath.join(
+ ntpath.dirname(referencing), ref_config.config_path
+ )
+ assert result == "sentinel"
+ assert recorded["path"] == expected_path
diff --git a/tests/unittests/agents/test_base_agent.py b/tests/unittests/agents/test_base_agent.py
index 663179f670..259bdd51c2 100644
--- a/tests/unittests/agents/test_base_agent.py
+++ b/tests/unittests/agents/test_base_agent.py
@@ -16,6 +16,7 @@
from enum import Enum
from functools import partial
+import logging
from typing import AsyncGenerator
from typing import List
from typing import Optional
@@ -854,6 +855,112 @@ def test_set_parent_agent_for_sub_agent_twice(
)
+def test_validate_sub_agents_unique_names_single_duplicate(
+ request: pytest.FixtureRequest,
+ caplog: pytest.LogCaptureFixture,
+):
+ """Test that duplicate sub-agent names logs a warning."""
+ duplicate_name = f'{request.function.__name__}_duplicate_agent'
+ sub_agent_1 = _TestingAgent(name=duplicate_name)
+ sub_agent_2 = _TestingAgent(name=duplicate_name)
+
+ with caplog.at_level(logging.WARNING):
+ _ = _TestingAgent(
+ name=f'{request.function.__name__}_parent',
+ sub_agents=[sub_agent_1, sub_agent_2],
+ )
+ assert f'Found duplicate sub-agent names: `{duplicate_name}`' in caplog.text
+
+
+def test_validate_sub_agents_unique_names_multiple_duplicates(
+ request: pytest.FixtureRequest,
+ caplog: pytest.LogCaptureFixture,
+):
+ """Test that multiple duplicate sub-agent names are all reported."""
+ duplicate_name_1 = f'{request.function.__name__}_duplicate_1'
+ duplicate_name_2 = f'{request.function.__name__}_duplicate_2'
+
+ sub_agents = [
+ _TestingAgent(name=duplicate_name_1),
+ _TestingAgent(name=f'{request.function.__name__}_unique'),
+ _TestingAgent(name=duplicate_name_1), # First duplicate
+ _TestingAgent(name=duplicate_name_2),
+ _TestingAgent(name=duplicate_name_2), # Second duplicate
+ ]
+
+ with caplog.at_level(logging.WARNING):
+ _ = _TestingAgent(
+ name=f'{request.function.__name__}_parent',
+ sub_agents=sub_agents,
+ )
+
+ # Verify each duplicate name appears exactly once in the error message
+ assert caplog.text.count(duplicate_name_1) == 1
+ assert caplog.text.count(duplicate_name_2) == 1
+ # Verify both duplicate names are present
+ assert duplicate_name_1 in caplog.text
+ assert duplicate_name_2 in caplog.text
+
+
+def test_validate_sub_agents_unique_names_triple_duplicate(
+ request: pytest.FixtureRequest,
+ caplog: pytest.LogCaptureFixture,
+):
+ """Test that a name appearing three times is reported only once."""
+ duplicate_name = f'{request.function.__name__}_triple_duplicate'
+
+ sub_agents = [
+ _TestingAgent(name=duplicate_name),
+ _TestingAgent(name=f'{request.function.__name__}_unique'),
+ _TestingAgent(name=duplicate_name), # Second occurrence
+ _TestingAgent(name=duplicate_name), # Third occurrence
+ ]
+
+ with caplog.at_level(logging.WARNING):
+ _ = _TestingAgent(
+ name=f'{request.function.__name__}_parent',
+ sub_agents=sub_agents,
+ )
+
+ # Verify the duplicate name appears exactly once in the error message
+ # (not three times even though it appears three times in the list)
+ assert caplog.text.count(duplicate_name) == 1
+ assert duplicate_name in caplog.text
+
+
+def test_validate_sub_agents_unique_names_no_duplicates(
+ request: pytest.FixtureRequest,
+):
+ """Test that unique sub-agent names pass validation."""
+ sub_agents = [
+ _TestingAgent(name=f'{request.function.__name__}_sub_agent_1'),
+ _TestingAgent(name=f'{request.function.__name__}_sub_agent_2'),
+ _TestingAgent(name=f'{request.function.__name__}_sub_agent_3'),
+ ]
+
+ parent = _TestingAgent(
+ name=f'{request.function.__name__}_parent',
+ sub_agents=sub_agents,
+ )
+
+ assert len(parent.sub_agents) == 3
+ assert parent.sub_agents[0].name == f'{request.function.__name__}_sub_agent_1'
+ assert parent.sub_agents[1].name == f'{request.function.__name__}_sub_agent_2'
+ assert parent.sub_agents[2].name == f'{request.function.__name__}_sub_agent_3'
+
+
+def test_validate_sub_agents_unique_names_empty_list(
+ request: pytest.FixtureRequest,
+):
+ """Test that empty sub-agents list passes validation."""
+ parent = _TestingAgent(
+ name=f'{request.function.__name__}_parent',
+ sub_agents=[],
+ )
+
+ assert len(parent.sub_agents) == 0
+
+
if __name__ == '__main__':
pytest.main([__file__])
diff --git a/tests/unittests/agents/test_llm_agent_fields.py b/tests/unittests/agents/test_llm_agent_fields.py
index c57254dbc8..577923f7bf 100644
--- a/tests/unittests/agents/test_llm_agent_fields.py
+++ b/tests/unittests/agents/test_llm_agent_fields.py
@@ -22,6 +22,9 @@
from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.llm_agent import LlmAgent
from google.adk.agents.readonly_context import ReadonlyContext
+from google.adk.models.anthropic_llm import Claude
+from google.adk.models.google_llm import Gemini
+from google.adk.models.lite_llm import LiteLlm
from google.adk.models.llm_request import LlmRequest
from google.adk.models.registry import LLMRegistry
from google.adk.sessions.in_memory_session_service import InMemorySessionService
@@ -411,3 +414,47 @@ async def test_handle_vais_only(self):
assert len(tools) == 1
assert tools[0].name == 'vertex_ai_search'
assert tools[0].__class__.__name__ == 'VertexAiSearchTool'
+
+
+# Tests for multi-provider model support via string model names
+@pytest.mark.parametrize(
+ 'model_name',
+ [
+ 'gemini-1.5-flash',
+ 'gemini-2.0-flash-exp',
+ ],
+)
+def test_agent_with_gemini_string_model(model_name):
+ """Test that Agent accepts Gemini model strings and resolves to Gemini."""
+ agent = LlmAgent(name='test_agent', model=model_name)
+ assert isinstance(agent.canonical_model, Gemini)
+ assert agent.canonical_model.model == model_name
+
+
+@pytest.mark.parametrize(
+ 'model_name',
+ [
+ 'claude-3-5-sonnet-v2@20241022',
+ 'claude-sonnet-4@20250514',
+ ],
+)
+def test_agent_with_claude_string_model(model_name):
+ """Test that Agent accepts Claude model strings and resolves to Claude."""
+ agent = LlmAgent(name='test_agent', model=model_name)
+ assert isinstance(agent.canonical_model, Claude)
+ assert agent.canonical_model.model == model_name
+
+
+@pytest.mark.parametrize(
+ 'model_name',
+ [
+ 'openai/gpt-4o',
+ 'groq/llama3-70b-8192',
+ 'anthropic/claude-3-opus-20240229',
+ ],
+)
+def test_agent_with_litellm_string_model(model_name):
+ """Test that Agent accepts LiteLLM provider strings."""
+ agent = LlmAgent(name='test_agent', model=model_name)
+ assert isinstance(agent.canonical_model, LiteLlm)
+ assert agent.canonical_model.model == model_name
diff --git a/tests/unittests/agents/test_mcp_instruction_provider.py b/tests/unittests/agents/test_mcp_instruction_provider.py
index 1f2d098c2a..256d812630 100644
--- a/tests/unittests/agents/test_mcp_instruction_provider.py
+++ b/tests/unittests/agents/test_mcp_instruction_provider.py
@@ -13,34 +13,14 @@
# limitations under the License.
"""Unit tests for McpInstructionProvider."""
-import sys
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
from unittest.mock import patch
+from google.adk.agents.mcp_instruction_provider import McpInstructionProvider
from google.adk.agents.readonly_context import ReadonlyContext
import pytest
-# Skip all tests in this module if Python version is less than 3.10
-pytestmark = pytest.mark.skipif(
- sys.version_info < (3, 10),
- reason="MCP instruction provider requires Python 3.10+",
-)
-
-# Import dependencies with version checking
-try:
- from google.adk.agents.mcp_instruction_provider import McpInstructionProvider
-except ImportError as e:
- if sys.version_info < (3, 10):
- # Create dummy classes to prevent NameError during test collection
- # Tests will be skipped anyway due to pytestmark
- class DummyClass:
- pass
-
- McpInstructionProvider = DummyClass
- else:
- raise e
-
class TestMcpInstructionProvider:
"""Unit tests for McpInstructionProvider."""
diff --git a/tests/unittests/agents/test_remote_a2a_agent.py b/tests/unittests/agents/test_remote_a2a_agent.py
index 561a381870..e7865f39ba 100644
--- a/tests/unittests/agents/test_remote_a2a_agent.py
+++ b/tests/unittests/agents/test_remote_a2a_agent.py
@@ -14,70 +14,38 @@
import json
from pathlib import Path
-import sys
import tempfile
from unittest.mock import AsyncMock
from unittest.mock import create_autospec
from unittest.mock import Mock
from unittest.mock import patch
+from a2a.client.client import ClientConfig
+from a2a.client.client import Consumer
+from a2a.client.client_factory import ClientFactory
+from a2a.types import AgentCapabilities
+from a2a.types import AgentCard
+from a2a.types import AgentSkill
+from a2a.types import Artifact
+from a2a.types import Message as A2AMessage
+from a2a.types import Part as A2ATaskStatus
+from a2a.types import SendMessageSuccessResponse
+from a2a.types import Task as A2ATask
+from a2a.types import TaskArtifactUpdateEvent
+from a2a.types import TaskState
+from a2a.types import TaskStatus
+from a2a.types import TaskStatusUpdateEvent
+from a2a.types import TextPart
+from google.adk.agents.invocation_context import InvocationContext
+from google.adk.agents.remote_a2a_agent import A2A_METADATA_PREFIX
+from google.adk.agents.remote_a2a_agent import AgentCardResolutionError
+from google.adk.agents.remote_a2a_agent import RemoteA2aAgent
from google.adk.events.event import Event
from google.adk.sessions.session import Session
from google.genai import types as genai_types
import httpx
import pytest
-# Skip all tests in this module if Python version is less than 3.10
-pytestmark = pytest.mark.skipif(
- sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
-)
-
-# Import dependencies with version checking
-try:
- from a2a.client.client import ClientConfig
- from a2a.client.client import Consumer
- from a2a.client.client_factory import ClientFactory
- from a2a.types import AgentCapabilities
- from a2a.types import AgentCard
- from a2a.types import AgentSkill
- from a2a.types import Artifact
- from a2a.types import Message as A2AMessage
- from a2a.types import Part as A2ATaskStatus
- from a2a.types import SendMessageSuccessResponse
- from a2a.types import Task as A2ATask
- from a2a.types import TaskArtifactUpdateEvent
- from a2a.types import TaskState
- from a2a.types import TaskStatus
- from a2a.types import TaskStatusUpdateEvent
- from a2a.types import TextPart
- from google.adk.agents.invocation_context import InvocationContext
- from google.adk.agents.remote_a2a_agent import A2A_METADATA_PREFIX
- from google.adk.agents.remote_a2a_agent import AgentCardResolutionError
- from google.adk.agents.remote_a2a_agent import RemoteA2aAgent
-except ImportError as e:
- if sys.version_info < (3, 10):
- # Create dummy classes to prevent NameError during module compilation.
- # These are needed because the module has type annotations and module-level
- # helper functions that reference imported types.
- class DummyTypes:
- pass
-
- AgentCapabilities = DummyTypes()
- AgentCard = DummyTypes()
- AgentSkill = DummyTypes()
- A2AMessage = DummyTypes()
- SendMessageSuccessResponse = DummyTypes()
- A2ATask = DummyTypes()
- TaskStatusUpdateEvent = DummyTypes()
- Artifact = DummyTypes()
- TaskArtifactUpdateEvent = DummyTypes()
- InvocationContext = DummyTypes()
- RemoteA2aAgent = DummyTypes()
- AgentCardResolutionError = Exception
- A2A_METADATA_PREFIX = ""
- else:
- raise e
-
# Helper function to create a proper AgentCard for testing
def create_test_agent_card(
@@ -723,6 +691,7 @@ async def test_handle_a2a_response_success_with_message(self):
mock_a2a_message,
self.agent.name,
self.mock_context,
+ self.mock_a2a_part_converter,
)
# Check that metadata was added
assert result.custom_metadata is not None
@@ -760,6 +729,7 @@ async def test_handle_a2a_response_with_task_completed_and_no_update(self):
mock_a2a_task,
self.agent.name,
self.mock_context,
+ self.mock_a2a_part_converter,
)
# Check the parts are not updated as Thought
assert result.content.parts[0].thought is None
@@ -864,6 +834,7 @@ async def test_handle_a2a_response_with_task_submitted_and_no_update(self):
mock_a2a_task,
self.agent.name,
self.mock_context,
+ self.mock_a2a_part_converter,
)
# Check the parts are updated as Thought
assert result.content.parts[0].thought is True
@@ -909,6 +880,7 @@ async def test_handle_a2a_response_with_task_status_update_with_message(self):
mock_a2a_message,
self.agent.name,
self.mock_context,
+ self.mock_a2a_part_converter,
)
# Check that metadata was added
assert result.custom_metadata is not None
@@ -954,6 +926,7 @@ async def test_handle_a2a_response_with_task_status_working_update_with_message(
mock_a2a_message,
self.agent.name,
self.mock_context,
+ self.mock_a2a_part_converter,
)
# Check that metadata was added
assert result.custom_metadata is not None
@@ -1009,7 +982,10 @@ async def test_handle_a2a_response_with_artifact_update(self):
assert result == mock_event
mock_convert.assert_called_once_with(
- mock_a2a_task, self.agent.name, self.mock_context
+ mock_a2a_task,
+ self.agent.name,
+ self.mock_context,
+ self.agent._a2a_part_converter,
)
# Check that metadata was added
assert result.custom_metadata is not None
@@ -1039,6 +1015,8 @@ class TestRemoteA2aAgentMessageHandlingFromFactory:
def setup_method(self):
"""Setup test fixtures."""
+ self.mock_a2a_part_converter = Mock()
+
self.agent_card = create_test_agent_card()
self.agent = RemoteA2aAgent(
name="test_agent",
@@ -1046,6 +1024,7 @@ def setup_method(self):
a2a_client_factory=ClientFactory(
config=ClientConfig(httpx_client=httpx.AsyncClient()),
),
+ a2a_part_converter=self.mock_a2a_part_converter,
)
# Mock session and context
@@ -1173,7 +1152,10 @@ async def test_handle_a2a_response_success_with_message(self):
assert result == mock_event
mock_convert.assert_called_once_with(
- mock_a2a_message, self.agent.name, self.mock_context
+ mock_a2a_message,
+ self.agent.name,
+ self.mock_context,
+ self.mock_a2a_part_converter,
)
# Check that metadata was added
assert result.custom_metadata is not None
@@ -1211,6 +1193,7 @@ async def test_handle_a2a_response_with_task_completed_and_no_update(self):
mock_a2a_task,
self.agent.name,
self.mock_context,
+ self.mock_a2a_part_converter,
)
# Check the parts are not updated as Thought
assert result.content.parts[0].thought is None
@@ -1251,6 +1234,7 @@ async def test_handle_a2a_response_with_task_submitted_and_no_update(self):
mock_a2a_task,
self.agent.name,
self.mock_context,
+ self.agent._a2a_part_converter,
)
# Check the parts are updated as Thought
assert result.content.parts[0].thought is True
@@ -1296,6 +1280,7 @@ async def test_handle_a2a_response_with_task_status_update_with_message(self):
mock_a2a_message,
self.agent.name,
self.mock_context,
+ self.agent._a2a_part_converter,
)
# Check that metadata was added
assert result.custom_metadata is not None
@@ -1341,6 +1326,7 @@ async def test_handle_a2a_response_with_task_status_working_update_with_message(
mock_a2a_message,
self.agent.name,
self.mock_context,
+ self.agent._a2a_part_converter,
)
# Check that metadata was added
assert result.custom_metadata is not None
@@ -1396,7 +1382,10 @@ async def test_handle_a2a_response_with_artifact_update(self):
assert result == mock_event
mock_convert.assert_called_once_with(
- mock_a2a_task, self.agent.name, self.mock_context
+ mock_a2a_task,
+ self.agent.name,
+ self.mock_context,
+ self.agent._a2a_part_converter,
)
# Check that metadata was added
assert result.custom_metadata is not None
diff --git a/tests/unittests/artifacts/test_artifact_service.py b/tests/unittests/artifacts/test_artifact_service.py
index f7b457f73b..c68ad512c0 100644
--- a/tests/unittests/artifacts/test_artifact_service.py
+++ b/tests/unittests/artifacts/test_artifact_service.py
@@ -32,6 +32,7 @@
from google.adk.artifacts.file_artifact_service import FileArtifactService
from google.adk.artifacts.gcs_artifact_service import GcsArtifactService
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
+from google.adk.errors.input_validation_error import InputValidationError
from google.genai import types
import pytest
@@ -732,7 +733,7 @@ async def test_file_save_artifact_rejects_out_of_scope_paths(
"""FileArtifactService prevents path traversal outside of its storage roots."""
artifact_service = FileArtifactService(root_dir=tmp_path / "artifacts")
part = types.Part(text="content")
- with pytest.raises(ValueError):
+ with pytest.raises(InputValidationError):
await artifact_service.save_artifact(
app_name="myapp",
user_id="user123",
@@ -757,7 +758,7 @@ async def test_file_save_artifact_rejects_absolute_path_within_scope(tmp_path):
/ "diagram.png"
)
part = types.Part(text="content")
- with pytest.raises(ValueError):
+ with pytest.raises(InputValidationError):
await artifact_service.save_artifact(
app_name="myapp",
user_id="user123",
diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py
index 2d7b9472ba..75d5679084 100755
--- a/tests/unittests/cli/test_fast_api.py
+++ b/tests/unittests/cli/test_fast_api.py
@@ -30,7 +30,9 @@
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.run_config import RunConfig
from google.adk.apps.app import App
+from google.adk.artifacts.base_artifact_service import ArtifactVersion
from google.adk.cli.fast_api import get_fast_api_app
+from google.adk.errors.input_validation_error import InputValidationError
from google.adk.evaluation.eval_case import EvalCase
from google.adk.evaluation.eval_case import Invocation
from google.adk.evaluation.eval_result import EvalSetResult
@@ -190,6 +192,14 @@ def load_agent(self, app_name):
def list_agents(self):
return ["test_app"]
+ def list_agents_detailed(self):
+ return [{
+ "name": "test_app",
+ "root_agent_name": "test_agent",
+ "description": "A test agent for unit testing",
+ "language": "python",
+ }]
+
return MockAgentLoader(".")
@@ -203,48 +213,135 @@ def mock_session_service():
def mock_artifact_service():
"""Create a mock artifact service."""
- # Storage for artifacts
- artifacts = {}
+ artifacts: dict[str, list[dict[str, Any]]] = {}
+
+ def _artifact_key(
+ app_name: str, user_id: str, session_id: Optional[str], filename: str
+ ) -> str:
+ if session_id is None:
+ return f"{app_name}:{user_id}:user:{filename}"
+ return f"{app_name}:{user_id}:{session_id}:{filename}"
+
+ def _canonical_uri(
+ app_name: str,
+ user_id: str,
+ session_id: Optional[str],
+ filename: str,
+ version: int,
+ ) -> str:
+ if session_id is None:
+ return (
+ f"artifact://apps/{app_name}/users/{user_id}/artifacts/"
+ f"{filename}/versions/{version}"
+ )
+ return (
+ f"artifact://apps/{app_name}/users/{user_id}/sessions/{session_id}/"
+ f"artifacts/{filename}/versions/{version}"
+ )
class MockArtifactService:
+ def __init__(self):
+ self._artifacts = artifacts
+ self.save_artifact_side_effect: Optional[BaseException] = None
+
+ async def save_artifact(
+ self,
+ *,
+ app_name: str,
+ user_id: str,
+ filename: str,
+ artifact: types.Part,
+ session_id: Optional[str] = None,
+ custom_metadata: Optional[dict[str, Any]] = None,
+ ) -> int:
+ if self.save_artifact_side_effect is not None:
+ effect = self.save_artifact_side_effect
+ if isinstance(effect, BaseException):
+ raise effect
+ raise TypeError(
+ "save_artifact_side_effect must be an exception instance."
+ )
+ key = _artifact_key(app_name, user_id, session_id, filename)
+ entries = artifacts.setdefault(key, [])
+ version = len(entries)
+ artifact_version = ArtifactVersion(
+ version=version,
+ canonical_uri=_canonical_uri(
+ app_name, user_id, session_id, filename, version
+ ),
+ custom_metadata=custom_metadata or {},
+ )
+ if artifact.inline_data is not None:
+ artifact_version.mime_type = artifact.inline_data.mime_type
+ elif artifact.text is not None:
+ artifact_version.mime_type = "text/plain"
+ elif artifact.file_data is not None:
+ artifact_version.mime_type = artifact.file_data.mime_type
+
+ entries.append({
+ "version": version,
+ "artifact": artifact,
+ "metadata": artifact_version,
+ })
+ return version
+
async def load_artifact(
self, app_name, user_id, session_id, filename, version=None
):
"""Load an artifact by filename."""
- key = f"{app_name}:{user_id}:{session_id}:{filename}"
+ key = _artifact_key(app_name, user_id, session_id, filename)
if key not in artifacts:
return None
if version is not None:
- # Get a specific version
- for v in artifacts[key]:
- if v["version"] == version:
- return v["artifact"]
+ for entry in artifacts[key]:
+ if entry["version"] == version:
+ return entry["artifact"]
return None
- # Get the latest version
- return sorted(artifacts[key], key=lambda x: x["version"])[-1]["artifact"]
+ return artifacts[key][-1]["artifact"]
async def list_artifact_keys(self, app_name, user_id, session_id):
"""List artifact names for a session."""
prefix = f"{app_name}:{user_id}:{session_id}:"
return [
- k.split(":")[-1] for k in artifacts.keys() if k.startswith(prefix)
+ key.split(":")[-1]
+ for key in artifacts.keys()
+ if key.startswith(prefix)
]
async def list_versions(self, app_name, user_id, session_id, filename):
"""List versions of an artifact."""
- key = f"{app_name}:{user_id}:{session_id}:{filename}"
+ key = _artifact_key(app_name, user_id, session_id, filename)
if key not in artifacts:
return []
- return [a["version"] for a in artifacts[key]]
+ return [entry["version"] for entry in artifacts[key]]
async def delete_artifact(self, app_name, user_id, session_id, filename):
"""Delete an artifact."""
- key = f"{app_name}:{user_id}:{session_id}:{filename}"
- if key in artifacts:
- del artifacts[key]
+ key = _artifact_key(app_name, user_id, session_id, filename)
+ artifacts.pop(key, None)
+
+ async def get_artifact_version(
+ self,
+ *,
+ app_name: str,
+ user_id: str,
+ filename: str,
+ session_id: Optional[str] = None,
+ version: Optional[int] = None,
+ ) -> Optional[ArtifactVersion]:
+ key = _artifact_key(app_name, user_id, session_id, filename)
+ entries = artifacts.get(key)
+ if not entries:
+ return None
+ if version is None:
+ return entries[-1]["metadata"]
+ for entry in entries:
+ if entry["version"] == version:
+ return entry["metadata"]
+ return None
return MockArtifactService()
@@ -412,8 +509,6 @@ async def create_test_eval_set(
@pytest.fixture
def temp_agents_dir_with_a2a():
"""Create a temporary agents directory with A2A agent configurations for testing."""
- if sys.version_info < (3, 10):
- pytest.skip("A2A requires Python 3.10+")
with tempfile.TemporaryDirectory() as temp_dir:
# Create test agent directory
agent_dir = Path(temp_dir) / "test_a2a_agent"
@@ -457,9 +552,6 @@ def test_app_with_a2a(
temp_agents_dir_with_a2a,
):
"""Create a TestClient for the FastAPI app with A2A enabled."""
- if sys.version_info < (3, 10):
- pytest.skip("A2A requires Python 3.10+")
-
# Mock A2A related classes
with (
patch("signal.signal", return_value=None),
@@ -548,6 +640,26 @@ def test_list_apps(test_app):
logger.info(f"Listed apps: {data}")
+def test_list_apps_detailed(test_app):
+ """Test listing available applications with detailed metadata."""
+ response = test_app.get("/list-apps?detailed=true")
+
+ assert response.status_code == 200
+ data = response.json()
+ assert isinstance(data, dict)
+ assert "apps" in data
+ assert isinstance(data["apps"], list)
+
+ for app in data["apps"]:
+ assert "name" in app
+ assert "rootAgentName" in app
+ assert "description" in app
+ assert "language" in app
+ assert app["language"] in ["yaml", "python"]
+
+ logger.info(f"Listed apps: {data}")
+
+
def test_create_session_with_id(test_app, test_session_info):
"""Test creating a session with a specific ID."""
new_session_id = "new_session_id"
@@ -782,6 +894,87 @@ def test_list_artifact_names(test_app, create_test_session):
logger.info(f"Listed {len(data)} artifacts")
+def test_save_artifact(test_app, create_test_session, mock_artifact_service):
+ """Test saving an artifact through the FastAPI endpoint."""
+ info = create_test_session
+ url = (
+ f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/"
+ f"{info['session_id']}/artifacts"
+ )
+ artifact_part = types.Part(text="hello world")
+ payload = {
+ "filename": "greeting.txt",
+ "artifact": artifact_part.model_dump(by_alias=True, exclude_none=True),
+ }
+
+ response = test_app.post(url, json=payload)
+ assert response.status_code == 200
+ data = response.json()
+ assert data["version"] == 0
+ assert data["customMetadata"] == {}
+ assert data["mimeType"] in (None, "text/plain")
+ assert data["canonicalUri"].endswith(
+ f"/sessions/{info['session_id']}/artifacts/"
+ f"{payload['filename']}/versions/0"
+ )
+ assert isinstance(data["createTime"], float)
+
+ key = (
+ f"{info['app_name']}:{info['user_id']}:{info['session_id']}:"
+ f"{payload['filename']}"
+ )
+ stored = mock_artifact_service._artifacts[key][0]
+ assert stored["artifact"].text == "hello world"
+
+
+def test_save_artifact_returns_400_on_validation_error(
+ test_app, create_test_session, mock_artifact_service
+):
+ """Test save artifact endpoint surfaces validation errors as HTTP 400."""
+ info = create_test_session
+ url = (
+ f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/"
+ f"{info['session_id']}/artifacts"
+ )
+ artifact_part = types.Part(text="bad data")
+ payload = {
+ "filename": "invalid.txt",
+ "artifact": artifact_part.model_dump(by_alias=True, exclude_none=True),
+ }
+
+ mock_artifact_service.save_artifact_side_effect = InputValidationError(
+ "invalid artifact"
+ )
+
+ response = test_app.post(url, json=payload)
+ assert response.status_code == 400
+ assert response.json()["detail"] == "invalid artifact"
+
+
+def test_save_artifact_returns_500_on_unexpected_error(
+ test_app, create_test_session, mock_artifact_service
+):
+ """Test save artifact endpoint surfaces unexpected errors as HTTP 500."""
+ info = create_test_session
+ url = (
+ f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/"
+ f"{info['session_id']}/artifacts"
+ )
+ artifact_part = types.Part(text="bad data")
+ payload = {
+ "filename": "invalid.txt",
+ "artifact": artifact_part.model_dump(by_alias=True, exclude_none=True),
+ }
+
+ mock_artifact_service.save_artifact_side_effect = RuntimeError(
+ "unexpected failure"
+ )
+
+ response = test_app.post(url, json=payload)
+ assert response.status_code == 500
+ assert response.json()["detail"] == "unexpected failure"
+
+
def test_create_eval_set(test_app, test_session_info):
"""Test creating an eval set."""
url = f"/apps/{test_session_info['app_name']}/eval_sets/test_eval_set_id"
@@ -952,9 +1145,6 @@ def list_agents(self):
assert "dotSrc" in response.json()
-@pytest.mark.skipif(
- sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
-)
def test_a2a_agent_discovery(test_app_with_a2a):
"""Test that A2A agents are properly discovered and configured."""
# This test mainly verifies that the A2A setup doesn't break the app
@@ -963,9 +1153,6 @@ def test_a2a_agent_discovery(test_app_with_a2a):
logger.info("A2A agent discovery test passed")
-@pytest.mark.skipif(
- sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
-)
def test_a2a_disabled_by_default(test_app):
"""Test that A2A functionality is disabled by default."""
# The regular test_app fixture has a2a=False
diff --git a/tests/unittests/cli/utils/test_agent_loader.py b/tests/unittests/cli/utils/test_agent_loader.py
index 5c66160aed..4950fecbd3 100644
--- a/tests/unittests/cli/utils/test_agent_loader.py
+++ b/tests/unittests/cli/utils/test_agent_loader.py
@@ -12,12 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import ntpath
import os
from pathlib import Path
+from pathlib import PureWindowsPath
+import re
import sys
import tempfile
from textwrap import dedent
+from google.adk.cli.utils import agent_loader as agent_loader_module
from google.adk.cli.utils.agent_loader import AgentLoader
from pydantic import ValidationError
import pytest
@@ -280,6 +284,43 @@ def test_load_multiple_different_agents(self):
assert agent2 is not agent3
assert agent1.agent_id != agent2.agent_id != agent3.agent_id
+ def test_error_messages_use_os_sep_consistently(self):
+ """Verify error messages use os.sep instead of hardcoded '/'."""
+ del self
+ with tempfile.TemporaryDirectory() as temp_dir:
+ loader = AgentLoader(temp_dir)
+ agent_name = "missing_agent"
+
+ expected_path = os.path.join(temp_dir, agent_name)
+
+ with pytest.raises(ValueError) as exc_info:
+ loader.load_agent(agent_name)
+
+ exc_info.match(re.escape(expected_path))
+ exc_info.match(re.escape(f"{agent_name}{os.sep}root_agent.yaml"))
+ exc_info.match(re.escape(f"{os.sep}"))
+
+ def test_agent_loader_with_mocked_windows_path(self, monkeypatch):
+ """Mock Path() to simulate Windows behavior and catch regressions.
+
+ REGRESSION TEST: Fails with rstrip('/'), passes with str(Path()).
+ """
+ del self
+ windows_path = "C:\\Users\\dev\\agents\\"
+
+ with monkeypatch.context() as m:
+ m.setattr(
+ agent_loader_module,
+ "Path",
+ lambda path_str: PureWindowsPath(path_str),
+ )
+ loader = AgentLoader(windows_path)
+
+ expected = str(PureWindowsPath(windows_path))
+ assert loader.agents_dir == expected
+ assert not loader.agents_dir.endswith("\\")
+ assert not loader.agents_dir.endswith("/")
+
def test_agent_not_found_error(self):
"""Test that appropriate error is raised when agent is not found."""
with tempfile.TemporaryDirectory() as temp_dir:
diff --git a/tests/unittests/cli/utils/test_cli.py b/tests/unittests/cli/utils/test_cli.py
index 0de59598b3..73ae89a986 100644
--- a/tests/unittests/cli/utils/test_cli.py
+++ b/tests/unittests/cli/utils/test_cli.py
@@ -28,7 +28,12 @@
import click
from google.adk.agents.base_agent import BaseAgent
from google.adk.apps.app import App
+from google.adk.artifacts.file_artifact_service import FileArtifactService
+from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
+from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService
import google.adk.cli.cli as cli
+from google.adk.cli.utils.service_factory import create_artifact_service_from_options
+from google.adk.sessions.in_memory_session_service import InMemorySessionService
import pytest
@@ -151,9 +156,9 @@ def _echo(msg: str) -> None:
input_path = tmp_path / "input.json"
input_path.write_text(json.dumps(input_json))
- artifact_service = cli.InMemoryArtifactService()
- session_service = cli.InMemorySessionService()
- credential_service = cli.InMemoryCredentialService()
+ artifact_service = InMemoryArtifactService()
+ session_service = InMemorySessionService()
+ credential_service = InMemoryCredentialService()
dummy_root = BaseAgent(name="root")
session = await cli.run_input_file(
@@ -189,6 +194,34 @@ async def test_run_cli_with_input_file(fake_agent, tmp_path: Path) -> None:
)
+@pytest.mark.asyncio
+async def test_run_cli_loads_services_module(
+ fake_agent, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
+) -> None:
+ """run_cli should load custom services from the agents directory."""
+ parent_dir, folder_name = fake_agent
+ input_json = {"state": {}, "queries": ["ping"]}
+ input_path = tmp_path / "input.json"
+ input_path.write_text(json.dumps(input_json))
+
+ loaded_dirs: list[str] = []
+ monkeypatch.setattr(
+ cli, "load_services_module", lambda path: loaded_dirs.append(path)
+ )
+
+ agent_root = parent_dir / folder_name
+
+ await cli.run_cli(
+ agent_parent_dir=str(parent_dir),
+ agent_folder_name=folder_name,
+ input_file=str(input_path),
+ saved_session_file=None,
+ save_session=False,
+ )
+
+ assert loaded_dirs == [str(agent_root.resolve())]
+
+
@pytest.mark.asyncio
async def test_run_cli_app_uses_app_name_for_sessions(
fake_app_agent, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
@@ -197,15 +230,20 @@ async def test_run_cli_app_uses_app_name_for_sessions(
parent_dir, folder_name, app_name = fake_app_agent
created_app_names: List[str] = []
- original_session_cls = cli.InMemorySessionService
-
- class _SpySessionService(original_session_cls):
+ class _SpySessionService(InMemorySessionService):
async def create_session(self, *, app_name: str, **kwargs: Any) -> Any:
created_app_names.append(app_name)
return await super().create_session(app_name=app_name, **kwargs)
- monkeypatch.setattr(cli, "InMemorySessionService", _SpySessionService)
+ spy_session_service = _SpySessionService()
+
+ def _session_factory(**_: Any) -> InMemorySessionService:
+ return spy_session_service
+
+ monkeypatch.setattr(
+ cli, "create_session_service_from_options", _session_factory
+ )
input_json = {"state": {}, "queries": ["ping"]}
input_path = tmp_path / "input_app.json"
@@ -253,16 +291,76 @@ async def test_run_cli_save_session(
assert "id" in data and "events" in data
+def test_create_artifact_service_defaults_to_file(tmp_path: Path) -> None:
+ """Service factory should default to FileArtifactService when URI is unset."""
+ service = create_artifact_service_from_options(base_dir=tmp_path)
+ assert isinstance(service, FileArtifactService)
+ expected_root = Path(tmp_path) / ".adk" / "artifacts"
+ assert service.root_dir == expected_root
+ assert expected_root.exists()
+
+
+def test_create_artifact_service_uses_shared_root(
+ tmp_path: Path,
+) -> None:
+ """Artifact service should use a single file artifact service."""
+ service = create_artifact_service_from_options(base_dir=tmp_path)
+ assert isinstance(service, FileArtifactService)
+ expected_root = Path(tmp_path) / ".adk" / "artifacts"
+ assert service.root_dir == expected_root
+ assert expected_root.exists()
+
+
+def test_create_artifact_service_respects_memory_uri(tmp_path: Path) -> None:
+ """Service factory should honor memory:// URIs."""
+ service = create_artifact_service_from_options(
+ base_dir=tmp_path, artifact_service_uri="memory://"
+ )
+ assert isinstance(service, InMemoryArtifactService)
+
+
+def test_create_artifact_service_accepts_file_uri(tmp_path: Path) -> None:
+ """Service factory should allow custom local roots via file:// URIs."""
+ custom_root = tmp_path / "custom_artifacts"
+ service = create_artifact_service_from_options(
+ base_dir=tmp_path, artifact_service_uri=custom_root.as_uri()
+ )
+ assert isinstance(service, FileArtifactService)
+ assert service.root_dir == custom_root
+ assert custom_root.exists()
+
+
+@pytest.mark.asyncio
+async def test_run_cli_accepts_memory_scheme(
+ fake_agent, tmp_path: Path
+) -> None:
+ """run_cli should allow configuring in-memory services via memory:// URIs."""
+ parent_dir, folder_name = fake_agent
+ input_json = {"state": {}, "queries": []}
+ input_path = tmp_path / "noop.json"
+ input_path.write_text(json.dumps(input_json))
+
+ await cli.run_cli(
+ agent_parent_dir=str(parent_dir),
+ agent_folder_name=folder_name,
+ input_file=str(input_path),
+ saved_session_file=None,
+ save_session=False,
+ session_service_uri="memory://",
+ artifact_service_uri="memory://",
+ )
+
+
@pytest.mark.asyncio
async def test_run_interactively_whitespace_and_exit(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""run_interactively should skip blank input, echo once, then exit."""
# make a session that belongs to dummy agent
- session_service = cli.InMemorySessionService()
+ session_service = InMemorySessionService()
sess = await session_service.create_session(app_name="dummy", user_id="u")
- artifact_service = cli.InMemoryArtifactService()
- credential_service = cli.InMemoryCredentialService()
+ artifact_service = InMemoryArtifactService()
+ credential_service = InMemoryCredentialService()
root_agent = BaseAgent(name="root")
# fake user input: blank -> 'hello' -> 'exit'
diff --git a/tests/unittests/cli/utils/test_cli_tools_click.py b/tests/unittests/cli/utils/test_cli_tools_click.py
index be9015ca87..95b561e57b 100644
--- a/tests/unittests/cli/utils/test_cli_tools_click.py
+++ b/tests/unittests/cli/utils/test_cli_tools_click.py
@@ -76,8 +76,11 @@ def __call__(self, *args: Any, **kwargs: Any) -> None: # noqa: D401
# Fixtures
@pytest.fixture(autouse=True)
-def _mute_click(monkeypatch: pytest.MonkeyPatch) -> None:
+def _mute_click(request, monkeypatch: pytest.MonkeyPatch) -> None:
"""Suppress click output during tests."""
+ # Allow tests to opt-out of muting by using the 'unmute_click' marker
+ if "unmute_click" in request.keywords:
+ return
monkeypatch.setattr(click, "echo", lambda *a, **k: None)
# Keep secho for error messages
# monkeypatch.setattr(click, "secho", lambda *a, **k: None)
@@ -121,32 +124,70 @@ def test_cli_create_cmd_invokes_run_cmd(
cli_tools_click.main,
["create", "--model", "gemini", "--api_key", "key123", str(app_dir)],
)
- assert result.exit_code == 0
+ assert result.exit_code == 0, (result.output, repr(result.exception))
assert rec.calls, "cli_create.run_cmd must be called"
# cli run
-@pytest.mark.asyncio
-async def test_cli_run_invokes_run_cli(
- tmp_path: Path, monkeypatch: pytest.MonkeyPatch
+@pytest.mark.parametrize(
+ "cli_args,expected_session_uri,expected_artifact_uri",
+ [
+ pytest.param(
+ [
+ "--session_service_uri",
+ "memory://",
+ "--artifact_service_uri",
+ "memory://",
+ ],
+ "memory://",
+ "memory://",
+ id="memory_scheme_uris",
+ ),
+ pytest.param(
+ [],
+ None,
+ None,
+ id="default_uris_none",
+ ),
+ ],
+)
+def test_cli_run_service_uris(
+ tmp_path: Path,
+ monkeypatch: pytest.MonkeyPatch,
+ cli_args: list,
+ expected_session_uri: str,
+ expected_artifact_uri: str,
) -> None:
- """`adk run` should call run_cli via asyncio.run with correct parameters."""
- rec = _Recorder()
- monkeypatch.setattr(cli_tools_click, "run_cli", lambda **kwargs: rec(kwargs))
- monkeypatch.setattr(
- cli_tools_click.asyncio, "run", lambda coro: coro
- ) # pass-through
-
- # create dummy agent directory
+ """`adk run` should forward service URIs correctly to run_cli."""
agent_dir = tmp_path / "agent"
agent_dir.mkdir()
(agent_dir / "__init__.py").touch()
(agent_dir / "agent.py").touch()
+ # Capture the coroutine's locals before closing it
+ captured_locals = []
+
+ def capture_asyncio_run(coro):
+ # Extract the locals before closing the coroutine
+ if coro.cr_frame is not None:
+ captured_locals.append(dict(coro.cr_frame.f_locals))
+ coro.close() # Properly close the coroutine to avoid warnings
+
+ monkeypatch.setattr(cli_tools_click.asyncio, "run", capture_asyncio_run)
+
runner = CliRunner()
- result = runner.invoke(cli_tools_click.main, ["run", str(agent_dir)])
- assert result.exit_code == 0
- assert rec.calls and rec.calls[0][0][0]["agent_folder_name"] == "agent"
+ result = runner.invoke(
+ cli_tools_click.main,
+ ["run", *cli_args, str(agent_dir)],
+ )
+ assert result.exit_code == 0, (result.output, repr(result.exception))
+ assert len(captured_locals) == 1, "Expected asyncio.run to be called once"
+
+ # Verify the kwargs passed to run_cli
+ coro_locals = captured_locals[0]
+ assert coro_locals.get("session_service_uri") == expected_session_uri
+ assert coro_locals.get("artifact_service_uri") == expected_artifact_uri
+ assert coro_locals["agent_folder_name"] == "agent"
# cli deploy cloud_run
@@ -520,10 +561,13 @@ def test_cli_web_passes_service_uris(
assert called_kwargs.get("memory_service_uri") == "rag://mycorpus"
-def test_cli_web_passes_deprecated_uris(
- tmp_path: Path, monkeypatch: pytest.MonkeyPatch, _patch_uvicorn: _Recorder
+@pytest.mark.unmute_click
+def test_cli_web_warns_and_maps_deprecated_uris(
+ tmp_path: Path,
+ _patch_uvicorn: _Recorder,
+ monkeypatch: pytest.MonkeyPatch,
) -> None:
- """`adk web` should use deprecated URIs if new ones are not provided."""
+ """`adk web` should accept deprecated URI flags with warnings."""
agents_dir = tmp_path / "agents"
agents_dir.mkdir()
@@ -542,11 +586,14 @@ def test_cli_web_passes_deprecated_uris(
"gs://deprecated",
],
)
+
assert result.exit_code == 0
- assert mock_get_app.calls
called_kwargs = mock_get_app.calls[0][1]
assert called_kwargs.get("session_service_uri") == "sqlite:///deprecated.db"
assert called_kwargs.get("artifact_service_uri") == "gs://deprecated"
+ # Check output for deprecation warnings (CliRunner captures both stdout and stderr)
+ assert "--session_db_url" in result.output
+ assert "--artifact_storage_uri" in result.output
def test_cli_eval_with_eval_set_file_path(
diff --git a/tests/unittests/cli/utils/test_service_factory.py b/tests/unittests/cli/utils/test_service_factory.py
new file mode 100644
index 0000000000..9d9afdd23b
--- /dev/null
+++ b/tests/unittests/cli/utils/test_service_factory.py
@@ -0,0 +1,141 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for service factory helpers."""
+
+from __future__ import annotations
+
+from pathlib import Path
+from unittest.mock import Mock
+
+from google.adk.cli.utils.local_storage import PerAgentDatabaseSessionService
+import google.adk.cli.utils.service_factory as service_factory
+from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
+from google.adk.sessions.database_session_service import DatabaseSessionService
+import pytest
+
+
+def test_create_session_service_uses_registry(tmp_path: Path, monkeypatch):
+ registry = Mock()
+ expected = object()
+ registry.create_session_service.return_value = expected
+ monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry)
+
+ result = service_factory.create_session_service_from_options(
+ base_dir=tmp_path,
+ session_service_uri="sqlite:///test.db",
+ )
+
+ assert result is expected
+ registry.create_session_service.assert_called_once_with(
+ "sqlite:///test.db",
+ agents_dir=str(tmp_path),
+ )
+
+
+@pytest.mark.asyncio
+async def test_create_session_service_defaults_to_per_agent_sqlite(
+ tmp_path: Path,
+) -> None:
+ agent_dir = tmp_path / "agent_a"
+ agent_dir.mkdir()
+ service = service_factory.create_session_service_from_options(
+ base_dir=tmp_path,
+ )
+
+ assert isinstance(service, PerAgentDatabaseSessionService)
+ session = await service.create_session(app_name="agent_a", user_id="user")
+ assert session.app_name == "agent_a"
+ assert (agent_dir / ".adk" / "session.db").exists()
+
+
+def test_create_session_service_fallbacks_to_database(
+ tmp_path: Path, monkeypatch
+):
+ registry = Mock()
+ registry.create_session_service.return_value = None
+ monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry)
+
+ service = service_factory.create_session_service_from_options(
+ base_dir=tmp_path,
+ session_service_uri="sqlite+aiosqlite:///:memory:",
+ session_db_kwargs={"echo": True},
+ )
+
+ assert isinstance(service, DatabaseSessionService)
+ assert service.db_engine.url.drivername == "sqlite+aiosqlite"
+ assert service.db_engine.echo is True
+ registry.create_session_service.assert_called_once_with(
+ "sqlite+aiosqlite:///:memory:",
+ agents_dir=str(tmp_path),
+ echo=True,
+ )
+
+
+def test_create_artifact_service_uses_registry(tmp_path: Path, monkeypatch):
+ registry = Mock()
+ expected = object()
+ registry.create_artifact_service.return_value = expected
+ monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry)
+
+ result = service_factory.create_artifact_service_from_options(
+ base_dir=tmp_path,
+ artifact_service_uri="gs://bucket/path",
+ )
+
+ assert result is expected
+ registry.create_artifact_service.assert_called_once_with(
+ "gs://bucket/path",
+ agents_dir=str(tmp_path),
+ )
+
+
+def test_create_memory_service_uses_registry(tmp_path: Path, monkeypatch):
+ registry = Mock()
+ expected = object()
+ registry.create_memory_service.return_value = expected
+ monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry)
+
+ result = service_factory.create_memory_service_from_options(
+ base_dir=tmp_path,
+ memory_service_uri="rag://my-corpus",
+ )
+
+ assert result is expected
+ registry.create_memory_service.assert_called_once_with(
+ "rag://my-corpus",
+ agents_dir=str(tmp_path),
+ )
+
+
+def test_create_memory_service_defaults_to_in_memory(tmp_path: Path):
+ service = service_factory.create_memory_service_from_options(
+ base_dir=tmp_path
+ )
+
+ assert isinstance(service, InMemoryMemoryService)
+
+
+def test_create_memory_service_raises_on_unknown_scheme(
+ tmp_path: Path, monkeypatch
+):
+ registry = Mock()
+ registry.create_memory_service.return_value = None
+ monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry)
+
+ with pytest.raises(ValueError):
+ service_factory.create_memory_service_from_options(
+ base_dir=tmp_path,
+ memory_service_uri="unknown://foo",
+ )
diff --git a/tests/unittests/code_executors/test_unsafe_local_code_executor.py b/tests/unittests/code_executors/test_unsafe_local_code_executor.py
index eeb10b34fa..e5d5c4f792 100644
--- a/tests/unittests/code_executors/test_unsafe_local_code_executor.py
+++ b/tests/unittests/code_executors/test_unsafe_local_code_executor.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import textwrap
from unittest.mock import MagicMock
from google.adk.agents.base_agent import BaseAgent
@@ -101,3 +102,22 @@ def test_execute_code_empty(self, mock_invocation_context: InvocationContext):
result = executor.execute_code(mock_invocation_context, code_input)
assert result.stdout == ""
assert result.stderr == ""
+
+ def test_execute_code_nested_function_call(
+ self, mock_invocation_context: InvocationContext
+ ):
+ executor = UnsafeLocalCodeExecutor()
+ code_input = CodeExecutionInput(code=(textwrap.dedent("""
+ def helper(name):
+ return f'hi {name}'
+
+ def run():
+ print(helper('ada'))
+
+ run()
+ """)))
+
+ result = executor.execute_code(mock_invocation_context, code_input)
+
+ assert result.stderr == ""
+ assert result.stdout == "hi ada\n"
diff --git a/tests/unittests/evaluation/simulation/__init__.py b/tests/unittests/evaluation/simulation/__init__.py
new file mode 100644
index 0000000000..0a2669d7a2
--- /dev/null
+++ b/tests/unittests/evaluation/simulation/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/unittests/evaluation/test_llm_backed_user_simulator.py b/tests/unittests/evaluation/simulation/test_llm_backed_user_simulator.py
similarity index 95%
rename from tests/unittests/evaluation/test_llm_backed_user_simulator.py
rename to tests/unittests/evaluation/simulation/test_llm_backed_user_simulator.py
index 6ef3969e70..75db778bc7 100644
--- a/tests/unittests/evaluation/test_llm_backed_user_simulator.py
+++ b/tests/unittests/evaluation/simulation/test_llm_backed_user_simulator.py
@@ -15,9 +15,9 @@
from __future__ import annotations
from google.adk.evaluation import conversation_scenarios
-from google.adk.evaluation.llm_backed_user_simulator import LlmBackedUserSimulator
-from google.adk.evaluation.llm_backed_user_simulator import LlmBackedUserSimulatorConfig
-from google.adk.evaluation.user_simulator import Status
+from google.adk.evaluation.simulation.llm_backed_user_simulator import LlmBackedUserSimulator
+from google.adk.evaluation.simulation.llm_backed_user_simulator import LlmBackedUserSimulatorConfig
+from google.adk.evaluation.simulation.user_simulator import Status
from google.adk.events.event import Event
from google.genai import types
import pytest
@@ -112,7 +112,7 @@ async def to_async_iter(items):
def mock_llm_agent(mocker):
"""Provides a mock LLM agent."""
mock_llm_registry_cls = mocker.patch(
- "google.adk.evaluation.llm_backed_user_simulator.LLMRegistry"
+ "google.adk.evaluation.simulation.llm_backed_user_simulator.LLMRegistry"
)
mock_llm_registry = mocker.MagicMock()
mock_llm_registry_cls.return_value = mock_llm_registry
diff --git a/tests/unittests/evaluation/test_static_user_simulator.py b/tests/unittests/evaluation/simulation/test_static_user_simulator.py
similarity index 93%
rename from tests/unittests/evaluation/test_static_user_simulator.py
rename to tests/unittests/evaluation/simulation/test_static_user_simulator.py
index 5cc70c80e6..f18c23f5f2 100644
--- a/tests/unittests/evaluation/test_static_user_simulator.py
+++ b/tests/unittests/evaluation/simulation/test_static_user_simulator.py
@@ -14,9 +14,9 @@
from __future__ import annotations
-from google.adk.evaluation import static_user_simulator
-from google.adk.evaluation import user_simulator
from google.adk.evaluation.eval_case import Invocation
+from google.adk.evaluation.simulation import static_user_simulator
+from google.adk.evaluation.simulation import user_simulator
from google.genai import types
import pytest
diff --git a/tests/unittests/evaluation/test_user_simulator.py b/tests/unittests/evaluation/simulation/test_user_simulator.py
similarity index 90%
rename from tests/unittests/evaluation/test_user_simulator.py
rename to tests/unittests/evaluation/simulation/test_user_simulator.py
index c3e1e606ee..dbe7aff1db 100644
--- a/tests/unittests/evaluation/test_user_simulator.py
+++ b/tests/unittests/evaluation/simulation/test_user_simulator.py
@@ -14,8 +14,8 @@
from __future__ import annotations
-from google.adk.evaluation.user_simulator import NextUserMessage
-from google.adk.evaluation.user_simulator import Status
+from google.adk.evaluation.simulation.user_simulator import NextUserMessage
+from google.adk.evaluation.simulation.user_simulator import Status
from google.genai.types import Content
import pytest
diff --git a/tests/unittests/evaluation/test_user_simulator_provider.py b/tests/unittests/evaluation/simulation/test_user_simulator_provider.py
similarity index 86%
rename from tests/unittests/evaluation/test_user_simulator_provider.py
rename to tests/unittests/evaluation/simulation/test_user_simulator_provider.py
index 7cff4241b6..c4fb826fb7 100644
--- a/tests/unittests/evaluation/test_user_simulator_provider.py
+++ b/tests/unittests/evaluation/simulation/test_user_simulator_provider.py
@@ -16,10 +16,10 @@
from google.adk.evaluation import conversation_scenarios
from google.adk.evaluation import eval_case
-from google.adk.evaluation import user_simulator_provider
-from google.adk.evaluation.llm_backed_user_simulator import LlmBackedUserSimulator
-from google.adk.evaluation.llm_backed_user_simulator import LlmBackedUserSimulatorConfig
-from google.adk.evaluation.static_user_simulator import StaticUserSimulator
+from google.adk.evaluation.simulation import user_simulator_provider
+from google.adk.evaluation.simulation.llm_backed_user_simulator import LlmBackedUserSimulator
+from google.adk.evaluation.simulation.llm_backed_user_simulator import LlmBackedUserSimulatorConfig
+from google.adk.evaluation.simulation.static_user_simulator import StaticUserSimulator
from google.genai import types
import pytest
@@ -52,7 +52,7 @@ def test_provide_static_user_simulator(self):
def test_provide_llm_backed_user_simulator(self, mocker):
"""Tests the case when a LlmBackedUserSimulator should be provided."""
mock_llm_registry = mocker.patch(
- 'google.adk.evaluation.llm_backed_user_simulator.LLMRegistry',
+ 'google.adk.evaluation.simulation.llm_backed_user_simulator.LLMRegistry',
autospec=True,
)
mock_llm_registry.return_value.resolve.return_value = mocker.Mock()
diff --git a/tests/unittests/evaluation/test_evaluation_generator.py b/tests/unittests/evaluation/test_evaluation_generator.py
index 27372f12c2..873239e7f4 100644
--- a/tests/unittests/evaluation/test_evaluation_generator.py
+++ b/tests/unittests/evaluation/test_evaluation_generator.py
@@ -18,9 +18,9 @@
from google.adk.evaluation.app_details import AppDetails
from google.adk.evaluation.evaluation_generator import EvaluationGenerator
from google.adk.evaluation.request_intercepter_plugin import _RequestIntercepterPlugin
-from google.adk.evaluation.user_simulator import NextUserMessage
-from google.adk.evaluation.user_simulator import Status as UserSimulatorStatus
-from google.adk.evaluation.user_simulator import UserSimulator
+from google.adk.evaluation.simulation.user_simulator import NextUserMessage
+from google.adk.evaluation.simulation.user_simulator import Status as UserSimulatorStatus
+from google.adk.evaluation.simulation.user_simulator import UserSimulator
from google.adk.events.event import Event
from google.adk.models.llm_request import LlmRequest
from google.genai import types
diff --git a/tests/unittests/evaluation/test_local_eval_service.py b/tests/unittests/evaluation/test_local_eval_service.py
index cf2ca342f3..66080828d8 100644
--- a/tests/unittests/evaluation/test_local_eval_service.py
+++ b/tests/unittests/evaluation/test_local_eval_service.py
@@ -536,9 +536,6 @@ def test_generate_final_eval_status_doesn_t_throw_on(eval_service):
@pytest.mark.asyncio
-@pytest.mark.skipif(
- sys.version_info < (3, 10), reason="MCP tool requires Python 3.10+"
-)
async def test_mcp_stdio_agent_no_runtime_error(mocker):
"""Test that LocalEvalService can handle MCP stdio agents without RuntimeError.
diff --git a/tests/unittests/flows/llm_flows/test_agent_transfer_system_instructions.py b/tests/unittests/flows/llm_flows/test_agent_transfer_system_instructions.py
index be97a627a1..b180a589cb 100644
--- a/tests/unittests/flows/llm_flows/test_agent_transfer_system_instructions.py
+++ b/tests/unittests/flows/llm_flows/test_agent_transfer_system_instructions.py
@@ -126,15 +126,16 @@ async def test_agent_transfer_includes_sorted_agent_names_in_system_instructions
Agent description: Peer agent
-If you are the best to answer the question according to your description, you
-can answer it.
+If you are the best to answer the question according to your description,
+you can answer it.
If another agent is better for answering the question according to its
-description, call `transfer_to_agent` function to transfer the
-question to that agent. When transferring, do not generate any text other than
-the function call.
+description, call `transfer_to_agent` function to transfer the question to that
+agent. When transferring, do not generate any text other than the function
+call.
-**NOTE**: the only available agents for `transfer_to_agent` function are `a_agent`, `m_agent`, `parent_agent`, `peer_agent`, `z_agent`.
+**NOTE**: the only available agents for `transfer_to_agent` function are
+`a_agent`, `m_agent`, `parent_agent`, `peer_agent`, `z_agent`.
If neither you nor the other agents are best for the question, transfer to your parent agent parent_agent."""
@@ -189,15 +190,16 @@ async def test_agent_transfer_system_instructions_without_parent():
Agent description: Second sub-agent
-If you are the best to answer the question according to your description, you
-can answer it.
+If you are the best to answer the question according to your description,
+you can answer it.
If another agent is better for answering the question according to its
-description, call `transfer_to_agent` function to transfer the
-question to that agent. When transferring, do not generate any text other than
-the function call.
+description, call `transfer_to_agent` function to transfer the question to that
+agent. When transferring, do not generate any text other than the function
+call.
-**NOTE**: the only available agents for `transfer_to_agent` function are `agent1`, `agent2`."""
+**NOTE**: the only available agents for `transfer_to_agent` function are
+`agent1`, `agent2`."""
assert expected_content in instructions
@@ -248,15 +250,16 @@ async def test_agent_transfer_simplified_parent_instructions():
Agent description: Parent agent
-If you are the best to answer the question according to your description, you
-can answer it.
+If you are the best to answer the question according to your description,
+you can answer it.
If another agent is better for answering the question according to its
-description, call `transfer_to_agent` function to transfer the
-question to that agent. When transferring, do not generate any text other than
-the function call.
+description, call `transfer_to_agent` function to transfer the question to that
+agent. When transferring, do not generate any text other than the function
+call.
-**NOTE**: the only available agents for `transfer_to_agent` function are `parent_agent`, `sub_agent`.
+**NOTE**: the only available agents for `transfer_to_agent` function are
+`parent_agent`, `sub_agent`.
If neither you nor the other agents are best for the question, transfer to your parent agent parent_agent."""
diff --git a/tests/unittests/flows/llm_flows/test_contents.py b/tests/unittests/flows/llm_flows/test_contents.py
index cf55630b67..b2aa91dbee 100644
--- a/tests/unittests/flows/llm_flows/test_contents.py
+++ b/tests/unittests/flows/llm_flows/test_contents.py
@@ -465,6 +465,43 @@ async def test_events_with_empty_content_are_skipped():
author="user",
content=types.UserContent("How are you?"),
),
+ # Event with content that has only empty text part
+ Event(
+ invocation_id="inv6",
+ author="user",
+ content=types.Content(parts=[types.Part(text="")], role="model"),
+ ),
+ # Event with content that has only inline data part
+ Event(
+ invocation_id="inv7",
+ author="user",
+ content=types.Content(
+ parts=[
+ types.Part(
+ inline_data=types.Blob(
+ data=b"test", mime_type="image/png"
+ )
+ )
+ ],
+ role="user",
+ ),
+ ),
+ # Event with content that has only file data part
+ Event(
+ invocation_id="inv8",
+ author="user",
+ content=types.Content(
+ parts=[
+ types.Part(
+ file_data=types.FileData(
+ file_uri="gs://test_bucket/test_file.png",
+ mime_type="image/png",
+ )
+ )
+ ],
+ role="user",
+ ),
+ ),
]
invocation_context.session.events = events
@@ -478,4 +515,23 @@ async def test_events_with_empty_content_are_skipped():
assert llm_request.contents == [
types.UserContent("Hello"),
types.UserContent("How are you?"),
+ types.Content(
+ parts=[
+ types.Part(
+ inline_data=types.Blob(data=b"test", mime_type="image/png")
+ )
+ ],
+ role="user",
+ ),
+ types.Content(
+ parts=[
+ types.Part(
+ file_data=types.FileData(
+ file_uri="gs://test_bucket/test_file.png",
+ mime_type="image/png",
+ )
+ )
+ ],
+ role="user",
+ ),
]
diff --git a/tests/unittests/flows/llm_flows/test_progressive_sse_streaming.py b/tests/unittests/flows/llm_flows/test_progressive_sse_streaming.py
index e64613ff8d..e589d51c7d 100644
--- a/tests/unittests/flows/llm_flows/test_progressive_sse_streaming.py
+++ b/tests/unittests/flows/llm_flows/test_progressive_sse_streaming.py
@@ -397,3 +397,245 @@ def test_progressive_sse_preserves_part_ordering():
# Part 4: Second function call (from chunk8)
assert parts[4].function_call.name == "get_weather"
assert parts[4].function_call.args["location"] == "New York"
+
+
+def test_progressive_sse_streaming_function_call_arguments():
+ """Test streaming function call arguments feature.
+
+ This test simulates the streamFunctionCallArguments feature where a function
+ call's arguments are streamed incrementally across multiple chunks:
+
+ Chunk 1: FC name + partial location argument ("New ")
+ Chunk 2: Continue location argument ("York") -> concatenated to "New York"
+ Chunk 3: Add unit argument ("celsius"), willContinue=False -> FC complete
+
+ Expected result: FunctionCall(name="get_weather",
+ args={"location": "New York", "unit":
+ "celsius"},
+ id="fc_001")
+ """
+
+ aggregator = StreamingResponseAggregator()
+
+ # Chunk 1: FC name + partial location argument
+ chunk1_fc = types.FunctionCall(
+ name="get_weather",
+ id="fc_001",
+ partial_args=[
+ types.PartialArg(json_path="$.location", string_value="New ")
+ ],
+ will_continue=True,
+ )
+ chunk1 = types.GenerateContentResponse(
+ candidates=[
+ types.Candidate(
+ content=types.Content(
+ role="model", parts=[types.Part(function_call=chunk1_fc)]
+ )
+ )
+ ]
+ )
+
+ # Chunk 2: Continue streaming location argument
+ chunk2_fc = types.FunctionCall(
+ partial_args=[
+ types.PartialArg(json_path="$.location", string_value="York")
+ ],
+ will_continue=True,
+ )
+ chunk2 = types.GenerateContentResponse(
+ candidates=[
+ types.Candidate(
+ content=types.Content(
+ role="model", parts=[types.Part(function_call=chunk2_fc)]
+ )
+ )
+ ]
+ )
+
+ # Chunk 3: Add unit argument, FC complete
+ chunk3_fc = types.FunctionCall(
+ partial_args=[
+ types.PartialArg(json_path="$.unit", string_value="celsius")
+ ],
+ will_continue=False, # FC complete
+ )
+ chunk3 = types.GenerateContentResponse(
+ candidates=[
+ types.Candidate(
+ content=types.Content(
+ role="model", parts=[types.Part(function_call=chunk3_fc)]
+ ),
+ finish_reason=types.FinishReason.STOP,
+ )
+ ]
+ )
+
+ # Process all chunks through aggregator
+ processed_chunks = []
+ for chunk in [chunk1, chunk2, chunk3]:
+
+ async def process():
+ results = []
+ async for response in aggregator.process_response(chunk):
+ results.append(response)
+ return results
+
+ import asyncio
+
+ chunk_results = asyncio.run(process())
+ processed_chunks.extend(chunk_results)
+
+ # Get final aggregated response
+ final_response = aggregator.close()
+
+ # Verify final aggregated response has complete FC
+ assert final_response is not None
+ assert len(final_response.content.parts) == 1
+
+ fc_part = final_response.content.parts[0]
+ assert fc_part.function_call is not None
+ assert fc_part.function_call.name == "get_weather"
+ assert fc_part.function_call.id == "fc_001"
+
+ # Verify arguments were correctly assembled from streaming chunks
+ args = fc_part.function_call.args
+ assert args["location"] == "New York" # "New " + "York" concatenated
+ assert args["unit"] == "celsius"
+
+
+def test_progressive_sse_preserves_thought_signature():
+ """Test that thought_signature is preserved when streaming FC arguments.
+
+ This test verifies that when a streaming function call has a thought_signature
+ in the Part, it is correctly preserved in the final aggregated FunctionCall.
+ """
+
+ aggregator = StreamingResponseAggregator()
+
+ # Create a thought signature (simulating what Gemini returns)
+ # thought_signature is bytes (base64 encoded)
+ test_thought_signature = b"test_signature_abc123"
+
+ # Chunk with streaming FC args and thought_signature
+ chunk_fc = types.FunctionCall(
+ name="add_5_numbers",
+ id="fc_003",
+ partial_args=[
+ types.PartialArg(json_path="$.num1", number_value=10),
+ types.PartialArg(json_path="$.num2", number_value=20),
+ ],
+ will_continue=False,
+ )
+
+ # Create Part with both function_call AND thought_signature
+ chunk_part = types.Part(
+ function_call=chunk_fc, thought_signature=test_thought_signature
+ )
+
+ chunk = types.GenerateContentResponse(
+ candidates=[
+ types.Candidate(
+ content=types.Content(role="model", parts=[chunk_part]),
+ finish_reason=types.FinishReason.STOP,
+ )
+ ]
+ )
+
+ # Process chunk through aggregator
+ async def process():
+ results = []
+ async for response in aggregator.process_response(chunk):
+ results.append(response)
+ return results
+
+ import asyncio
+
+ asyncio.run(process())
+
+ # Get final aggregated response
+ final_response = aggregator.close()
+
+ # Verify thought_signature was preserved in the Part
+ assert final_response is not None
+ assert len(final_response.content.parts) == 1
+
+ fc_part = final_response.content.parts[0]
+ assert fc_part.function_call is not None
+ assert fc_part.function_call.name == "add_5_numbers"
+
+ assert fc_part.thought_signature == test_thought_signature
+
+
+def test_progressive_sse_handles_empty_function_call():
+ """Test that empty function calls are skipped.
+
+ When using streamFunctionCallArguments, Gemini may send an empty
+ functionCall: {} as the final chunk to signal streaming completion.
+ This test verifies that such empty function calls are properly skipped
+ and don't cause errors.
+ """
+
+ aggregator = StreamingResponseAggregator()
+
+ # Chunk 1: Streaming FC with partial args
+ chunk1_fc = types.FunctionCall(
+ name="concat_number_and_string",
+ id="fc_001",
+ partial_args=[
+ types.PartialArg(json_path="$.num", number_value=100),
+ types.PartialArg(json_path="$.s", string_value="ADK"),
+ ],
+ will_continue=False,
+ )
+ chunk1 = types.GenerateContentResponse(
+ candidates=[
+ types.Candidate(
+ content=types.Content(
+ role="model", parts=[types.Part(function_call=chunk1_fc)]
+ )
+ )
+ ]
+ )
+
+ # Chunk 2: Empty function call (streaming end marker)
+ chunk2_fc = types.FunctionCall() # Empty function call
+ chunk2 = types.GenerateContentResponse(
+ candidates=[
+ types.Candidate(
+ content=types.Content(
+ role="model", parts=[types.Part(function_call=chunk2_fc)]
+ ),
+ finish_reason=types.FinishReason.STOP,
+ )
+ ]
+ )
+
+ # Process all chunks through aggregator
+ async def process():
+ results = []
+ for chunk in [chunk1, chunk2]:
+ async for response in aggregator.process_response(chunk):
+ results.append(response)
+ return results
+
+ import asyncio
+
+ asyncio.run(process())
+
+ # Get final aggregated response
+ final_response = aggregator.close()
+
+ # Verify final response only has the real FC, not the empty one
+ assert final_response is not None
+ assert len(final_response.content.parts) == 1
+
+ fc_part = final_response.content.parts[0]
+ assert fc_part.function_call is not None
+ assert fc_part.function_call.name == "concat_number_and_string"
+ assert fc_part.function_call.id == "fc_001"
+
+ # Verify arguments
+ args = fc_part.function_call.args
+ assert args["num"] == 100
+ assert args["s"] == "ADK"
diff --git a/tests/unittests/models/test_anthropic_llm.py b/tests/unittests/models/test_anthropic_llm.py
index e5ac8cc051..e1880abf0d 100644
--- a/tests/unittests/models/test_anthropic_llm.py
+++ b/tests/unittests/models/test_anthropic_llm.py
@@ -19,7 +19,9 @@
from anthropic import types as anthropic_types
from google.adk import version as adk_version
from google.adk.models import anthropic_llm
+from google.adk.models.anthropic_llm import AnthropicLlm
from google.adk.models.anthropic_llm import Claude
+from google.adk.models.anthropic_llm import content_to_message_param
from google.adk.models.anthropic_llm import function_declaration_to_tool_param
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
@@ -358,6 +360,37 @@ async def mock_coro():
assert responses[0].content.parts[0].text == "Hello, how can I help you?"
+@pytest.mark.asyncio
+async def test_anthropic_llm_generate_content_async(
+ llm_request, generate_content_response, generate_llm_response
+):
+ anthropic_llm_instance = AnthropicLlm(model="claude-sonnet-4-20250514")
+ with mock.patch.object(
+ anthropic_llm_instance, "_anthropic_client"
+ ) as mock_client:
+ with mock.patch.object(
+ anthropic_llm,
+ "message_to_generate_content_response",
+ return_value=generate_llm_response,
+ ):
+ # Create a mock coroutine that returns the generate_content_response.
+ async def mock_coro():
+ return generate_content_response
+
+ # Assign the coroutine to the mocked method
+ mock_client.messages.create.return_value = mock_coro()
+
+ responses = [
+ resp
+ async for resp in anthropic_llm_instance.generate_content_async(
+ llm_request, stream=False
+ )
+ ]
+ assert len(responses) == 1
+ assert isinstance(responses[0], LlmResponse)
+ assert responses[0].content.parts[0].text == "Hello, how can I help you?"
+
+
@pytest.mark.asyncio
async def test_generate_content_async_with_max_tokens(
llm_request, generate_content_response, generate_llm_response
@@ -462,3 +495,81 @@ def test_part_to_message_block_with_multiple_content_items():
assert isinstance(result, dict)
# Multiple text items should be joined with newlines
assert result["content"] == "First part\nSecond part"
+
+
+content_to_message_param_test_cases = [
+ (
+ "user_role_with_text_and_image",
+ Content(
+ role="user",
+ parts=[
+ Part.from_text(text="What's in this image?"),
+ Part(
+ inline_data=types.Blob(
+ mime_type="image/jpeg", data=b"fake_image_data"
+ )
+ ),
+ ],
+ ),
+ "user",
+ 2, # Expected content length
+ False, # Should not log warning
+ ),
+ (
+ "model_role_with_text_and_image",
+ Content(
+ role="model",
+ parts=[
+ Part.from_text(text="I see a cat."),
+ Part(
+ inline_data=types.Blob(
+ mime_type="image/png", data=b"fake_image_data"
+ )
+ ),
+ ],
+ ),
+ "assistant",
+ 1, # Image filtered out, only text remains
+ True, # Should log warning
+ ),
+ (
+ "assistant_role_with_text_and_image",
+ Content(
+ role="assistant",
+ parts=[
+ Part.from_text(text="Here's what I found."),
+ Part(
+ inline_data=types.Blob(
+ mime_type="image/webp", data=b"fake_image_data"
+ )
+ ),
+ ],
+ ),
+ "assistant",
+ 1, # Image filtered out, only text remains
+ True, # Should log warning
+ ),
+]
+
+
+@pytest.mark.parametrize(
+ "_, content, expected_role, expected_content_length, should_log_warning",
+ content_to_message_param_test_cases,
+ ids=[case[0] for case in content_to_message_param_test_cases],
+)
+def test_content_to_message_param_with_images(
+ _, content, expected_role, expected_content_length, should_log_warning
+):
+ """Test content_to_message_param handles images correctly based on role."""
+ with mock.patch("google.adk.models.anthropic_llm.logger") as mock_logger:
+ result = content_to_message_param(content)
+
+ assert result["role"] == expected_role
+ assert len(result["content"]) == expected_content_length
+
+ if should_log_warning:
+ mock_logger.warning.assert_called_once_with(
+ "Image data is not supported in Claude for assistant turns."
+ )
+ else:
+ mock_logger.warning.assert_not_called()
diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py
index 6d3e685748..190007603c 100644
--- a/tests/unittests/models/test_gemini_llm_connection.py
+++ b/tests/unittests/models/test_gemini_llm_connection.py
@@ -15,6 +15,7 @@
from unittest import mock
from google.adk.models.gemini_llm_connection import GeminiLlmConnection
+from google.adk.utils.variant_utils import GoogleLLMVariant
from google.genai import types
import pytest
@@ -28,7 +29,17 @@ def mock_gemini_session():
@pytest.fixture
def gemini_connection(mock_gemini_session):
"""GeminiLlmConnection instance with mocked session."""
- return GeminiLlmConnection(mock_gemini_session)
+ return GeminiLlmConnection(
+ mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI
+ )
+
+
+@pytest.fixture
+def gemini_api_connection(mock_gemini_session):
+ """GeminiLlmConnection instance with mocked session for Gemini API."""
+ return GeminiLlmConnection(
+ mock_gemini_session, api_backend=GoogleLLMVariant.GEMINI_API
+ )
@pytest.fixture
@@ -225,6 +236,227 @@ async def mock_receive_generator():
assert content_response.content == mock_content
+@pytest.mark.asyncio
+async def test_receive_transcript_finished_on_interrupt(
+ gemini_api_connection,
+ mock_gemini_session,
+):
+ """Test receive finishes transcription on interrupt signal."""
+
+ message1 = mock.Mock()
+ message1.usage_metadata = None
+ message1.server_content = mock.Mock()
+ message1.server_content.model_turn = None
+ message1.server_content.interrupted = False
+ message1.server_content.input_transcription = types.Transcription(
+ text='Hello', finished=False
+ )
+ message1.server_content.output_transcription = None
+ message1.server_content.turn_complete = False
+ message1.server_content.generation_complete = False
+ message1.tool_call = None
+ message1.session_resumption_update = None
+
+ message2 = mock.Mock()
+ message2.usage_metadata = None
+ message2.server_content = mock.Mock()
+ message2.server_content.model_turn = None
+ message2.server_content.interrupted = False
+ message2.server_content.input_transcription = None
+ message2.server_content.output_transcription = types.Transcription(
+ text='How can', finished=False
+ )
+ message2.server_content.turn_complete = False
+ message2.server_content.generation_complete = False
+ message2.tool_call = None
+ message2.session_resumption_update = None
+
+ message3 = mock.Mock()
+ message3.usage_metadata = None
+ message3.server_content = mock.Mock()
+ message3.server_content.model_turn = None
+ message3.server_content.interrupted = True
+ message3.server_content.input_transcription = None
+ message3.server_content.output_transcription = None
+ message3.server_content.turn_complete = False
+ message3.server_content.generation_complete = False
+ message3.tool_call = None
+ message3.session_resumption_update = None
+
+ async def mock_receive_generator():
+ yield message1
+ yield message2
+ yield message3
+
+ receive_mock = mock.Mock(return_value=mock_receive_generator())
+ mock_gemini_session.receive = receive_mock
+
+ responses = [resp async for resp in gemini_api_connection.receive()]
+
+ assert len(responses) == 5
+ assert responses[4].interrupted is True
+
+ assert responses[0].input_transcription.text == 'Hello'
+ assert responses[0].input_transcription.finished is False
+ assert responses[0].partial is True
+ assert responses[1].output_transcription.text == 'How can'
+ assert responses[1].output_transcription.finished is False
+ assert responses[1].partial is True
+ assert responses[2].input_transcription.text == 'Hello'
+ assert responses[2].input_transcription.finished is True
+ assert responses[2].partial is False
+ assert responses[3].output_transcription.text == 'How can'
+ assert responses[3].output_transcription.finished is True
+ assert responses[3].partial is False
+
+
+@pytest.mark.asyncio
+async def test_receive_transcript_finished_on_generation_complete(
+ gemini_api_connection,
+ mock_gemini_session,
+):
+ """Test receive finishes transcription on generation_complete signal."""
+
+ message1 = mock.Mock()
+ message1.usage_metadata = None
+ message1.server_content = mock.Mock()
+ message1.server_content.model_turn = None
+ message1.server_content.interrupted = False
+ message1.server_content.input_transcription = types.Transcription(
+ text='Hello', finished=False
+ )
+ message1.server_content.output_transcription = None
+ message1.server_content.turn_complete = False
+ message1.server_content.generation_complete = False
+ message1.tool_call = None
+ message1.session_resumption_update = None
+
+ message2 = mock.Mock()
+ message2.usage_metadata = None
+ message2.server_content = mock.Mock()
+ message2.server_content.model_turn = None
+ message2.server_content.interrupted = False
+ message2.server_content.input_transcription = None
+ message2.server_content.output_transcription = types.Transcription(
+ text='How can', finished=False
+ )
+ message2.server_content.turn_complete = False
+ message2.server_content.generation_complete = False
+ message2.tool_call = None
+ message2.session_resumption_update = None
+
+ message3 = mock.Mock()
+ message3.usage_metadata = None
+ message3.server_content = mock.Mock()
+ message3.server_content.model_turn = None
+ message3.server_content.interrupted = False
+ message3.server_content.input_transcription = None
+ message3.server_content.output_transcription = None
+ message3.server_content.turn_complete = False
+ message3.server_content.generation_complete = True
+ message3.tool_call = None
+ message3.session_resumption_update = None
+
+ async def mock_receive_generator():
+ yield message1
+ yield message2
+ yield message3
+
+ receive_mock = mock.Mock(return_value=mock_receive_generator())
+ mock_gemini_session.receive = receive_mock
+
+ responses = [resp async for resp in gemini_api_connection.receive()]
+
+ assert len(responses) == 4
+
+ assert responses[0].input_transcription.text == 'Hello'
+ assert responses[0].input_transcription.finished is False
+ assert responses[0].partial is True
+ assert responses[1].output_transcription.text == 'How can'
+ assert responses[1].output_transcription.finished is False
+ assert responses[1].partial is True
+ assert responses[2].input_transcription.text == 'Hello'
+ assert responses[2].input_transcription.finished is True
+ assert responses[2].partial is False
+ assert responses[3].output_transcription.text == 'How can'
+ assert responses[3].output_transcription.finished is True
+ assert responses[3].partial is False
+
+
+@pytest.mark.asyncio
+async def test_receive_transcript_finished_on_turn_complete(
+ gemini_api_connection,
+ mock_gemini_session,
+):
+ """Test receive finishes transcription on interrupt or complete signals."""
+
+ message1 = mock.Mock()
+ message1.usage_metadata = None
+ message1.server_content = mock.Mock()
+ message1.server_content.model_turn = None
+ message1.server_content.interrupted = False
+ message1.server_content.input_transcription = types.Transcription(
+ text='Hello', finished=False
+ )
+ message1.server_content.output_transcription = None
+ message1.server_content.turn_complete = False
+ message1.server_content.generation_complete = False
+ message1.tool_call = None
+ message1.session_resumption_update = None
+
+ message2 = mock.Mock()
+ message2.usage_metadata = None
+ message2.server_content = mock.Mock()
+ message2.server_content.model_turn = None
+ message2.server_content.interrupted = False
+ message2.server_content.input_transcription = None
+ message2.server_content.output_transcription = types.Transcription(
+ text='How can', finished=False
+ )
+ message2.server_content.turn_complete = False
+ message2.server_content.generation_complete = False
+ message2.tool_call = None
+ message2.session_resumption_update = None
+
+ message3 = mock.Mock()
+ message3.usage_metadata = None
+ message3.server_content = mock.Mock()
+ message3.server_content.model_turn = None
+ message3.server_content.interrupted = False
+ message3.server_content.input_transcription = None
+ message3.server_content.output_transcription = None
+ message3.server_content.turn_complete = True
+ message3.server_content.generation_complete = False
+ message3.tool_call = None
+ message3.session_resumption_update = None
+
+ async def mock_receive_generator():
+ yield message1
+ yield message2
+ yield message3
+
+ receive_mock = mock.Mock(return_value=mock_receive_generator())
+ mock_gemini_session.receive = receive_mock
+
+ responses = [resp async for resp in gemini_api_connection.receive()]
+
+ assert len(responses) == 5
+ assert responses[4].turn_complete is True
+
+ assert responses[0].input_transcription.text == 'Hello'
+ assert responses[0].input_transcription.finished is False
+ assert responses[0].partial is True
+ assert responses[1].output_transcription.text == 'How can'
+ assert responses[1].output_transcription.finished is False
+ assert responses[1].partial is True
+ assert responses[2].input_transcription.text == 'Hello'
+ assert responses[2].input_transcription.finished is True
+ assert responses[2].partial is False
+ assert responses[3].output_transcription.text == 'How can'
+ assert responses[3].output_transcription.finished is True
+ assert responses[3].partial is False
+
+
@pytest.mark.asyncio
async def test_receive_handles_input_transcription_fragments(
gemini_connection, mock_gemini_session
@@ -240,6 +472,7 @@ async def test_receive_handles_input_transcription_fragments(
)
message1.server_content.output_transcription = None
message1.server_content.turn_complete = False
+ message1.server_content.generation_complete = False
message1.tool_call = None
message1.session_resumption_update = None
@@ -253,6 +486,7 @@ async def test_receive_handles_input_transcription_fragments(
)
message2.server_content.output_transcription = None
message2.server_content.turn_complete = False
+ message2.server_content.generation_complete = False
message2.tool_call = None
message2.session_resumption_update = None
@@ -266,6 +500,7 @@ async def test_receive_handles_input_transcription_fragments(
)
message3.server_content.output_transcription = None
message3.server_content.turn_complete = False
+ message3.server_content.generation_complete = False
message3.tool_call = None
message3.session_resumption_update = None
@@ -306,6 +541,7 @@ async def test_receive_handles_output_transcription_fragments(
text='How can', finished=False
)
message1.server_content.turn_complete = False
+ message1.server_content.generation_complete = False
message1.tool_call = None
message1.session_resumption_update = None
@@ -319,6 +555,7 @@ async def test_receive_handles_output_transcription_fragments(
text=' I help?', finished=False
)
message2.server_content.turn_complete = False
+ message2.server_content.generation_complete = False
message2.tool_call = None
message2.session_resumption_update = None
@@ -332,6 +569,7 @@ async def test_receive_handles_output_transcription_fragments(
text=None, finished=True
)
message3.server_content.turn_complete = False
+ message3.server_content.generation_complete = False
message3.tool_call = None
message3.session_resumption_update = None
diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py
index f2419daf3f..ddf1b07667 100644
--- a/tests/unittests/models/test_google_llm.py
+++ b/tests/unittests/models/test_google_llm.py
@@ -22,8 +22,6 @@
from google.adk.agents.context_cache_config import ContextCacheConfig
from google.adk.models.cache_metadata import CacheMetadata
from google.adk.models.gemini_llm_connection import GeminiLlmConnection
-from google.adk.models.google_llm import _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME
-from google.adk.models.google_llm import _AGENT_ENGINE_TELEMETRY_TAG
from google.adk.models.google_llm import _build_function_declaration_log
from google.adk.models.google_llm import _build_request_log
from google.adk.models.google_llm import _RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE
@@ -31,6 +29,8 @@
from google.adk.models.google_llm import Gemini
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
+from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME
+from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_TAG
from google.adk.utils.variant_utils import GoogleLLMVariant
from google.genai import types
from google.genai.errors import ClientError
@@ -142,13 +142,6 @@ def llm_request_with_computer_use():
)
-@pytest.fixture
-def mock_os_environ():
- initial_env = os.environ.copy()
- with mock.patch.dict(os.environ, initial_env, clear=False) as m:
- yield m
-
-
def test_supported_models():
models = Gemini.supported_models()
assert len(models) == 4
@@ -193,12 +186,15 @@ def test_client_version_header():
)
-def test_client_version_header_with_agent_engine(mock_os_environ):
- os.environ[_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME] = "my_test_project"
+def test_client_version_header_with_agent_engine(monkeypatch):
+ monkeypatch.setenv(
+ _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME, "my_test_project"
+ )
model = Gemini(model="gemini-1.5-flash")
client = model.api_client
- # Check that ADK version with telemetry tag and Python version are present in headers
+ # Check that ADK version with telemetry tag and Python version are present in
+ # headers
adk_version_with_telemetry = (
f"google-adk/{adk_version.__version__}+{_AGENT_ENGINE_TELEMETRY_TAG}"
)
@@ -473,8 +469,9 @@ async def test_generate_content_async_with_custom_headers(
"""Test that tracking headers are updated when custom headers are provided."""
# Add custom headers to the request config
custom_headers = {"custom-header": "custom-value"}
- for key in gemini_llm._tracking_headers:
- custom_headers[key] = "custom " + gemini_llm._tracking_headers[key]
+ tracking_headers = gemini_llm._tracking_headers()
+ for key in tracking_headers:
+ custom_headers[key] = "custom " + tracking_headers[key]
llm_request.config.http_options = types.HttpOptions(headers=custom_headers)
with mock.patch.object(gemini_llm, "api_client") as mock_client:
@@ -497,8 +494,9 @@ async def mock_coro():
config_arg = call_args.kwargs["config"]
for key, value in config_arg.http_options.headers.items():
- if key in gemini_llm._tracking_headers:
- assert value == gemini_llm._tracking_headers[key] + " custom"
+ tracking_headers = gemini_llm._tracking_headers()
+ if key in tracking_headers:
+ assert value == tracking_headers[key] + " custom"
else:
assert value == custom_headers[key]
@@ -547,7 +545,7 @@ async def mock_coro():
config_arg = call_args.kwargs["config"]
expected_headers = custom_headers.copy()
- expected_headers.update(gemini_llm._tracking_headers)
+ expected_headers.update(gemini_llm._tracking_headers())
assert config_arg.http_options.headers == expected_headers
assert len(responses) == 2
@@ -601,7 +599,7 @@ async def mock_coro():
assert final_config.http_options is not None
assert (
final_config.http_options.headers["x-goog-api-client"]
- == gemini_llm._tracking_headers["x-goog-api-client"]
+ == gemini_llm._tracking_headers()["x-goog-api-client"]
)
assert len(responses) == 2 if stream else 1
@@ -635,7 +633,7 @@ def test_live_api_client_properties(gemini_llm):
assert http_options.api_version == "v1beta1"
# Check that tracking headers are included
- tracking_headers = gemini_llm._tracking_headers
+ tracking_headers = gemini_llm._tracking_headers()
for key, value in tracking_headers.items():
assert key in http_options.headers
assert value in http_options.headers[key]
@@ -673,7 +671,7 @@ async def __aexit__(self, *args):
# Verify that tracking headers were merged with custom headers
expected_headers = custom_headers.copy()
- expected_headers.update(gemini_llm._tracking_headers)
+ expected_headers.update(gemini_llm._tracking_headers())
assert config_arg.http_options.headers == expected_headers
# Verify that API version was set
diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py
index 8f2ae50b42..f65fc77a61 100644
--- a/tests/unittests/models/test_litellm.py
+++ b/tests/unittests/models/test_litellm.py
@@ -17,7 +17,6 @@
from unittest.mock import Mock
import warnings
-from google.adk.models.lite_llm import _build_function_declaration_log
from google.adk.models.lite_llm import _content_to_message_param
from google.adk.models.lite_llm import _FINISH_REASON_MAPPING
from google.adk.models.lite_llm import _function_declaration_to_tool_param
@@ -25,7 +24,9 @@
from google.adk.models.lite_llm import _get_content
from google.adk.models.lite_llm import _message_to_generate_content_response
from google.adk.models.lite_llm import _model_response_to_chunk
+from google.adk.models.lite_llm import _model_response_to_generate_content_response
from google.adk.models.lite_llm import _parse_tool_calls_from_text
+from google.adk.models.lite_llm import _schema_to_dict
from google.adk.models.lite_llm import _split_message_content_and_tool_calls
from google.adk.models.lite_llm import _to_litellm_response_format
from google.adk.models.lite_llm import _to_litellm_role
@@ -286,6 +287,28 @@ def test_to_litellm_response_format_handles_genai_schema_instance():
)
+def test_schema_to_dict_filters_none_enum_values():
+ # Use model_construct to bypass strict enum validation.
+ top_level_schema = types.Schema.model_construct(
+ type=types.Type.STRING,
+ enum=["ACTIVE", None, "INACTIVE"],
+ )
+ nested_schema = types.Schema.model_construct(
+ type=types.Type.OBJECT,
+ properties={
+ "status": types.Schema.model_construct(
+ type=types.Type.STRING, enum=["READY", None, "DONE"]
+ ),
+ },
+ )
+
+ assert _schema_to_dict(top_level_schema)["enum"] == ["ACTIVE", "INACTIVE"]
+ assert _schema_to_dict(nested_schema)["properties"]["status"]["enum"] == [
+ "READY",
+ "DONE",
+ ]
+
+
MULTIPLE_FUNCTION_CALLS_STREAM = [
ModelResponse(
choices=[
@@ -630,54 +653,6 @@ def completion(self, model, messages, tools, stream, **kwargs):
)
-def test_build_function_declaration_log():
- """Test that _build_function_declaration_log formats function declarations correctly."""
- # Test case 1: Function with parameters and response
- func_decl1 = types.FunctionDeclaration(
- name="test_func1",
- description="Test function 1",
- parameters=types.Schema(
- type=types.Type.OBJECT,
- properties={
- "param1": types.Schema(
- type=types.Type.STRING, description="param1 desc"
- )
- },
- ),
- response=types.Schema(type=types.Type.BOOLEAN, description="return bool"),
- )
- log1 = _build_function_declaration_log(func_decl1)
- assert log1 == (
- "test_func1: {'param1': {'description': 'param1 desc', 'type':"
- " }} -> {'description': 'return bool', 'type':"
- " }"
- )
-
- # Test case 2: Function with JSON schema parameters and response
- func_decl2 = types.FunctionDeclaration(
- name="test_func2",
- description="Test function 2",
- parameters_json_schema={
- "type": "object",
- "properties": {"param2": {"type": "integer"}},
- },
- response_json_schema={"type": "string"},
- )
- log2 = _build_function_declaration_log(func_decl2)
- assert log2 == (
- "test_func2: {'type': 'object', 'properties': {'param2': {'type':"
- " 'integer'}}} -> {'type': 'string'}"
- )
-
- # Test case 3: Function with no parameters and no response
- func_decl3 = types.FunctionDeclaration(
- name="test_func3",
- description="Test function 3",
- )
- log3 = _build_function_declaration_log(func_decl3)
- assert log3 == "test_func3: {} -> None"
-
-
@pytest.mark.asyncio
async def test_generate_content_async(mock_acompletion, lite_llm_instance):
@@ -1220,7 +1195,7 @@ async def test_generate_content_async_with_system_instruction(
_, kwargs = mock_acompletion.call_args
assert kwargs["model"] == "test_model"
- assert kwargs["messages"][0]["role"] == "developer"
+ assert kwargs["messages"][0]["role"] == "system"
assert kwargs["messages"][0]["content"] == "Test system instruction"
assert kwargs["messages"][1]["role"] == "user"
assert kwargs["messages"][1]["content"] == "Test prompt"
@@ -1535,6 +1510,42 @@ def test_message_to_generate_content_response_with_model():
assert response.model_version == "gemini-2.5-pro"
+def test_message_to_generate_content_response_reasoning_content():
+ message = {
+ "role": "assistant",
+ "content": "Visible text",
+ "reasoning_content": "Hidden chain",
+ }
+ response = _message_to_generate_content_response(message)
+
+ assert len(response.content.parts) == 2
+ thought_part = response.content.parts[0]
+ text_part = response.content.parts[1]
+ assert thought_part.text == "Hidden chain"
+ assert thought_part.thought is True
+ assert text_part.text == "Visible text"
+
+
+def test_model_response_to_generate_content_response_reasoning_content():
+ model_response = ModelResponse(
+ model="thinking-model",
+ choices=[{
+ "message": {
+ "role": "assistant",
+ "content": "Answer",
+ "reasoning_content": "Step-by-step",
+ },
+ "finish_reason": "stop",
+ }],
+ )
+
+ response = _model_response_to_generate_content_response(model_response)
+
+ assert response.content.parts[0].text == "Step-by-step"
+ assert response.content.parts[0].thought is True
+ assert response.content.parts[1].text == "Answer"
+
+
def test_parse_tool_calls_from_text_multiple_calls():
text = (
'{"name":"alpha","arguments":{"value":1}}\n'
diff --git a/tests/unittests/models/test_models.py b/tests/unittests/models/test_models.py
index 70246c7bc1..8575064baa 100644
--- a/tests/unittests/models/test_models.py
+++ b/tests/unittests/models/test_models.py
@@ -15,7 +15,7 @@
from google.adk import models
from google.adk.models.anthropic_llm import Claude
from google.adk.models.google_llm import Gemini
-from google.adk.models.registry import LLMRegistry
+from google.adk.models.lite_llm import LiteLlm
import pytest
@@ -34,6 +34,7 @@
],
)
def test_match_gemini_family(model_name):
+ """Test that Gemini models are resolved correctly."""
assert models.LLMRegistry.resolve(model_name) is Gemini
@@ -51,12 +52,63 @@ def test_match_gemini_family(model_name):
],
)
def test_match_claude_family(model_name):
- LLMRegistry.register(Claude)
-
+ """Test that Claude models are resolved correctly."""
assert models.LLMRegistry.resolve(model_name) is Claude
+@pytest.mark.parametrize(
+ 'model_name',
+ [
+ 'openai/gpt-4o',
+ 'openai/gpt-4o-mini',
+ 'groq/llama3-70b-8192',
+ 'groq/mixtral-8x7b-32768',
+ 'anthropic/claude-3-opus-20240229',
+ 'anthropic/claude-3-5-sonnet-20241022',
+ ],
+)
+def test_match_litellm_family(model_name):
+ """Test that LiteLLM models are resolved correctly."""
+ assert models.LLMRegistry.resolve(model_name) is LiteLlm
+
+
def test_non_exist_model():
with pytest.raises(ValueError) as e_info:
models.LLMRegistry.resolve('non-exist-model')
assert 'Model non-exist-model not found.' in str(e_info.value)
+
+
+def test_helpful_error_for_claude_without_extensions():
+ """Test that missing Claude models show helpful install instructions.
+
+ Note: This test may pass even when anthropic IS installed, because it
+ only checks the error message format when a model is not found.
+ """
+ # Use a non-existent Claude model variant to trigger error
+ with pytest.raises(ValueError) as e_info:
+ models.LLMRegistry.resolve('claude-nonexistent-model-xyz')
+
+ error_msg = str(e_info.value)
+ # The error should mention anthropic package and installation instructions
+ # These checks work whether or not anthropic is actually installed
+ assert 'Model claude-nonexistent-model-xyz not found' in error_msg
+ assert 'anthropic package' in error_msg
+ assert 'pip install' in error_msg
+
+
+def test_helpful_error_for_litellm_without_extensions():
+ """Test that missing LiteLLM models show helpful install instructions.
+
+ Note: This test may pass even when litellm IS installed, because it
+ only checks the error message format when a model is not found.
+ """
+ # Use a non-existent provider to trigger error
+ with pytest.raises(ValueError) as e_info:
+ models.LLMRegistry.resolve('unknown-provider/gpt-4o')
+
+ error_msg = str(e_info.value)
+ # The error should mention litellm package for provider-style models
+ assert 'Model unknown-provider/gpt-4o not found' in error_msg
+ assert 'litellm package' in error_msg
+ assert 'pip install' in error_msg
+ assert 'Provider-style models' in error_msg
diff --git a/tests/unittests/plugins/test_reflect_retry_tool_plugin.py b/tests/unittests/plugins/test_reflect_retry_tool_plugin.py
index 1e15f33899..2cf52e99cb 100644
--- a/tests/unittests/plugins/test_reflect_retry_tool_plugin.py
+++ b/tests/unittests/plugins/test_reflect_retry_tool_plugin.py
@@ -57,10 +57,8 @@ async def extract_error_from_result(
return None
-# Inheriting from IsolatedAsyncioTestCase ensures these tests works in Python
-# 3.9. See https://github.com/pytest-dev/pytest-asyncio/issues/1039
-# Without this, the tests will fail with a "RuntimeError: There is no current
-# event loop in thread 'MainThread'."
+# Inheriting from IsolatedAsyncioTestCase ensures consistent behavior.
+# See https://github.com/pytest-dev/pytest-asyncio/issues/1039
class TestReflectAndRetryToolPlugin(IsolatedAsyncioTestCase):
"""Comprehensive tests for ReflectAndRetryToolPlugin focusing on behavior."""
diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py
index 661e6ead59..45aa3feede 100644
--- a/tests/unittests/sessions/test_session_service.py
+++ b/tests/unittests/sessions/test_session_service.py
@@ -531,6 +531,79 @@ async def test_append_event_complete(service_type, tmp_path):
)
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ 'service_type',
+ [
+ SessionServiceType.IN_MEMORY,
+ SessionServiceType.DATABASE,
+ SessionServiceType.SQLITE,
+ ],
+)
+async def test_session_last_update_time_updates_on_event(
+ service_type, tmp_path
+):
+ session_service = get_session_service(service_type, tmp_path)
+ app_name = 'my_app'
+ user_id = 'user'
+
+ session = await session_service.create_session(
+ app_name=app_name, user_id=user_id
+ )
+ original_update_time = session.last_update_time
+
+ event_timestamp = original_update_time + 10
+ event = Event(
+ invocation_id='invocation',
+ author='user',
+ timestamp=event_timestamp,
+ )
+ await session_service.append_event(session=session, event=event)
+
+ assert session.last_update_time == pytest.approx(event_timestamp, abs=1e-6)
+
+ refreshed_session = await session_service.get_session(
+ app_name=app_name, user_id=user_id, session_id=session.id
+ )
+ assert refreshed_session is not None
+ assert refreshed_session.last_update_time == pytest.approx(
+ event_timestamp, abs=1e-6
+ )
+ assert refreshed_session.last_update_time > original_update_time
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ 'service_type',
+ [
+ SessionServiceType.IN_MEMORY,
+ SessionServiceType.DATABASE,
+ SessionServiceType.SQLITE,
+ ],
+)
+async def test_get_session_with_config(service_type):
+ session_service = get_session_service(service_type)
+ app_name = 'my_app'
+ user_id = 'user'
+
+ session = await session_service.create_session(
+ app_name=app_name, user_id=user_id
+ )
+ original_update_time = session.last_update_time
+
+ event = Event(invocation_id='invocation', author='user')
+ await session_service.append_event(session=session, event=event)
+
+ assert session.last_update_time >= event.timestamp
+
+ refreshed_session = await session_service.get_session(
+ app_name=app_name, user_id=user_id, session_id=session.id
+ )
+ assert refreshed_session is not None
+ assert refreshed_session.last_update_time >= event.timestamp
+ assert refreshed_session.last_update_time > original_update_time
+
+
@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type',
diff --git a/tests/unittests/telemetry/test_functional.py b/tests/unittests/telemetry/test_functional.py
index 409571ad1f..43fe672333 100644
--- a/tests/unittests/telemetry/test_functional.py
+++ b/tests/unittests/telemetry/test_functional.py
@@ -103,7 +103,7 @@ def wrapped_firstiter(coro):
isinstance(referrer, Aclosing)
or isinstance(indirect_referrer, Aclosing)
for referrer in gc.get_referrers(coro)
- # Some coroutines have a layer of indirection in python 3.9 and 3.10
+ # Some coroutines have a layer of indirection in Python 3.10
for indirect_referrer in gc.get_referrers(referrer)
), f'Coro `{coro.__name__}` is not wrapped with Aclosing'
firstiter(coro)
diff --git a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py
index eef83a1f5e..1791100e1f 100644
--- a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py
+++ b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py
@@ -1709,6 +1709,65 @@ def test_execute_sql_job_labels(
}
+@pytest.mark.parametrize(
+ ("write_mode", "dry_run", "query_call_count", "query_and_wait_call_count"),
+ [
+ pytest.param(WriteMode.ALLOWED, False, 0, 1, id="write-allowed"),
+ pytest.param(WriteMode.ALLOWED, True, 1, 0, id="write-allowed-dry-run"),
+ pytest.param(WriteMode.BLOCKED, False, 1, 1, id="write-blocked"),
+ pytest.param(WriteMode.BLOCKED, True, 2, 0, id="write-blocked-dry-run"),
+ pytest.param(WriteMode.PROTECTED, False, 2, 1, id="write-protected"),
+ pytest.param(
+ WriteMode.PROTECTED, True, 3, 0, id="write-protected-dry-run"
+ ),
+ ],
+)
+def test_execute_sql_user_job_labels_augment_internal_labels(
+ write_mode, dry_run, query_call_count, query_and_wait_call_count
+):
+ """Test execute_sql tool augments user job_labels with internal labels."""
+ project = "my_project"
+ query = "SELECT 123 AS num"
+ statement_type = "SELECT"
+ credentials = mock.create_autospec(Credentials, instance=True)
+ user_labels = {"environment": "test", "team": "data"}
+ tool_settings = BigQueryToolConfig(
+ write_mode=write_mode,
+ job_labels=user_labels,
+ )
+ tool_context = mock.create_autospec(ToolContext, instance=True)
+ tool_context.state.get.return_value = None
+
+ with mock.patch.object(bigquery, "Client", autospec=True) as Client:
+ bq_client = Client.return_value
+
+ query_job = mock.create_autospec(bigquery.QueryJob)
+ query_job.statement_type = statement_type
+ bq_client.query.return_value = query_job
+
+ query_tool.execute_sql(
+ project,
+ query,
+ credentials,
+ tool_settings,
+ tool_context,
+ dry_run=dry_run,
+ )
+
+ assert bq_client.query.call_count == query_call_count
+ assert bq_client.query_and_wait.call_count == query_and_wait_call_count
+ # Build expected labels from user_labels + internal label
+ expected_labels = {**user_labels, "adk-bigquery-tool": "execute_sql"}
+ for call_args_list in [
+ bq_client.query.call_args_list,
+ bq_client.query_and_wait.call_args_list,
+ ]:
+ for call_args in call_args_list:
+ _, mock_kwargs = call_args
+ # Verify user labels are preserved and internal label is added
+ assert mock_kwargs["job_config"].labels == expected_labels
+
+
@pytest.mark.parametrize(
("tool_call", "expected_tool_label"),
[
@@ -1850,6 +1909,94 @@ def test_ml_tool_job_labels_w_application_name(tool_call, expected_tool_label):
assert mock_kwargs["job_config"].labels == expected_labels
+@pytest.mark.parametrize(
+ ("tool_call", "expected_labels"),
+ [
+ pytest.param(
+ lambda tool_context: query_tool.forecast(
+ project_id="test-project",
+ history_data="SELECT * FROM `test-dataset.test-table`",
+ timestamp_col="ts_col",
+ data_col="data_col",
+ credentials=mock.create_autospec(Credentials, instance=True),
+ settings=BigQueryToolConfig(
+ write_mode=WriteMode.ALLOWED,
+ job_labels={"environment": "prod", "app": "forecaster"},
+ ),
+ tool_context=tool_context,
+ ),
+ {
+ "environment": "prod",
+ "app": "forecaster",
+ "adk-bigquery-tool": "forecast",
+ },
+ id="forecast",
+ ),
+ pytest.param(
+ lambda tool_context: query_tool.analyze_contribution(
+ project_id="test-project",
+ input_data="test-dataset.test-table",
+ dimension_id_cols=["dim1", "dim2"],
+ contribution_metric="SUM(metric)",
+ is_test_col="is_test",
+ credentials=mock.create_autospec(Credentials, instance=True),
+ settings=BigQueryToolConfig(
+ write_mode=WriteMode.ALLOWED,
+ job_labels={"environment": "prod", "app": "analyzer"},
+ ),
+ tool_context=tool_context,
+ ),
+ {
+ "environment": "prod",
+ "app": "analyzer",
+ "adk-bigquery-tool": "analyze_contribution",
+ },
+ id="analyze-contribution",
+ ),
+ pytest.param(
+ lambda tool_context: query_tool.detect_anomalies(
+ project_id="test-project",
+ history_data="SELECT * FROM `test-dataset.test-table`",
+ times_series_timestamp_col="ts_timestamp",
+ times_series_data_col="ts_data",
+ credentials=mock.create_autospec(Credentials, instance=True),
+ settings=BigQueryToolConfig(
+ write_mode=WriteMode.ALLOWED,
+ job_labels={"environment": "prod", "app": "detector"},
+ ),
+ tool_context=tool_context,
+ ),
+ {
+ "environment": "prod",
+ "app": "detector",
+ "adk-bigquery-tool": "detect_anomalies",
+ },
+ id="detect-anomalies",
+ ),
+ ],
+)
+def test_ml_tool_user_job_labels_augment_internal_labels(
+ tool_call, expected_labels
+):
+ """Test ML tools augment user job_labels with internal labels."""
+
+ with mock.patch.object(bigquery, "Client", autospec=True) as Client:
+ bq_client = Client.return_value
+
+ tool_context = mock.create_autospec(ToolContext, instance=True)
+ tool_context.state.get.return_value = None
+ tool_call(tool_context)
+
+ for call_args_list in [
+ bq_client.query.call_args_list,
+ bq_client.query_and_wait.call_args_list,
+ ]:
+ for call_args in call_args_list:
+ _, mock_kwargs = call_args
+ # Verify user labels are preserved and internal label is added
+ assert mock_kwargs["job_config"].labels == expected_labels
+
+
def test_execute_sql_max_rows_config():
"""Test execute_sql tool respects max_query_result_rows from config."""
project = "my_project"
@@ -2014,3 +2161,93 @@ def test_tool_call_doesnt_change_global_settings(tool_call):
# Test settings write mode after
assert settings.write_mode == WriteMode.ALLOWED
+
+
+@pytest.mark.parametrize(
+ ("tool_call",),
+ [
+ pytest.param(
+ lambda settings, tool_context: query_tool.execute_sql(
+ project_id="test-project",
+ query="SELECT * FROM `test-dataset.test-table`",
+ credentials=mock.create_autospec(Credentials, instance=True),
+ settings=settings,
+ tool_context=tool_context,
+ ),
+ id="execute-sql",
+ ),
+ pytest.param(
+ lambda settings, tool_context: query_tool.forecast(
+ project_id="test-project",
+ history_data="SELECT * FROM `test-dataset.test-table`",
+ timestamp_col="ts_col",
+ data_col="data_col",
+ credentials=mock.create_autospec(Credentials, instance=True),
+ settings=settings,
+ tool_context=tool_context,
+ ),
+ id="forecast",
+ ),
+ pytest.param(
+ lambda settings, tool_context: query_tool.analyze_contribution(
+ project_id="test-project",
+ input_data="test-dataset.test-table",
+ dimension_id_cols=["dim1", "dim2"],
+ contribution_metric="SUM(metric)",
+ is_test_col="is_test",
+ credentials=mock.create_autospec(Credentials, instance=True),
+ settings=settings,
+ tool_context=tool_context,
+ ),
+ id="analyze-contribution",
+ ),
+ pytest.param(
+ lambda settings, tool_context: query_tool.detect_anomalies(
+ project_id="test-project",
+ history_data="SELECT * FROM `test-dataset.test-table`",
+ times_series_timestamp_col="ts_timestamp",
+ times_series_data_col="ts_data",
+ credentials=mock.create_autospec(Credentials, instance=True),
+ settings=settings,
+ tool_context=tool_context,
+ ),
+ id="detect-anomalies",
+ ),
+ ],
+)
+def test_tool_call_doesnt_mutate_job_labels(tool_call):
+ """Test query tools don't mutate job_labels in global settings."""
+ original_labels = {"environment": "test", "team": "data"}
+ settings = BigQueryToolConfig(
+ write_mode=WriteMode.ALLOWED,
+ job_labels=original_labels.copy(),
+ )
+ tool_context = mock.create_autospec(ToolContext, instance=True)
+ tool_context.state.get.return_value = (
+ "test-bq-session-id",
+ "_anonymous_dataset",
+ )
+
+ with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
+ # The mock instance
+ bq_client = Client.return_value
+
+ # Simulate the result of query API
+ query_job = mock.create_autospec(bigquery.QueryJob)
+ query_job.destination.dataset_id = "_anonymous_dataset"
+ bq_client.query.return_value = query_job
+ bq_client.query_and_wait.return_value = []
+
+ # Test job_labels before
+ assert settings.job_labels == original_labels
+ assert "adk-bigquery-tool" not in settings.job_labels
+
+ # Call the tool
+ result = tool_call(settings, tool_context)
+
+ # Test successful execution of the tool
+ assert result == {"status": "SUCCESS", "rows": []}
+
+ # Test job_labels remain unchanged after tool call
+ assert settings.job_labels == original_labels
+ assert "adk-bigquery-tool" not in settings.job_labels
diff --git a/tests/unittests/tools/bigquery/test_bigquery_tool_config.py b/tests/unittests/tools/bigquery/test_bigquery_tool_config.py
index 5854c97797..072ccea7d0 100644
--- a/tests/unittests/tools/bigquery/test_bigquery_tool_config.py
+++ b/tests/unittests/tools/bigquery/test_bigquery_tool_config.py
@@ -77,3 +77,61 @@ def test_bigquery_tool_config_invalid_maximum_bytes_billed():
),
):
BigQueryToolConfig(maximum_bytes_billed=10_485_759)
+
+
+@pytest.mark.parametrize(
+ "labels",
+ [
+ pytest.param(
+ {"environment": "test", "team": "data"},
+ id="valid-labels",
+ ),
+ pytest.param(
+ {},
+ id="empty-labels",
+ ),
+ pytest.param(
+ None,
+ id="none-labels",
+ ),
+ ],
+)
+def test_bigquery_tool_config_valid_labels(labels):
+ """Test BigQueryToolConfig accepts valid labels."""
+ with pytest.warns(UserWarning):
+ config = BigQueryToolConfig(job_labels=labels)
+ assert config.job_labels == labels
+
+
+@pytest.mark.parametrize(
+ ("labels", "message"),
+ [
+ pytest.param(
+ "invalid",
+ "Input should be a valid dictionary",
+ id="invalid-type",
+ ),
+ pytest.param(
+ {123: "value"},
+ "Input should be a valid string",
+ id="non-str-key",
+ ),
+ pytest.param(
+ {"key": 123},
+ "Input should be a valid string",
+ id="non-str-value",
+ ),
+ pytest.param(
+ {"": "value"},
+ "Label keys cannot be empty",
+ id="empty-label-key",
+ ),
+ ],
+)
+def test_bigquery_tool_config_invalid_labels(labels, message):
+ """Test BigQueryToolConfig raises an exception with invalid labels."""
+ with pytest.raises(
+ ValueError,
+ match=message,
+ ):
+ BigQueryToolConfig(job_labels=labels)
diff --git a/tests/unittests/tools/computer_use/test_computer_use_tool.py b/tests/unittests/tools/computer_use/test_computer_use_tool.py
index 4dbdfbb5c0..f3843b87a6 100644
--- a/tests/unittests/tools/computer_use/test_computer_use_tool.py
+++ b/tests/unittests/tools/computer_use/test_computer_use_tool.py
@@ -47,7 +47,7 @@ async def tool_context(self):
@pytest.fixture
def mock_computer_function(self):
"""Fixture providing a mock computer function."""
- # Create a real async function instead of AsyncMock for Python 3.9 compatibility
+ # Create a real async function instead of AsyncMock for better test control
calls = []
async def mock_func(*args, **kwargs):
diff --git a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py
index 6c001ccf65..74eabe9d4d 100644
--- a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py
+++ b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py
@@ -22,46 +22,14 @@
from unittest.mock import Mock
from unittest.mock import patch
+from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager
+from google.adk.tools.mcp_tool.mcp_session_manager import retry_on_errors
+from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
+from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams
+from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams
+from mcp import StdioServerParameters
import pytest
-# Skip all tests in this module if Python version is less than 3.10
-pytestmark = pytest.mark.skipif(
- sys.version_info < (3, 10), reason="MCP tool requires Python 3.10+"
-)
-
-# Import dependencies with version checking
-try:
- from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager
- from google.adk.tools.mcp_tool.mcp_session_manager import retry_on_closed_resource
- from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
- from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams
- from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams
-except ImportError as e:
- if sys.version_info < (3, 10):
- # Create dummy classes to prevent NameError during test collection
- # Tests will be skipped anyway due to pytestmark
- class DummyClass:
- pass
-
- MCPSessionManager = DummyClass
- retry_on_closed_resource = lambda x: x
- SseConnectionParams = DummyClass
- StdioConnectionParams = DummyClass
- StreamableHTTPConnectionParams = DummyClass
- else:
- raise e
-
-# Import real MCP classes
-try:
- from mcp import StdioServerParameters
-except ImportError:
- # Create a mock if MCP is not available
- class StdioServerParameters:
-
- def __init__(self, command="test_command", args=None):
- self.command = command
- self.args = args or []
-
class MockClientSession:
"""Mock ClientSession for testing."""
@@ -375,12 +343,12 @@ async def test_close_with_errors(self):
assert "Close error 1" in error_output
-def test_retry_on_closed_resource_decorator():
- """Test the retry_on_closed_resource decorator."""
+def test_retry_on_errors_decorator():
+ """Test the retry_on_errors decorator."""
call_count = 0
- @retry_on_closed_resource
+ @retry_on_errors
async def mock_function(self):
nonlocal call_count
call_count += 1
diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py
index 17b1d8e54e..1284e73bce 100644
--- a/tests/unittests/tools/mcp_tool/test_mcp_tool.py
+++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import sys
from unittest.mock import AsyncMock
from unittest.mock import Mock
from unittest.mock import patch
@@ -23,39 +22,15 @@
from google.adk.auth.auth_credential import HttpCredentials
from google.adk.auth.auth_credential import OAuth2Auth
from google.adk.auth.auth_credential import ServiceAccount
+from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager
+from google.adk.tools.mcp_tool.mcp_tool import MCPTool
+from google.adk.tools.tool_context import ToolContext
+from google.genai.types import FunctionDeclaration
+from google.genai.types import Type
+from mcp.types import CallToolResult
+from mcp.types import TextContent
import pytest
-# Skip all tests in this module if Python version is less than 3.10
-pytestmark = pytest.mark.skipif(
- sys.version_info < (3, 10), reason="MCP tool requires Python 3.10+"
-)
-
-# Import dependencies with version checking
-try:
- from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager
- from google.adk.tools.mcp_tool.mcp_tool import MCPTool
- from google.adk.tools.tool_context import ToolContext
- from google.genai.types import FunctionDeclaration
- from google.genai.types import Type
- from mcp.types import CallToolResult
- from mcp.types import TextContent
-except ImportError as e:
- if sys.version_info < (3, 10):
- # Create dummy classes to prevent NameError during test collection
- # Tests will be skipped anyway due to pytestmark
- class DummyClass:
- pass
-
- MCPSessionManager = DummyClass
- MCPTool = DummyClass
- ToolContext = DummyClass
- FunctionDeclaration = DummyClass
- Type = DummyClass
- CallToolResult = DummyClass
- TextContent = DummyClass
- else:
- raise e
-
# Mock MCP Tool from mcp.types
class MockMCPTool:
diff --git a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py
index 82a5c9a3e7..5809efe56f 100644
--- a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py
+++ b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py
@@ -20,47 +20,17 @@
from unittest.mock import Mock
from unittest.mock import patch
+from google.adk.agents.readonly_context import ReadonlyContext
from google.adk.auth.auth_credential import AuthCredential
+from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager
+from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
+from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams
+from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams
+from google.adk.tools.mcp_tool.mcp_tool import MCPTool
+from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset
+from mcp import StdioServerParameters
import pytest
-# Skip all tests in this module if Python version is less than 3.10
-pytestmark = pytest.mark.skipif(
- sys.version_info < (3, 10), reason="MCP tool requires Python 3.10+"
-)
-
-# Import dependencies with version checking
-try:
- from google.adk.agents.readonly_context import ReadonlyContext
- from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager
- from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
- from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams
- from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams
- from google.adk.tools.mcp_tool.mcp_tool import MCPTool
- from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset
- from mcp import StdioServerParameters
-except ImportError as e:
- if sys.version_info < (3, 10):
- # Create dummy classes to prevent NameError during test collection
- # Tests will be skipped anyway due to pytestmark
- class DummyClass:
- pass
-
- class StdioServerParameters:
-
- def __init__(self, command="test_command", args=None):
- self.command = command
- self.args = args or []
-
- MCPSessionManager = DummyClass
- SseConnectionParams = DummyClass
- StdioConnectionParams = DummyClass
- StreamableHTTPConnectionParams = DummyClass
- MCPTool = DummyClass
- MCPToolset = DummyClass
- ReadonlyContext = DummyClass
- else:
- raise e
-
class MockMCPTool:
"""Mock MCP Tool for testing."""
diff --git a/tests/unittests/tools/openapi_tool/common/test_common.py b/tests/unittests/tools/openapi_tool/common/test_common.py
index 47aeb79fdb..faece5be89 100644
--- a/tests/unittests/tools/openapi_tool/common/test_common.py
+++ b/tests/unittests/tools/openapi_tool/common/test_common.py
@@ -74,6 +74,24 @@ def test_api_parameter_keyword_rename(self):
)
assert param.py_name == 'param_in'
+ def test_api_parameter_uses_location_default_when_name_missing(self):
+ schema = Schema(type='string')
+ param = ApiParameter(
+ original_name='',
+ param_location='body',
+ param_schema=schema,
+ )
+ assert param.py_name == 'body'
+
+ def test_api_parameter_uses_value_default_when_location_unknown(self):
+ schema = Schema(type='integer')
+ param = ApiParameter(
+ original_name='',
+ param_location='',
+ param_schema=schema,
+ )
+ assert param.py_name == 'value'
+
def test_api_parameter_custom_py_name(self):
schema = Schema(type='integer')
param = ApiParameter(
diff --git a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_operation_parser.py b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_operation_parser.py
index 26cb944a22..83741c97a2 100644
--- a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_operation_parser.py
+++ b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_operation_parser.py
@@ -164,6 +164,40 @@ def test_process_request_body_no_name():
assert parser._params[0].param_location == 'body'
+def test_process_request_body_one_of_schema_assigns_name():
+ """Ensures oneOf bodies result in a named parameter."""
+ operation = Operation(
+ operationId='one_of_request',
+ requestBody=RequestBody(
+ content={
+ 'application/json': MediaType(
+ schema=Schema(
+ oneOf=[
+ Schema(
+ type='object',
+ properties={
+ 'type': Schema(type='string'),
+ 'stage': Schema(type='string'),
+ },
+ )
+ ],
+ discriminator={'propertyName': 'type'},
+ )
+ )
+ }
+ ),
+ responses={'200': Response(description='ok')},
+ )
+ parser = OperationParser(operation)
+ params = parser.get_parameters()
+ assert len(params) == 1
+ assert params[0].original_name == 'body'
+ assert params[0].py_name == 'body'
+ schema = parser.get_json_schema()
+ assert 'body' in schema['properties']
+ assert '' not in schema['properties']
+
+
def test_process_request_body_empty_object():
"""Test _process_request_body with a schema that is of type object but with no properties."""
operation = Operation(
diff --git a/tests/unittests/tools/retrieval/test_files_retrieval.py b/tests/unittests/tools/retrieval/test_files_retrieval.py
index ea4b99cd98..dfb7215dce 100644
--- a/tests/unittests/tools/retrieval/test_files_retrieval.py
+++ b/tests/unittests/tools/retrieval/test_files_retrieval.py
@@ -14,7 +14,6 @@
"""Tests for FilesRetrieval tool."""
-import sys
import unittest.mock as mock
from google.adk.tools.retrieval.files_retrieval import _get_default_embedding_model
@@ -111,9 +110,6 @@ def mock_import(name, *args, **kwargs):
def test_get_default_embedding_model_success(self):
"""Test _get_default_embedding_model returns Google embedding when available."""
- # Skip this test in Python 3.9 where llama_index.embeddings.google_genai may not be available
- if sys.version_info < (3, 10):
- pytest.skip("llama_index.embeddings.google_genai requires Python 3.10+")
# Mock the module creation to avoid import issues
mock_module = mock.MagicMock()
diff --git a/tests/unittests/tools/spanner/test_search_tool.py b/tests/unittests/tools/spanner/test_search_tool.py
index 1b330b01d5..c69aa444ec 100644
--- a/tests/unittests/tools/spanner/test_search_tool.py
+++ b/tests/unittests/tools/spanner/test_search_tool.py
@@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from unittest import mock
from unittest.mock import MagicMock
-from unittest.mock import patch
+from google.adk.tools.spanner import client
from google.adk.tools.spanner import search_tool
+from google.adk.tools.spanner import utils
from google.cloud.spanner_admin_database_v1.types import DatabaseDialect
import pytest
@@ -35,29 +37,59 @@ def mock_spanner_ids():
}
-@patch("google.adk.tools.spanner.client.get_spanner_client")
+@pytest.mark.parametrize(
+ ("embedding_option_key", "embedding_option_value", "expected_embedding"),
+ [
+ pytest.param(
+ "spanner_googlesql_embedding_model_name",
+ "EmbeddingsModel",
+ [0.1, 0.2, 0.3],
+ id="spanner_googlesql_embedding_model",
+ ),
+ pytest.param(
+ "vertex_ai_embedding_model_name",
+ "text-embedding-005",
+ [0.4, 0.5, 0.6],
+ id="vertex_ai_embedding_model",
+ ),
+ ],
+)
+@mock.patch.object(utils, "embed_contents")
+@mock.patch.object(client, "get_spanner_client")
def test_similarity_search_knn_success(
- mock_get_spanner_client, mock_spanner_ids, mock_credentials
+ mock_get_spanner_client,
+ mock_embed_contents,
+ mock_spanner_ids,
+ mock_credentials,
+ embedding_option_key,
+ embedding_option_value,
+ expected_embedding,
):
"""Test similarity_search function with kNN success."""
mock_spanner_client = MagicMock()
mock_instance = MagicMock()
mock_database = MagicMock()
mock_snapshot = MagicMock()
- mock_embedding_result = MagicMock()
- mock_embedding_result.one.return_value = ([0.1, 0.2, 0.3],)
- # First call to execute_sql is for getting the embedding
- # Second call is for the kNN search
- mock_snapshot.execute_sql.side_effect = [
- mock_embedding_result,
- iter([("result1",), ("result2",)]),
- ]
mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot
mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL
mock_instance.database.return_value = mock_database
mock_spanner_client.instance.return_value = mock_instance
mock_get_spanner_client.return_value = mock_spanner_client
+ if embedding_option_key == "vertex_ai_embedding_model_name":
+ mock_embed_contents.return_value = [expected_embedding]
+ # execute_sql is called once for the kNN search
+ mock_snapshot.execute_sql.return_value = iter([("result1",), ("result2",)])
+ else:
+ mock_embedding_result = MagicMock()
+ mock_embedding_result.one.return_value = (expected_embedding,)
+ # First call to execute_sql is for getting the embedding,
+ # second call is for the kNN search
+ mock_snapshot.execute_sql.side_effect = [
+ mock_embedding_result,
+ iter([("result1",), ("result2",)]),
+ ]
+
result = search_tool.similarity_search(
project_id=mock_spanner_ids["project_id"],
instance_id=mock_spanner_ids["instance_id"],
@@ -66,10 +98,8 @@ def test_similarity_search_knn_success(
query="test query",
embedding_column_to_search="embedding_col",
columns=["col1"],
- embedding_options={"spanner_embedding_model_name": "test_model"},
+ embedding_options={embedding_option_key: embedding_option_value},
credentials=mock_credentials,
- settings=MagicMock(),
- tool_context=MagicMock(),
)
assert result["status"] == "SUCCESS", result
assert result["rows"] == [("result1",), ("result2",)]
@@ -79,10 +109,14 @@ def test_similarity_search_knn_success(
sql = call_args.args[0]
assert "COSINE_DISTANCE" in sql
assert "@embedding" in sql
- assert call_args.kwargs == {"params": {"embedding": [0.1, 0.2, 0.3]}}
+ assert call_args.kwargs == {"params": {"embedding": expected_embedding}}
+ if embedding_option_key == "vertex_ai_embedding_model_name":
+ mock_embed_contents.assert_called_once_with(
+ embedding_option_value, ["test query"], None
+ )
-@patch("google.adk.tools.spanner.client.get_spanner_client")
+@mock.patch.object(client, "get_spanner_client")
def test_similarity_search_ann_success(
mock_get_spanner_client, mock_spanner_ids, mock_credentials
):
@@ -113,10 +147,10 @@ def test_similarity_search_ann_success(
query="test query",
embedding_column_to_search="embedding_col",
columns=["col1"],
- embedding_options={"spanner_embedding_model_name": "test_model"},
+ embedding_options={
+ "spanner_googlesql_embedding_model_name": "test_model"
+ },
credentials=mock_credentials,
- settings=MagicMock(),
- tool_context=MagicMock(),
search_options={
"nearest_neighbors_algorithm": "APPROXIMATE_NEAREST_NEIGHBORS"
},
@@ -130,7 +164,7 @@ def test_similarity_search_ann_success(
assert call_args.kwargs == {"params": {"embedding": [0.1, 0.2, 0.3]}}
-@patch("google.adk.tools.spanner.client.get_spanner_client")
+@mock.patch.object(client, "get_spanner_client")
def test_similarity_search_error(
mock_get_spanner_client, mock_spanner_ids, mock_credentials
):
@@ -143,17 +177,17 @@ def test_similarity_search_error(
table_name=mock_spanner_ids["table_name"],
query="test query",
embedding_column_to_search="embedding_col",
- embedding_options={"spanner_embedding_model_name": "test_model"},
+ embedding_options={
+ "spanner_googlesql_embedding_model_name": "test_model"
+ },
columns=["col1"],
credentials=mock_credentials,
- settings=MagicMock(),
- tool_context=MagicMock(),
)
assert result["status"] == "ERROR"
- assert result["error_details"] == "Test Exception"
+ assert "Test Exception" in result["error_details"]
-@patch("google.adk.tools.spanner.client.get_spanner_client")
+@mock.patch.object(client, "get_spanner_client")
def test_similarity_search_postgresql_knn_success(
mock_get_spanner_client, mock_spanner_ids, mock_credentials
):
@@ -182,10 +216,12 @@ def test_similarity_search_postgresql_knn_success(
query="test query",
embedding_column_to_search="embedding_col",
columns=["col1"],
- embedding_options={"vertex_ai_embedding_model_endpoint": "test_endpoint"},
+ embedding_options={
+ "spanner_postgresql_vertex_ai_embedding_model_endpoint": (
+ "test_endpoint"
+ )
+ },
credentials=mock_credentials,
- settings=MagicMock(),
- tool_context=MagicMock(),
)
assert result["status"] == "SUCCESS", result
assert result["rows"] == [("pg_result",)]
@@ -196,7 +232,7 @@ def test_similarity_search_postgresql_knn_success(
assert call_args.kwargs == {"params": {"p1": [0.1, 0.2, 0.3]}}
-@patch("google.adk.tools.spanner.client.get_spanner_client")
+@mock.patch.object(client, "get_spanner_client")
def test_similarity_search_postgresql_ann_unsupported(
mock_get_spanner_client, mock_spanner_ids, mock_credentials
):
@@ -217,27 +253,28 @@ def test_similarity_search_postgresql_ann_unsupported(
query="test query",
embedding_column_to_search="embedding_col",
columns=["col1"],
- embedding_options={"vertex_ai_embedding_model_endpoint": "test_endpoint"},
+ embedding_options={
+ "spanner_postgresql_vertex_ai_embedding_model_endpoint": (
+ "test_endpoint"
+ )
+ },
credentials=mock_credentials,
- settings=MagicMock(),
- tool_context=MagicMock(),
search_options={
"nearest_neighbors_algorithm": "APPROXIMATE_NEAREST_NEIGHBORS"
},
)
assert result["status"] == "ERROR"
assert (
- result["error_details"]
- == "APPROXIMATE_NEAREST_NEIGHBORS is not supported for PostgreSQL"
- " dialect."
+ "APPROXIMATE_NEAREST_NEIGHBORS is not supported for PostgreSQL dialect."
+ in result["error_details"]
)
-@patch("google.adk.tools.spanner.client.get_spanner_client")
-def test_similarity_search_missing_spanner_embedding_model_name_error(
+@mock.patch.object(client, "get_spanner_client")
+def test_similarity_search_gsql_missing_embedding_model_error(
mock_get_spanner_client, mock_spanner_ids, mock_credentials
):
- """Test similarity_search with missing spanner_embedding_model_name."""
+ """Test similarity_search with missing embedding_options for GoogleSQL dialect."""
mock_spanner_client = MagicMock()
mock_instance = MagicMock()
mock_database = MagicMock()
@@ -254,24 +291,27 @@ def test_similarity_search_missing_spanner_embedding_model_name_error(
query="test query",
embedding_column_to_search="embedding_col",
columns=["col1"],
- embedding_options={},
+ embedding_options={
+ "spanner_postgresql_vertex_ai_embedding_model_endpoint": (
+ "test_endpoint"
+ )
+ },
credentials=mock_credentials,
- settings=MagicMock(),
- tool_context=MagicMock(),
)
assert result["status"] == "ERROR"
assert (
- "embedding_options['spanner_embedding_model_name'] must be"
- " specified for GoogleSQL dialect."
+ "embedding_options['vertex_ai_embedding_model_name'] or"
+ " embedding_options['spanner_googlesql_embedding_model_name'] must be"
+ " specified for GoogleSQL dialect Spanner database."
in result["error_details"]
)
-@patch("google.adk.tools.spanner.client.get_spanner_client")
-def test_similarity_search_missing_vertex_ai_embedding_model_endpoint_error(
+@mock.patch.object(client, "get_spanner_client")
+def test_similarity_search_pg_missing_embedding_model_error(
mock_get_spanner_client, mock_spanner_ids, mock_credentials
):
- """Test similarity_search with missing vertex_ai_embedding_model_endpoint."""
+ """Test similarity_search with missing embedding_options for PostgreSQL dialect."""
mock_spanner_client = MagicMock()
mock_instance = MagicMock()
mock_database = MagicMock()
@@ -288,14 +328,153 @@ def test_similarity_search_missing_vertex_ai_embedding_model_endpoint_error(
query="test query",
embedding_column_to_search="embedding_col",
columns=["col1"],
- embedding_options={},
+ embedding_options={
+ "spanner_googlesql_embedding_model_name": "EmbeddingsModel"
+ },
+ credentials=mock_credentials,
+ )
+ assert result["status"] == "ERROR"
+ assert (
+ "embedding_options['vertex_ai_embedding_model_name'] or"
+ " embedding_options['spanner_postgresql_vertex_ai_embedding_model_endpoint']"
+ " must be specified for PostgreSQL dialect Spanner database."
+ in result["error_details"]
+ )
+
+
+@pytest.mark.parametrize(
+ "embedding_options",
+ [
+ pytest.param(
+ {
+ "vertex_ai_embedding_model_name": "test-model",
+ "spanner_googlesql_embedding_model_name": "test-model-2",
+ },
+ id="vertex_ai_and_googlesql",
+ ),
+ pytest.param(
+ {
+ "vertex_ai_embedding_model_name": "test-model",
+ "spanner_postgresql_vertex_ai_embedding_model_endpoint": (
+ "test-endpoint"
+ ),
+ },
+ id="vertex_ai_and_postgresql",
+ ),
+ pytest.param(
+ {
+ "spanner_googlesql_embedding_model_name": "test-model",
+ "spanner_postgresql_vertex_ai_embedding_model_endpoint": (
+ "test-endpoint"
+ ),
+ },
+ id="googlesql_and_postgresql",
+ ),
+ pytest.param(
+ {
+ "vertex_ai_embedding_model_name": "test-model",
+ "spanner_googlesql_embedding_model_name": "test-model-2",
+ "spanner_postgresql_vertex_ai_embedding_model_endpoint": (
+ "test-endpoint"
+ ),
+ },
+ id="all_three_models",
+ ),
+ pytest.param(
+ {},
+ id="no_models",
+ ),
+ ],
+)
+@mock.patch.object(client, "get_spanner_client")
+def test_similarity_search_multiple_embedding_options_error(
+ mock_get_spanner_client,
+ mock_spanner_ids,
+ mock_credentials,
+ embedding_options,
+):
+ """Test similarity_search with multiple embedding models."""
+ mock_spanner_client = MagicMock()
+ mock_instance = MagicMock()
+ mock_database = MagicMock()
+ mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL
+ mock_instance.database.return_value = mock_database
+ mock_spanner_client.instance.return_value = mock_instance
+ mock_get_spanner_client.return_value = mock_spanner_client
+
+ result = search_tool.similarity_search(
+ project_id=mock_spanner_ids["project_id"],
+ instance_id=mock_spanner_ids["instance_id"],
+ database_id=mock_spanner_ids["database_id"],
+ table_name=mock_spanner_ids["table_name"],
+ query="test query",
+ embedding_column_to_search="embedding_col",
+ columns=["col1"],
+ embedding_options=embedding_options,
credentials=mock_credentials,
- settings=MagicMock(),
- tool_context=MagicMock(),
)
assert result["status"] == "ERROR"
assert (
- "embedding_options['vertex_ai_embedding_model_endpoint'] must "
- "be specified for PostgreSQL dialect."
+ "Exactly one embedding model option must be specified."
in result["error_details"]
)
+
+
+@mock.patch.object(client, "get_spanner_client")
+def test_similarity_search_output_dimensionality_gsql_error(
+ mock_get_spanner_client, mock_spanner_ids, mock_credentials
+):
+ """Test similarity_search with output_dimensionality and spanner_googlesql_embedding_model_name."""
+ mock_spanner_client = MagicMock()
+ mock_instance = MagicMock()
+ mock_database = MagicMock()
+ mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL
+ mock_instance.database.return_value = mock_database
+ mock_spanner_client.instance.return_value = mock_instance
+ mock_get_spanner_client.return_value = mock_spanner_client
+
+ result = search_tool.similarity_search(
+ project_id=mock_spanner_ids["project_id"],
+ instance_id=mock_spanner_ids["instance_id"],
+ database_id=mock_spanner_ids["database_id"],
+ table_name=mock_spanner_ids["table_name"],
+ query="test query",
+ embedding_column_to_search="embedding_col",
+ columns=["col1"],
+ embedding_options={
+ "spanner_googlesql_embedding_model_name": "EmbeddingsModel",
+ "output_dimensionality": 128,
+ },
+ credentials=mock_credentials,
+ )
+ assert result["status"] == "ERROR"
+ assert "is not supported when" in result["error_details"]
+
+
+@mock.patch.object(client, "get_spanner_client")
+def test_similarity_search_unsupported_algorithm_error(
+ mock_get_spanner_client, mock_spanner_ids, mock_credentials
+):
+ """Test similarity_search with an unsupported nearest neighbors algorithm."""
+ mock_spanner_client = MagicMock()
+ mock_instance = MagicMock()
+ mock_database = MagicMock()
+ mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL
+ mock_instance.database.return_value = mock_database
+ mock_spanner_client.instance.return_value = mock_instance
+ mock_get_spanner_client.return_value = mock_spanner_client
+
+ result = search_tool.similarity_search(
+ project_id=mock_spanner_ids["project_id"],
+ instance_id=mock_spanner_ids["instance_id"],
+ database_id=mock_spanner_ids["database_id"],
+ table_name=mock_spanner_ids["table_name"],
+ query="test query",
+ embedding_column_to_search="embedding_col",
+ columns=["col1"],
+ embedding_options={"vertex_ai_embedding_model_name": "test-model"},
+ credentials=mock_credentials,
+ search_options={"nearest_neighbors_algorithm": "INVALID_ALGORITHM"},
+ )
+ assert result["status"] == "ERROR"
+ assert "Unsupported search_options" in result["error_details"]
diff --git a/tests/unittests/tools/spanner/test_spanner_tool_settings.py b/tests/unittests/tools/spanner/test_spanner_tool_settings.py
index f74922b248..730c9e0efe 100644
--- a/tests/unittests/tools/spanner/test_spanner_tool_settings.py
+++ b/tests/unittests/tools/spanner/test_spanner_tool_settings.py
@@ -15,9 +15,23 @@
from __future__ import annotations
from google.adk.tools.spanner.settings import SpannerToolSettings
+from google.adk.tools.spanner.settings import SpannerVectorStoreSettings
+from pydantic import ValidationError
import pytest
+def common_spanner_vector_store_settings(vector_length=None):
+ return {
+ "project_id": "test-project",
+ "instance_id": "test-instance",
+ "database_id": "test-database",
+ "table_name": "test-table",
+ "content_column": "test-content-column",
+ "embedding_column": "test-embedding-column",
+ "vector_length": 128 if vector_length is None else vector_length,
+ }
+
+
def test_spanner_tool_settings_experimental_warning():
"""Test SpannerToolSettings experimental warning."""
with pytest.warns(
@@ -25,3 +39,34 @@ def test_spanner_tool_settings_experimental_warning():
match="Tool settings defaults may have breaking change in the future.",
):
SpannerToolSettings()
+
+
+def test_spanner_vector_store_settings_all_fields_present():
+ """Test SpannerVectorStoreSettings with all required fields present."""
+ settings = SpannerVectorStoreSettings(
+ **common_spanner_vector_store_settings(),
+ vertex_ai_embedding_model_name="test-embedding-model",
+ )
+ assert settings is not None
+ assert settings.selected_columns == ["test-content-column"]
+ assert settings.vertex_ai_embedding_model_name == "test-embedding-model"
+
+
+def test_spanner_vector_store_settings_missing_embedding_model_name():
+ """Test SpannerVectorStoreSettings with missing vertex_ai_embedding_model_name."""
+ with pytest.raises(ValidationError) as excinfo:
+ SpannerVectorStoreSettings(**common_spanner_vector_store_settings())
+ assert "Field required" in str(excinfo.value)
+ assert "vertex_ai_embedding_model_name" in str(excinfo.value)
+
+
+def test_spanner_vector_store_settings_invalid_vector_length():
+ """Test SpannerVectorStoreSettings with invalid vector_length."""
+ with pytest.raises(ValidationError) as excinfo:
+ SpannerVectorStoreSettings(
+ **common_spanner_vector_store_settings(vector_length=0),
+ vertex_ai_embedding_model_name="test-embedding-model",
+ )
+ assert "Invalid vector length in the Spanner vector store settings." in str(
+ excinfo.value
+ )
diff --git a/tests/unittests/tools/spanner/test_spanner_toolset.py b/tests/unittests/tools/spanner/test_spanner_toolset.py
index 163832559d..a583a2f884 100644
--- a/tests/unittests/tools/spanner/test_spanner_toolset.py
+++ b/tests/unittests/tools/spanner/test_spanner_toolset.py
@@ -18,6 +18,7 @@
from google.adk.tools.spanner import SpannerCredentialsConfig
from google.adk.tools.spanner import SpannerToolset
from google.adk.tools.spanner.settings import SpannerToolSettings
+from google.adk.tools.spanner.settings import SpannerVectorStoreSettings
import pytest
@@ -184,3 +185,50 @@ async def test_spanner_toolset_without_read_capability(
expected_tool_names = set(returned_tools)
actual_tool_names = set([tool.name for tool in tools])
assert actual_tool_names == expected_tool_names
+
+
+@pytest.mark.asyncio
+async def test_spanner_toolset_with_vector_store_search():
+ """Test Spanner toolset with vector store search.
+
+ This test verifies the behavior of the Spanner toolset when vector store
+ settings is provided.
+ """
+ credentials_config = SpannerCredentialsConfig(
+ client_id="abc", client_secret="def"
+ )
+
+ spanner_tool_settings = SpannerToolSettings(
+ vector_store_settings=SpannerVectorStoreSettings(
+ project_id="test-project",
+ instance_id="test-instance",
+ database_id="test-database",
+ table_name="test-table",
+ content_column="test-content-column",
+ embedding_column="test-embedding-column",
+ vector_length=128,
+ vertex_ai_embedding_model_name="test-embedding-model",
+ )
+ )
+ toolset = SpannerToolset(
+ credentials_config=credentials_config,
+ spanner_tool_settings=spanner_tool_settings,
+ )
+ tools = await toolset.get_tools()
+ assert tools is not None
+
+ assert len(tools) == 8
+ assert all([isinstance(tool, GoogleTool) for tool in tools])
+
+ expected_tool_names = set([
+ "list_table_names",
+ "list_table_indexes",
+ "list_table_index_columns",
+ "list_named_schemas",
+ "get_table_schema",
+ "execute_sql",
+ "similarity_search",
+ "vector_store_similarity_search",
+ ])
+ actual_tool_names = set([tool.name for tool in tools])
+ assert actual_tool_names == expected_tool_names
diff --git a/tests/unittests/tools/test_agent_tool.py b/tests/unittests/tools/test_agent_tool.py
index 85e8b9caa1..a9723b4347 100644
--- a/tests/unittests/tools/test_agent_tool.py
+++ b/tests/unittests/tools/test_agent_tool.py
@@ -570,3 +570,135 @@ class CustomInput(BaseModel):
# Should have string response schema for VERTEX_AI when no output_schema
assert declaration.response is not None
assert declaration.response.type == types.Type.STRING
+
+
+def test_include_plugins_default_true():
+ """Test that plugins are propagated by default (include_plugins=True)."""
+
+ # Create a test plugin that tracks callbacks
+ class TrackingPlugin(BasePlugin):
+
+ def __init__(self, name: str):
+ super().__init__(name)
+ self.before_agent_calls = 0
+
+ async def before_agent_callback(self, **kwargs):
+ self.before_agent_calls += 1
+
+ tracking_plugin = TrackingPlugin(name='tracking')
+
+ mock_model = testing_utils.MockModel.create(
+ responses=[function_call_no_schema, 'response1', 'response2']
+ )
+
+ tool_agent = Agent(
+ name='tool_agent',
+ model=mock_model,
+ )
+
+ root_agent = Agent(
+ name='root_agent',
+ model=mock_model,
+ tools=[AgentTool(agent=tool_agent)], # Default include_plugins=True
+ )
+
+ runner = testing_utils.InMemoryRunner(root_agent, plugins=[tracking_plugin])
+ runner.run('test1')
+
+ # Plugin should be called for both root_agent and tool_agent
+ assert tracking_plugin.before_agent_calls == 2
+
+
+def test_include_plugins_explicit_true():
+ """Test that plugins are propagated when include_plugins=True."""
+
+ class TrackingPlugin(BasePlugin):
+
+ def __init__(self, name: str):
+ super().__init__(name)
+ self.before_agent_calls = 0
+
+ async def before_agent_callback(self, **kwargs):
+ self.before_agent_calls += 1
+
+ tracking_plugin = TrackingPlugin(name='tracking')
+
+ mock_model = testing_utils.MockModel.create(
+ responses=[function_call_no_schema, 'response1', 'response2']
+ )
+
+ tool_agent = Agent(
+ name='tool_agent',
+ model=mock_model,
+ )
+
+ root_agent = Agent(
+ name='root_agent',
+ model=mock_model,
+ tools=[AgentTool(agent=tool_agent, include_plugins=True)],
+ )
+
+ runner = testing_utils.InMemoryRunner(root_agent, plugins=[tracking_plugin])
+ runner.run('test1')
+
+ # Plugin should be called for both root_agent and tool_agent
+ assert tracking_plugin.before_agent_calls == 2
+
+
+def test_include_plugins_false():
+ """Test that plugins are NOT propagated when include_plugins=False."""
+
+ class TrackingPlugin(BasePlugin):
+
+ def __init__(self, name: str):
+ super().__init__(name)
+ self.before_agent_calls = 0
+
+ async def before_agent_callback(self, **kwargs):
+ self.before_agent_calls += 1
+
+ tracking_plugin = TrackingPlugin(name='tracking')
+
+ mock_model = testing_utils.MockModel.create(
+ responses=[function_call_no_schema, 'response1', 'response2']
+ )
+
+ tool_agent = Agent(
+ name='tool_agent',
+ model=mock_model,
+ )
+
+ root_agent = Agent(
+ name='root_agent',
+ model=mock_model,
+ tools=[AgentTool(agent=tool_agent, include_plugins=False)],
+ )
+
+ runner = testing_utils.InMemoryRunner(root_agent, plugins=[tracking_plugin])
+ runner.run('test1')
+
+ # Plugin should only be called for root_agent, not tool_agent
+ assert tracking_plugin.before_agent_calls == 1
+
+
+def test_agent_tool_description_with_input_schema():
+ """Test that agent description is propagated when using input_schema."""
+
+ class CustomInput(BaseModel):
+ """This is the Pydantic model docstring."""
+
+ custom_input: str
+
+ agent_description = 'This is the agent description that should be used'
+ tool_agent = Agent(
+ name='tool_agent',
+ model=testing_utils.MockModel.create(responses=['test response']),
+ description=agent_description,
+ input_schema=CustomInput,
+ )
+
+ agent_tool = AgentTool(agent=tool_agent)
+ declaration = agent_tool._get_declaration()
+
+ # The description should come from the agent, not the Pydantic model
+ assert declaration.description == agent_description
diff --git a/tests/unittests/tools/test_api_registry.py b/tests/unittests/tools/test_api_registry.py
new file mode 100644
index 0000000000..df54786049
--- /dev/null
+++ b/tests/unittests/tools/test_api_registry.py
@@ -0,0 +1,205 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import unittest
+from unittest.mock import MagicMock
+from unittest.mock import patch
+
+from google.adk.tools.api_registry import ApiRegistry
+from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams
+import httpx
+
+MOCK_MCP_SERVERS_LIST = {
+ "mcpServers": [
+ {
+ "name": "test-mcp-server-1",
+ "urls": ["mcp.server1.com"],
+ },
+ {
+ "name": "test-mcp-server-2",
+ "urls": ["mcp.server2.com"],
+ },
+ {
+ "name": "test-mcp-server-no-url",
+ },
+ ]
+}
+
+
+class TestApiRegistry(unittest.IsolatedAsyncioTestCase):
+ """Unit tests for ApiRegistry."""
+
+ def setUp(self):
+ self.project_id = "test-project"
+ self.location = "global"
+ self.mock_credentials = MagicMock()
+ self.mock_credentials.token = "mock_token"
+ self.mock_credentials.refresh = MagicMock()
+ mock_auth_patcher = patch(
+ "google.auth.default",
+ return_value=(self.mock_credentials, None),
+ autospec=True,
+ )
+ mock_auth_patcher.start()
+ self.addCleanup(mock_auth_patcher.stop)
+
+ @patch("httpx.Client", autospec=True)
+ def test_init_success(self, MockHttpClient):
+ mock_response = MagicMock()
+ mock_response.raise_for_status = MagicMock()
+ mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST)
+ mock_client_instance = MockHttpClient.return_value
+ mock_client_instance.__enter__.return_value = mock_client_instance
+ mock_client_instance.get.return_value = mock_response
+
+ api_registry = ApiRegistry(
+ api_registry_project_id=self.project_id, location=self.location
+ )
+
+ self.assertEqual(len(api_registry._mcp_servers), 3)
+ self.assertIn("test-mcp-server-1", api_registry._mcp_servers)
+ self.assertIn("test-mcp-server-2", api_registry._mcp_servers)
+ self.assertIn("test-mcp-server-no-url", api_registry._mcp_servers)
+ mock_client_instance.get.assert_called_once_with(
+ f"https://cloudapiregistry.googleapis.com/v1beta/projects/{self.project_id}/locations/{self.location}/mcpServers",
+ headers={
+ "Authorization": "Bearer mock_token",
+ "Content-Type": "application/json",
+ },
+ )
+
+ @patch("httpx.Client", autospec=True)
+ def test_init_http_error(self, MockHttpClient):
+ mock_client_instance = MockHttpClient.return_value
+ mock_client_instance.__enter__.return_value = mock_client_instance
+ mock_client_instance.get.side_effect = httpx.RequestError(
+ "Connection failed"
+ )
+
+ with self.assertRaisesRegex(RuntimeError, "Error fetching MCP servers"):
+ ApiRegistry(
+ api_registry_project_id=self.project_id, location=self.location
+ )
+
+ @patch("httpx.Client", autospec=True)
+ def test_init_bad_response(self, MockHttpClient):
+ mock_response = MagicMock()
+ mock_response.raise_for_status = MagicMock(
+ side_effect=httpx.HTTPStatusError(
+ "Not Found", request=MagicMock(), response=MagicMock()
+ )
+ )
+ mock_client_instance = MockHttpClient.return_value
+ mock_client_instance.__enter__.return_value = mock_client_instance
+ mock_client_instance.get.return_value = mock_response
+
+ with self.assertRaisesRegex(RuntimeError, "Error fetching MCP servers"):
+ ApiRegistry(
+ api_registry_project_id=self.project_id, location=self.location
+ )
+ mock_response.raise_for_status.assert_called_once()
+
+ @patch("google.adk.tools.api_registry.McpToolset", autospec=True)
+ @patch("httpx.Client", autospec=True)
+ async def test_get_toolset_success(self, MockHttpClient, MockMcpToolset):
+ mock_response = MagicMock()
+ mock_response.raise_for_status = MagicMock()
+ mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST)
+ mock_client_instance = MockHttpClient.return_value
+ mock_client_instance.__enter__.return_value = mock_client_instance
+ mock_client_instance.get.return_value = mock_response
+
+ api_registry = ApiRegistry(
+ api_registry_project_id=self.project_id, location=self.location
+ )
+
+ toolset = api_registry.get_toolset("test-mcp-server-1")
+
+ MockMcpToolset.assert_called_once_with(
+ connection_params=StreamableHTTPConnectionParams(
+ url="https://mcp.server1.com",
+ headers={"Authorization": "Bearer mock_token"},
+ ),
+ tool_filter=None,
+ tool_name_prefix=None,
+ header_provider=None,
+ )
+ self.assertEqual(toolset, MockMcpToolset.return_value)
+
+ @patch("google.adk.tools.api_registry.McpToolset", autospec=True)
+ @patch("httpx.Client", autospec=True)
+ async def test_get_toolset_with_filter_and_prefix(
+ self, MockHttpClient, MockMcpToolset
+ ):
+ mock_response = MagicMock()
+ mock_response.raise_for_status = MagicMock()
+ mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST)
+ mock_client_instance = MockHttpClient.return_value
+ mock_client_instance.__enter__.return_value = mock_client_instance
+ mock_client_instance.get.return_value = mock_response
+
+ api_registry = ApiRegistry(
+ api_registry_project_id=self.project_id, location=self.location
+ )
+ tool_filter = ["tool1"]
+ tool_name_prefix = "prefix_"
+ toolset = api_registry.get_toolset(
+ "test-mcp-server-1",
+ tool_filter=tool_filter,
+ tool_name_prefix=tool_name_prefix,
+ )
+
+ MockMcpToolset.assert_called_once_with(
+ connection_params=StreamableHTTPConnectionParams(
+ url="https://mcp.server1.com",
+ headers={"Authorization": "Bearer mock_token"},
+ ),
+ tool_filter=tool_filter,
+ tool_name_prefix=tool_name_prefix,
+ header_provider=None,
+ )
+ self.assertEqual(toolset, MockMcpToolset.return_value)
+
+ @patch("httpx.Client", autospec=True)
+ async def test_get_toolset_server_not_found(self, MockHttpClient):
+ mock_response = MagicMock()
+ mock_response.raise_for_status = MagicMock()
+ mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST)
+ mock_client_instance = MockHttpClient.return_value
+ mock_client_instance.__enter__.return_value = mock_client_instance
+ mock_client_instance.get.return_value = mock_response
+
+ api_registry = ApiRegistry(
+ api_registry_project_id=self.project_id, location=self.location
+ )
+
+ with self.assertRaisesRegex(ValueError, "not found in API Registry"):
+ api_registry.get_toolset("non-existent-server")
+
+ @patch("httpx.Client", autospec=True)
+ async def test_get_toolset_server_no_url(self, MockHttpClient):
+ mock_response = MagicMock()
+ mock_response.raise_for_status = MagicMock()
+ mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST)
+ mock_client_instance = MockHttpClient.return_value
+ mock_client_instance.__enter__.return_value = mock_client_instance
+ mock_client_instance.get.return_value = mock_response
+
+ api_registry = ApiRegistry(
+ api_registry_project_id=self.project_id, location=self.location
+ )
+
+ with self.assertRaisesRegex(ValueError, "has no URLs"):
+ api_registry.get_toolset("test-mcp-server-no-url")
diff --git a/tests/unittests/tools/test_build_function_declaration.py b/tests/unittests/tools/test_build_function_declaration.py
index f57c3d3838..8a562c7cf4 100644
--- a/tests/unittests/tools/test_build_function_declaration.py
+++ b/tests/unittests/tools/test_build_function_declaration.py
@@ -411,3 +411,19 @@ def transfer_to_agent(agent_name: str, tool_context: ToolContext):
# Changed: Now uses Any type instead of NULL for no return annotation
assert function_decl.response is not None
assert function_decl.response.type is None # Any type maps to None in schema
+
+
+def test_transfer_to_agent_tool_with_enum_constraint():
+ """Test TransferToAgentTool adds enum constraint to agent_name."""
+ from google.adk.tools.transfer_to_agent_tool import TransferToAgentTool
+
+ agent_names = ['agent_a', 'agent_b', 'agent_c']
+ tool = TransferToAgentTool(agent_names=agent_names)
+
+ function_decl = tool._get_declaration()
+
+ assert function_decl.name == 'transfer_to_agent'
+ assert function_decl.parameters.type == 'OBJECT'
+ assert function_decl.parameters.properties['agent_name'].type == 'STRING'
+ assert function_decl.parameters.properties['agent_name'].enum == agent_names
+ assert 'tool_context' not in function_decl.parameters.properties
diff --git a/tests/unittests/tools/test_mcp_toolset.py b/tests/unittests/tools/test_mcp_toolset.py
index a3a6598e35..7bfd912669 100644
--- a/tests/unittests/tools/test_mcp_toolset.py
+++ b/tests/unittests/tools/test_mcp_toolset.py
@@ -14,31 +14,12 @@
"""Unit tests for McpToolset."""
-import sys
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
+from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
import pytest
-# Skip all tests in this module if Python version is less than 3.10
-pytestmark = pytest.mark.skipif(
- sys.version_info < (3, 10), reason="MCP tool requires Python 3.10+"
-)
-
-# Import dependencies with version checking
-try:
- from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
-except ImportError as e:
- if sys.version_info < (3, 10):
- # Create dummy classes to prevent NameError during test collection
- # Tests will be skipped anyway due to pytestmark
- class DummyClass:
- pass
-
- McpToolset = DummyClass
- else:
- raise e
-
@pytest.mark.asyncio
async def test_mcp_toolset_with_prefix():
diff --git a/tests/unittests/tools/test_transfer_to_agent_tool.py b/tests/unittests/tools/test_transfer_to_agent_tool.py
new file mode 100644
index 0000000000..14b7b3abea
--- /dev/null
+++ b/tests/unittests/tools/test_transfer_to_agent_tool.py
@@ -0,0 +1,164 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for TransferToAgentTool enum constraint functionality."""
+
+from unittest.mock import patch
+
+from google.adk.tools.function_tool import FunctionTool
+from google.adk.tools.transfer_to_agent_tool import TransferToAgentTool
+from google.genai import types
+
+
+def test_transfer_to_agent_tool_enum_constraint():
+ """Test that TransferToAgentTool adds enum constraint to agent_name."""
+ agent_names = ['agent_a', 'agent_b', 'agent_c']
+ tool = TransferToAgentTool(agent_names=agent_names)
+
+ decl = tool._get_declaration()
+
+ assert decl is not None
+ assert decl.name == 'transfer_to_agent'
+ assert decl.parameters is not None
+ assert decl.parameters.type == types.Type.OBJECT
+ assert 'agent_name' in decl.parameters.properties
+
+ agent_name_schema = decl.parameters.properties['agent_name']
+ assert agent_name_schema.type == types.Type.STRING
+ assert agent_name_schema.enum == agent_names
+
+ # Verify that agent_name is marked as required
+ assert decl.parameters.required == ['agent_name']
+
+
+def test_transfer_to_agent_tool_single_agent():
+ """Test TransferToAgentTool with a single agent."""
+ tool = TransferToAgentTool(agent_names=['single_agent'])
+
+ decl = tool._get_declaration()
+
+ assert decl is not None
+ agent_name_schema = decl.parameters.properties['agent_name']
+ assert agent_name_schema.enum == ['single_agent']
+
+
+def test_transfer_to_agent_tool_multiple_agents():
+ """Test TransferToAgentTool with multiple agents."""
+ agent_names = ['agent_1', 'agent_2', 'agent_3', 'agent_4', 'agent_5']
+ tool = TransferToAgentTool(agent_names=agent_names)
+
+ decl = tool._get_declaration()
+
+ assert decl is not None
+ agent_name_schema = decl.parameters.properties['agent_name']
+ assert agent_name_schema.enum == agent_names
+ assert len(agent_name_schema.enum) == 5
+
+
+def test_transfer_to_agent_tool_empty_list():
+ """Test TransferToAgentTool with an empty agent list."""
+ tool = TransferToAgentTool(agent_names=[])
+
+ decl = tool._get_declaration()
+
+ assert decl is not None
+ agent_name_schema = decl.parameters.properties['agent_name']
+ assert agent_name_schema.enum == []
+
+
+def test_transfer_to_agent_tool_preserves_description():
+ """Test that TransferToAgentTool preserves the original description."""
+ tool = TransferToAgentTool(agent_names=['agent_a', 'agent_b'])
+
+ decl = tool._get_declaration()
+
+ assert decl is not None
+ assert decl.description is not None
+ assert 'Transfer the question to another agent' in decl.description
+
+
+def test_transfer_to_agent_tool_preserves_parameter_type():
+ """Test that TransferToAgentTool preserves the parameter type."""
+ tool = TransferToAgentTool(agent_names=['agent_a'])
+
+ decl = tool._get_declaration()
+
+ assert decl is not None
+ agent_name_schema = decl.parameters.properties['agent_name']
+ # Should still be a string type, just with enum constraint
+ assert agent_name_schema.type == types.Type.STRING
+
+
+def test_transfer_to_agent_tool_no_extra_parameters():
+ """Test that TransferToAgentTool doesn't add extra parameters."""
+ tool = TransferToAgentTool(agent_names=['agent_a'])
+
+ decl = tool._get_declaration()
+
+ assert decl is not None
+ # Should only have agent_name parameter (tool_context is ignored)
+ assert len(decl.parameters.properties) == 1
+ assert 'agent_name' in decl.parameters.properties
+ assert 'tool_context' not in decl.parameters.properties
+
+
+def test_transfer_to_agent_tool_maintains_inheritance():
+ """Test that TransferToAgentTool inherits from FunctionTool correctly."""
+ tool = TransferToAgentTool(agent_names=['agent_a'])
+
+ assert isinstance(tool, FunctionTool)
+ assert hasattr(tool, '_get_declaration')
+ assert hasattr(tool, 'process_llm_request')
+
+
+def test_transfer_to_agent_tool_handles_parameters_json_schema():
+ """Test that TransferToAgentTool handles parameters_json_schema format."""
+ agent_names = ['agent_x', 'agent_y', 'agent_z']
+
+ # Create a mock FunctionDeclaration with parameters_json_schema
+ mock_decl = type('MockDecl', (), {})()
+ mock_decl.parameters = None # No Schema object
+ mock_decl.parameters_json_schema = {
+ 'type': 'object',
+ 'properties': {
+ 'agent_name': {
+ 'type': 'string',
+ 'description': 'Agent name to transfer to',
+ }
+ },
+ 'required': ['agent_name'],
+ }
+
+ # Temporarily patch FunctionTool._get_declaration
+ with patch.object(
+ FunctionTool,
+ '_get_declaration',
+ return_value=mock_decl,
+ ):
+ tool = TransferToAgentTool(agent_names=agent_names)
+ result = tool._get_declaration()
+
+ # Verify enum was added to parameters_json_schema
+ assert result.parameters_json_schema is not None
+ assert 'agent_name' in result.parameters_json_schema['properties']
+ assert (
+ result.parameters_json_schema['properties']['agent_name']['enum']
+ == agent_names
+ )
+ assert (
+ result.parameters_json_schema['properties']['agent_name']['type']
+ == 'string'
+ )
+ # Verify required field is preserved
+ assert result.parameters_json_schema['required'] == ['agent_name']
diff --git a/tests/unittests/tools/test_vertex_ai_search_tool.py b/tests/unittests/tools/test_vertex_ai_search_tool.py
index 0df19288a3..1ec1572b90 100644
--- a/tests/unittests/tools/test_vertex_ai_search_tool.py
+++ b/tests/unittests/tools/test_vertex_ai_search_tool.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.sequential_agent import SequentialAgent
from google.adk.models.llm_request import LlmRequest
@@ -24,6 +26,10 @@
from google.genai import types
import pytest
+VERTEX_SEARCH_TOOL_LOGGER_NAME = (
+ 'google_adk.google.adk.tools.vertex_ai_search_tool'
+)
+
async def _create_tool_context() -> ToolContext:
session_service = InMemorySessionService()
@@ -121,12 +127,34 @@ def test_init_with_data_store_id(self):
tool = VertexAiSearchTool(data_store_id='test_data_store')
assert tool.data_store_id == 'test_data_store'
assert tool.search_engine_id is None
+ assert tool.data_store_specs is None
def test_init_with_search_engine_id(self):
"""Test initialization with search engine ID."""
tool = VertexAiSearchTool(search_engine_id='test_search_engine')
assert tool.search_engine_id == 'test_search_engine'
assert tool.data_store_id is None
+ assert tool.data_store_specs is None
+
+ def test_init_with_engine_and_specs(self):
+ """Test initialization with search engine ID and specs."""
+ specs = [
+ types.VertexAISearchDataStoreSpec(
+ dataStore=(
+ 'projects/p/locations/l/collections/c/dataStores/spec_store'
+ )
+ )
+ ]
+ engine_id = (
+ 'projects/p/locations/l/collections/c/engines/test_search_engine'
+ )
+ tool = VertexAiSearchTool(
+ search_engine_id=engine_id,
+ data_store_specs=specs,
+ )
+ assert tool.search_engine_id == engine_id
+ assert tool.data_store_id is None
+ assert tool.data_store_specs == specs
def test_init_with_neither_raises_error(self):
"""Test that initialization without either ID raises ValueError."""
@@ -146,10 +174,34 @@ def test_init_with_both_raises_error(self):
data_store_id='test_data_store', search_engine_id='test_search_engine'
)
+ def test_init_with_specs_but_no_engine_raises_error(self):
+ """Test that specs without engine ID raises ValueError."""
+ specs = [
+ types.VertexAISearchDataStoreSpec(
+ dataStore=(
+ 'projects/p/locations/l/collections/c/dataStores/spec_store'
+ )
+ )
+ ]
+ with pytest.raises(
+ ValueError,
+ match=(
+ 'search_engine_id must be specified if data_store_specs is'
+ ' specified'
+ ),
+ ):
+ VertexAiSearchTool(
+ data_store_id='test_data_store', data_store_specs=specs
+ )
+
@pytest.mark.asyncio
- async def test_process_llm_request_with_simple_gemini_model(self):
+ async def test_process_llm_request_with_simple_gemini_model(self, caplog):
"""Test processing LLM request with simple Gemini model name."""
- tool = VertexAiSearchTool(data_store_id='test_data_store')
+ caplog.set_level(logging.DEBUG, logger=VERTEX_SEARCH_TOOL_LOGGER_NAME)
+
+ tool = VertexAiSearchTool(
+ data_store_id='test_data_store', filter='f', max_results=5
+ )
tool_context = await _create_tool_context()
llm_request = LlmRequest(
@@ -162,17 +214,56 @@ async def test_process_llm_request_with_simple_gemini_model(self):
assert llm_request.config.tools is not None
assert len(llm_request.config.tools) == 1
- assert llm_request.config.tools[0].retrieval is not None
- assert llm_request.config.tools[0].retrieval.vertex_ai_search is not None
+ retrieval_tool = llm_request.config.tools[0]
+ assert retrieval_tool.retrieval is not None
+ assert retrieval_tool.retrieval.vertex_ai_search is not None
+ assert (
+ retrieval_tool.retrieval.vertex_ai_search.datastore == 'test_data_store'
+ )
+ assert retrieval_tool.retrieval.vertex_ai_search.engine is None
+ assert retrieval_tool.retrieval.vertex_ai_search.filter == 'f'
+ assert retrieval_tool.retrieval.vertex_ai_search.max_results == 5
+
+ # Verify debug log
+ debug_records = [
+ r
+ for r in caplog.records
+ if 'Adding Vertex AI Search tool config' in r.message
+ ]
+ assert len(debug_records) == 1
+ log_message = debug_records[0].getMessage()
+ assert 'datastore=test_data_store' in log_message
+ assert 'engine=None' in log_message
+ assert 'filter=f' in log_message
+ assert 'max_results=5' in log_message
+ assert 'data_store_specs=None' in log_message
@pytest.mark.asyncio
- async def test_process_llm_request_with_path_based_gemini_model(self):
+ async def test_process_llm_request_with_path_based_gemini_model(self, caplog):
"""Test processing LLM request with path-based Gemini model name."""
- tool = VertexAiSearchTool(data_store_id='test_data_store')
+ caplog.set_level(logging.DEBUG, logger=VERTEX_SEARCH_TOOL_LOGGER_NAME)
+
+ specs = [
+ types.VertexAISearchDataStoreSpec(
+ dataStore=(
+ 'projects/p/locations/l/collections/c/dataStores/spec_store'
+ )
+ )
+ ]
+ engine_id = 'projects/p/locations/l/collections/c/engines/test_engine'
+ tool = VertexAiSearchTool(
+ search_engine_id=engine_id,
+ data_store_specs=specs,
+ filter='f2',
+ max_results=10,
+ )
tool_context = await _create_tool_context()
llm_request = LlmRequest(
- model='projects/265104255505/locations/us-central1/publishers/google/models/gemini-2.0-flash-001',
+ model=(
+ 'projects/265104255505/locations/us-central1/publishers/'
+ 'google/models/gemini-2.0-flash-001'
+ ),
config=types.GenerateContentConfig(),
)
@@ -182,8 +273,28 @@ async def test_process_llm_request_with_path_based_gemini_model(self):
assert llm_request.config.tools is not None
assert len(llm_request.config.tools) == 1
- assert llm_request.config.tools[0].retrieval is not None
- assert llm_request.config.tools[0].retrieval.vertex_ai_search is not None
+ retrieval_tool = llm_request.config.tools[0]
+ assert retrieval_tool.retrieval is not None
+ assert retrieval_tool.retrieval.vertex_ai_search is not None
+ assert retrieval_tool.retrieval.vertex_ai_search.datastore is None
+ assert retrieval_tool.retrieval.vertex_ai_search.engine == engine_id
+ assert retrieval_tool.retrieval.vertex_ai_search.filter == 'f2'
+ assert retrieval_tool.retrieval.vertex_ai_search.max_results == 10
+ assert retrieval_tool.retrieval.vertex_ai_search.data_store_specs == specs
+
+ # Verify debug log
+ debug_records = [
+ r
+ for r in caplog.records
+ if 'Adding Vertex AI Search tool config' in r.message
+ ]
+ assert len(debug_records) == 1
+ log_message = debug_records[0].getMessage()
+ assert 'datastore=None' in log_message
+ assert f'engine={engine_id}' in log_message
+ assert 'filter=f2' in log_message
+ assert 'max_results=10' in log_message
+ assert 'data_store_specs=1 spec(s): [spec_store]' in log_message
@pytest.mark.asyncio
async def test_process_llm_request_with_gemini_1_and_other_tools_raises_error(
@@ -291,9 +402,11 @@ async def test_process_llm_request_with_path_based_non_gemini_model_raises_error
@pytest.mark.asyncio
async def test_process_llm_request_with_gemini_2_and_other_tools_succeeds(
- self,
+ self, caplog
):
"""Test that Gemini 2.x with other tools succeeds."""
+ caplog.set_level(logging.DEBUG, logger=VERTEX_SEARCH_TOOL_LOGGER_NAME)
+
tool = VertexAiSearchTool(data_store_id='test_data_store')
tool_context = await _create_tool_context()
@@ -316,5 +429,23 @@ async def test_process_llm_request_with_gemini_2_and_other_tools_succeeds(
assert llm_request.config.tools is not None
assert len(llm_request.config.tools) == 2
assert llm_request.config.tools[0] == existing_tool
- assert llm_request.config.tools[1].retrieval is not None
- assert llm_request.config.tools[1].retrieval.vertex_ai_search is not None
+ retrieval_tool = llm_request.config.tools[1]
+ assert retrieval_tool.retrieval is not None
+ assert retrieval_tool.retrieval.vertex_ai_search is not None
+ assert (
+ retrieval_tool.retrieval.vertex_ai_search.datastore == 'test_data_store'
+ )
+
+ # Verify debug log
+ debug_records = [
+ r
+ for r in caplog.records
+ if 'Adding Vertex AI Search tool config' in r.message
+ ]
+ assert len(debug_records) == 1
+ log_message = debug_records[0].getMessage()
+ assert 'datastore=test_data_store' in log_message
+ assert 'engine=None' in log_message
+ assert 'filter=None' in log_message
+ assert 'max_results=None' in log_message
+ assert 'data_store_specs=None' in log_message
diff --git a/tests/unittests/utils/test_client_labels_utils.py b/tests/unittests/utils/test_client_labels_utils.py
new file mode 100644
index 0000000000..b1d6acb001
--- /dev/null
+++ b/tests/unittests/utils/test_client_labels_utils.py
@@ -0,0 +1,68 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+
+from google.adk import version
+from google.adk.utils import _client_labels_utils
+import pytest
+
+
+def test_get_client_labels_default():
+ """Test get_client_labels returns default labels."""
+ labels = _client_labels_utils.get_client_labels()
+ assert len(labels) == 2
+ assert f"google-adk/{version.__version__}" == labels[0]
+ assert f"gl-python/{sys.version.split()[0]}" == labels[1]
+
+
+def test_get_client_labels_with_agent_engine_id(monkeypatch):
+ """Test get_client_labels returns agent engine tag when env var is set."""
+ monkeypatch.setenv(
+ _client_labels_utils._AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME,
+ "test-agent-id",
+ )
+ labels = _client_labels_utils.get_client_labels()
+ assert len(labels) == 2
+ assert (
+ f"google-adk/{version.__version__}+{_client_labels_utils._AGENT_ENGINE_TELEMETRY_TAG}"
+ == labels[0]
+ )
+ assert f"gl-python/{sys.version.split()[0]}" == labels[1]
+
+
+def test_get_client_labels_with_context():
+ """Test get_client_labels includes label from context."""
+ with _client_labels_utils.client_label_context("my-label/1.0"):
+ labels = _client_labels_utils.get_client_labels()
+ assert len(labels) == 3
+ assert f"google-adk/{version.__version__}" == labels[0]
+ assert f"gl-python/{sys.version.split()[0]}" == labels[1]
+ assert "my-label/1.0" == labels[2]
+
+
+def test_client_label_context_nested_error():
+ """Test client_label_context raises error when nested."""
+ with pytest.raises(ValueError, match="Client label already exists"):
+ with _client_labels_utils.client_label_context("my-label/1.0"):
+ with _client_labels_utils.client_label_context("another-label/1.0"):
+ pass
+
+
+def test_eval_client_label():
+ """Test EVAL_CLIENT_LABEL has correct format."""
+ assert (
+ f"google-adk-eval/{version.__version__}"
+ == _client_labels_utils.EVAL_CLIENT_LABEL
+ )