diff --git a/.github/workflows/analyze-releases-for-adk-docs-updates.yml b/.github/workflows/analyze-releases-for-adk-docs-updates.yml index c8a86fac66..21414ae534 100644 --- a/.github/workflows/analyze-releases-for-adk-docs-updates.yml +++ b/.github/workflows/analyze-releases-for-adk-docs-updates.yml @@ -16,10 +16,10 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.11' diff --git a/.github/workflows/check-file-contents.yml b/.github/workflows/check-file-contents.yml index 861e2247a0..bb575e0f20 100644 --- a/.github/workflows/check-file-contents.yml +++ b/.github/workflows/check-file-contents.yml @@ -24,7 +24,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout Code - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: fetch-depth: 2 diff --git a/.github/workflows/copybara-pr-handler.yml b/.github/workflows/copybara-pr-handler.yml index 670389c527..4ca3c48803 100644 --- a/.github/workflows/copybara-pr-handler.yml +++ b/.github/workflows/copybara-pr-handler.yml @@ -25,7 +25,7 @@ jobs: steps: - name: Check for Copybara commits and close PRs - uses: actions/github-script@v7 + uses: actions/github-script@v8 with: github-token: ${{ secrets.ADK_TRIAGE_AGENT }} script: | diff --git a/.github/workflows/discussion_answering.yml b/.github/workflows/discussion_answering.yml index 71c06ba9f6..d9bfffc361 100644 --- a/.github/workflows/discussion_answering.yml +++ b/.github/workflows/discussion_answering.yml @@ -15,10 +15,10 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.11' diff --git a/.github/workflows/isort.yml b/.github/workflows/isort.yml index e1a087742c..b8b24da5ce 100644 --- a/.github/workflows/isort.yml +++ b/.github/workflows/isort.yml @@ -26,12 +26,12 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: fetch-depth: 2 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.x' diff --git a/.github/workflows/pr-triage.yml b/.github/workflows/pr-triage.yml index a8a2082094..55b088b505 100644 --- a/.github/workflows/pr-triage.yml +++ b/.github/workflows/pr-triage.yml @@ -20,10 +20,10 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.11' diff --git a/.github/workflows/pyink.yml b/.github/workflows/pyink.yml index ef9e72e453..0822757fa0 100644 --- a/.github/workflows/pyink.yml +++ b/.github/workflows/pyink.yml @@ -26,12 +26,12 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: fetch-depth: 2 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.x' diff --git a/.github/workflows/python-unit-tests.yml b/.github/workflows/python-unit-tests.yml index 42b6174813..3fc6bd943f 100644 --- a/.github/workflows/python-unit-tests.yml +++ b/.github/workflows/python-unit-tests.yml @@ -25,14 +25,14 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + python-version: ["3.10", "3.11", "3.12", "3.13"] steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} @@ -48,14 +48,6 @@ jobs: - name: Run unit tests with pytest run: | source .venv/bin/activate - if [[ "${{ matrix.python-version }}" == "3.9" ]]; then - pytest tests/unittests \ - --ignore=tests/unittests/a2a \ - --ignore=tests/unittests/tools/mcp_tool \ - --ignore=tests/unittests/artifacts/test_artifact_service.py \ - --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py - else - pytest tests/unittests \ - --ignore=tests/unittests/artifacts/test_artifact_service.py \ - --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py - fi \ No newline at end of file + pytest tests/unittests \ + --ignore=tests/unittests/artifacts/test_artifact_service.py \ + --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py \ No newline at end of file diff --git a/.github/workflows/stale-bot.yml b/.github/workflows/stale-bot.yml new file mode 100644 index 0000000000..6948b56459 --- /dev/null +++ b/.github/workflows/stale-bot.yml @@ -0,0 +1,43 @@ +name: ADK Stale Issue Auditor + +on: + workflow_dispatch: + + schedule: + # This runs at 6:00 AM UTC (10 PM PST) + - cron: '0 6 * * *' + +jobs: + audit-stale-issues: + runs-on: ubuntu-latest + timeout-minutes: 60 + + permissions: + issues: write + contents: read + + steps: + - name: Checkout repository + uses: actions/checkout@v5 + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: '3.11' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install requests google-adk + + - name: Run Auditor Agent Script + env: + GITHUB_TOKEN: ${{ secrets.ADK_TRIAGE_AGENT }} + GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }} + OWNER: ${{ github.repository_owner }} + REPO: adk-python + CONCURRENCY_LIMIT: 3 + LLM_MODEL_NAME: "gemini-2.5-flash" + PYTHONPATH: contributing/samples + + run: python -m adk_stale_agent.main \ No newline at end of file diff --git a/.github/workflows/triage.yml b/.github/workflows/triage.yml index 57e729e9b5..46153f413a 100644 --- a/.github/workflows/triage.yml +++ b/.github/workflows/triage.yml @@ -2,21 +2,32 @@ name: ADK Issue Triaging Agent on: issues: - types: [opened, reopened] + types: [opened, labeled] + schedule: + # Run every 6 hours to triage untriaged issues + - cron: '0 */6 * * *' jobs: agent-triage-issues: runs-on: ubuntu-latest + # Run for: + # - Scheduled runs (batch processing) + # - New issues (need component labeling) + # - Issues labeled with "planned" (need owner assignment) + if: >- + github.event_name == 'schedule' || + github.event.action == 'opened' || + github.event.label.name == 'planned' permissions: issues: write contents: read steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.11' @@ -30,8 +41,8 @@ jobs: GITHUB_TOKEN: ${{ secrets.ADK_TRIAGE_AGENT }} GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }} GOOGLE_GENAI_USE_VERTEXAI: 0 - OWNER: 'google' - REPO: 'adk-python' + OWNER: ${{ github.repository_owner }} + REPO: ${{ github.event.repository.name }} INTERACTIVE: 0 EVENT_NAME: ${{ github.event_name }} # 'issues', 'schedule', etc. ISSUE_NUMBER: ${{ github.event.issue.number }} diff --git a/.github/workflows/upload-adk-docs-to-vertex-ai-search.yml b/.github/workflows/upload-adk-docs-to-vertex-ai-search.yml index 9b1f042917..bce7598c2f 100644 --- a/.github/workflows/upload-adk-docs-to-vertex-ai-search.yml +++ b/.github/workflows/upload-adk-docs-to-vertex-ai-search.yml @@ -13,7 +13,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Clone adk-docs repository run: git clone https://github.com/google/adk-docs.git /tmp/adk-docs @@ -22,7 +22,7 @@ jobs: run: git clone https://github.com/google/adk-python.git /tmp/adk-python - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.11' diff --git a/CHANGELOG.md b/CHANGELOG.md index ced7b7026b..93dc505adc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,146 @@ # Changelog +## [1.20.0](https://github.com/google/adk-python/compare/v1.19.0...v1.20.0) (2025-12-01) + + +### Features +* **[Core]** + * Add enum constraint to `agent_name` for `transfer_to_agent` ([4a42d0d](https://github.com/google/adk-python/commit/4a42d0d9d81b7aab98371427f70a7707dbfb8bc4)) + * Add validation for unique sub-agent names ([#3557](https://github.com/google/adk-python/issues/3557)) ([2247a45](https://github.com/google/adk-python/commit/2247a45922afdf0a733239b619f45601d9b325ec)) + * Support streaming function call arguments in progressive SSE streaming feature ([786aaed](https://github.com/google/adk-python/commit/786aaed335e1ce64b7e92dff2f4af8316b2ef593)) + +* **[Models]** + * Enable multi-provider support for Claude and LiteLLM ([d29261a](https://github.com/google/adk-python/commit/d29261a3dc9c5a603feef27ea657c4a03bb8a089)) + +* **[Tools]** + * Create APIRegistryToolset to add tools from Cloud API registry to agent ([ec4ccd7](https://github.com/google/adk-python/commit/ec4ccd718feeadeb6b2b59fcc0e9ff29a4fd0bac)) + * Add an option to disallow propagating runner plugins to AgentTool runner ([777dba3](https://github.com/google/adk-python/commit/777dba3033a9a14667fb009ba017f648177be41d)) + +* **[Web]** + * Added an endpoint to list apps with details ([b57fe5f](https://github.com/google/adk-python/commit/b57fe5f4598925ec7592917bb32c7f0d6eca287a)) + + +### Bug Fixes + +* Allow image parts in user messages for Anthropic Claude ([5453b5b](https://github.com/google/adk-python/commit/5453b5bfdedc91d9d668c9eac39e3bb009a7bbbf)) +* Mark the Content as non-empty if its first part contains text or inline_data or file_data or func call/response ([631b583](https://github.com/google/adk-python/commit/631b58336d36bfd93e190582be34069613d38559)) +* Fixes double response processing issue in `base_llm_flow.py` where, in Bidi-streaming (live) mode, the multi-agent structure causes duplicated responses after tool calling. ([cf21ca3](https://github.com/google/adk-python/commit/cf21ca358478919207049695ba6b31dc6e0b2673)) +* Fix out of bounds error in _run_async_impl ([8fc6128](https://github.com/google/adk-python/commit/8fc6128b62ba576480d196d4a2597564fd0a7006)) +* Fix paths for public docs ([cd54f48](https://github.com/google/adk-python/commit/cd54f48fed0c87b54fb19743c9c75e790c5d9135)) +* Ensure request bodies without explicit names are named 'body' ([084c2de](https://github.com/google/adk-python/commit/084c2de0dac84697906e2b4beebf008bbd9ae8e1)), closes [#2213](https://github.com/google/adk-python/issues/2213) +* Optimize Stale Agent with GraphQL and Search API to resolve 429 Quota errors ([cb19d07](https://github.com/google/adk-python/commit/cb19d0714c90cd578551753680f39d8d6076c79b)) +* Update AgentTool to use Agent's description when input_schema is provided in FunctionDeclaration ([52674e7](https://github.com/google/adk-python/commit/52674e7fac6b7689f0e3871d41c4523e13471a7e)) +* Update LiteLLM system instruction role from "developer" to "system" ([2e1f730](https://github.com/google/adk-python/commit/2e1f730c3bc0eb454b76d7f36b7b9f1da7304cfe)), closes [#3657](https://github.com/google/adk-python/issues/3657) +* Update session last update time when appending events ([a3e4ad3](https://github.com/google/adk-python/commit/a3e4ad3cd130714affcaa880f696aeb498cd93af)), closes [#2721](https://github.com/google/adk-python/issues/2721) +* Update the retry_on_closed_resource decorator to retry on all errors ([a3aa077](https://github.com/google/adk-python/commit/a3aa07722a7de3e08807e86fd10f28938f0b267d)) +* Windows Path Handling and Normalize Cross-Platform Path Resolution in AgentLoader ([a1c09b7](https://github.com/google/adk-python/commit/a1c09b724bb37513eaabaff9643eeaa68014f14d)) + + +### Documentation + +* Add Code Wiki badge to README ([caf23ac](https://github.com/google/adk-python/commit/caf23ac49fe08bc7f625c61eed4635c26852c3ba)) + + +## [1.19.0](https://github.com/google/adk-python/compare/v1.18.0...v1.19.0) (2025-11-19) + +### Features + +* **[Core]** + * Add `id` and `custom_metadata` fields to `MemoryEntry` ([4dd28a3](https://github.com/google/adk-python/commit/4dd28a3970d0f76c571caf80b3e1bea1b79e9dde)) + * Add progressive SSE streaming feature ([a5ac1d5](https://github.com/google/adk-python/commit/a5ac1d5e14f5ce7cd875d81a494a773710669dc1)) + * Add a2a_request_meta_provider to RemoteAgent init ([d12468e](https://github.com/google/adk-python/commit/d12468ee5a2b906b6699ccdb94c6a5a4c2822465)) + * Add feature decorator for the feature registry system ([871da73](https://github.com/google/adk-python/commit/871da731f1c09c6a62d51b137d9d2e7c9fb3897a)) + * Breaking: Raise minimum Python version to 3_10 ([8402832](https://github.com/google/adk-python/commit/840283228ee77fb3dbd737cfe7eb8736d9be5ec8)) + * Refactor and rename BigQuery agent analytics plugin ([6b14f88](https://github.com/google/adk-python/commit/6b14f887262722ccb85dcd6cef9c0e9b103cfa6e)) + * Pass custom_metadata through forwarding artifact service ([c642f13](https://github.com/google/adk-python/commit/c642f13f216fb64bc93ac46c1c57702c8a2add8c)) + * Update save_files_as_artifacts_plugin to never keep inline data ([857de04](https://github.com/google/adk-python/commit/857de04debdeba421075c2283c9bd8518d586624)) + +* **[Evals]** + * Add support for InOrder and AnyOrder match in ToolTrajectoryAvgScore Metric ([e2d3b2d](https://github.com/google/adk-python/commit/e2d3b2d862f7fc93807d16089307d4df25367a24)) + +* **[Integrations]** + * Enhance BQ Plugin Schema, Error Handling, and Logging ([5ac5129](https://github.com/google/adk-python/commit/5ac5129fb01913516d6f5348a825ca83d024d33a)) + * Schema Enhancements with Descriptions, Partitioning, and Truncation Indicator ([7c993b0](https://github.com/google/adk-python/commit/7c993b01d1b9d582b4e2348f73c0591d47bf2f3a)) + +* **[Services]** + * Add file-backed artifact service ([99ca6aa](https://github.com/google/adk-python/commit/99ca6aa6e6b4027f37d091d9c93da6486def20d7)) + * Add service factory for configurable session and artifact backends ([a12ae81](https://github.com/google/adk-python/commit/a12ae812d367d2d00ab246f85a73ed679dd3828a)) + * Add SqliteSessionService and a migration script to migrate existing DB using DatabaseSessionService to SqliteSessionService ([e218254](https://github.com/google/adk-python/commit/e2182544952c0174d1a8307fbba319456dca748b)) + * Add transcription fields to session events ([3ad30a5](https://github.com/google/adk-python/commit/3ad30a58f95b8729f369d00db799546069d7b23a)) + * Full async implementation of DatabaseSessionService ([7495941](https://github.com/google/adk-python/commit/74959414d8ded733d584875a49fb4638a12d3ce5)) + +* **[Models]** + * Add experimental feature to use `parameters_json_schema` and `response_json_schema` for McpTool ([1dd97f5](https://github.com/google/adk-python/commit/1dd97f5b45226c25e4c51455c78ebf3ff56ab46a)) + * Add support for parsing inline JSON tool calls in LiteLLM responses ([22eb7e5](https://github.com/google/adk-python/commit/22eb7e5b06c9e048da5bb34fe7ae9135d00acb4e)) + * Expose artifact URLs to the model when available ([e3caf79](https://github.com/google/adk-python/commit/e3caf791395ce3cc0b10410a852be6e7b0d8d3b1)) + +* **[Tools]** + * Add BigQuery related label handling ([ffbab4c](https://github.com/google/adk-python/commit/ffbab4cf4ed6ceb313241c345751214d3c0e11ce)) + * Allow setting max_billed_bytes in BigQuery tools config ([ffbb0b3](https://github.com/google/adk-python/commit/ffbb0b37e128de50ebf57d76cba8b743a8b970d5)) + * Propagate `application_name` set for the BigQuery Tools as BigQuery job labels ([f13a11e](https://github.com/google/adk-python/commit/f13a11e1dc27c5aa46345154fbe0eecfe1690cbb)) + * Set per-tool user agent in BQ calls and tool label in BQ jobs ([c0be1df](https://github.com/google/adk-python/commit/c0be1df0521cfd4b84585f404d4385b80d08ba59)) + +* **[Observability]** + * Migrate BigQuery logging to Storage Write API ([a2ce34a](https://github.com/google/adk-python/commit/a2ce34a0b9a8403f830ff637d0e2094e82dee8e7)) + +### Bug Fixes + +* Add `jsonschema` dependency for Agent Builder config validation ([0fa7e46](https://github.com/google/adk-python/commit/0fa7e4619d589dc834f7508a18bc2a3b93ec7fd9)) +* Add None check for `event` in `remote_a2a_agent.py` ([744f94f](https://github.com/google/adk-python/commit/744f94f0c8736087724205bbbad501640b365270)) +* Add vertexai initialization for code being deployed to AgentEngine ([b8e4aed](https://github.com/google/adk-python/commit/b8e4aedfbf0eb55b34599ee24e163b41072a699c)) +* Change LiteLLM content and tool parameter handling ([a19be12](https://github.com/google/adk-python/commit/a19be12c1f04bb62a8387da686499857c24b45c0)) +* Change name for builder agent ([131d39c](https://github.com/google/adk-python/commit/131d39c3db1ae25e3911fa7f72afbe05e24a1c37)) +* Ensure event compaction completes by awaiting task ([b5f5df9](https://github.com/google/adk-python/commit/b5f5df9fa8f616b855c186fcef45bade00653c77)) +* Fix deploy to cloud run on Windows ([29fea7e](https://github.com/google/adk-python/commit/29fea7ec1fb27989f07c90494b2d6acbe76c03d8)) +* Fix error handling when MCP server is unreachable ([ee8106b](https://github.com/google/adk-python/commit/ee8106be77f253e3687e72ae0e236687d254965c)) +* Fix error when query job destination is None ([0ccc43c](https://github.com/google/adk-python/commit/0ccc43cf49dc0882dc896455d6603a602d8a28e7)) +* Fix Improve logic for checking if a MCP session is disconnected ([a754c96](https://github.com/google/adk-python/commit/a754c96d3c4fd00f9c2cd924fc428b68cc5115fb)) +* Fix McpToolset crashing with anyio.BrokenResourceError ([8e0648d](https://github.com/google/adk-python/commit/8e0648df23d0694afd3e245ec4a3c41aa935120a)) +* Fix Safely handle `FunctionDeclaration` without a `required` attribute ([93aad61](https://github.com/google/adk-python/commit/93aad611983dc1daf415d3a73105db45bbdd1988)) +* Fix status code in error message in RestApiTool ([9b75456](https://github.com/google/adk-python/commit/9b754564b3cc5a06ad0c6ae2cd2d83082f9f5943)) +* Fix Use `async for` to loop through event iterator to get all events in vertex_ai_session_service ([9211f4c](https://github.com/google/adk-python/commit/9211f4ce8cc6d918df314d6a2ff13da2e0ef35fa)) +* Fix: Fixes DeprecationWarning when using send method ([2882995](https://github.com/google/adk-python/commit/28829952890c39dbdb4463b2b67ff241d0e9ef6d)) +* Improve logic for checking if a MCP session is disconnected ([a48a1a9](https://github.com/google/adk-python/commit/a48a1a9e889d4126e6f30b56c93718dfbacef624)) +* Improve handling of partial and complete transcriptions in live calls ([1819ecb](https://github.com/google/adk-python/commit/1819ecb4b8c009d02581c2d060fae49cd7fdf653)) +* Keep vertex session event after the session update time ([0ec0195](https://github.com/google/adk-python/commit/0ec01956e86df6ae8e6553c70e410f1f8238ba88)) +* Let part converters also return multiple parts so they can support more usecases ([824ab07](https://github.com/google/adk-python/commit/824ab072124e037cc373c493f43de38f8b61b534)) +* Load agent/app before creating session ([236f562](https://github.com/google/adk-python/commit/236f562cd275f84837be46f7dfb0065f85425169)) +* Remove app name from FileArtifactService directory structure ([12db84f](https://github.com/google/adk-python/commit/12db84f5cd6d8b6e06142f6f6411f6b78ff3f177)) +* Remove hardcoded `google-cloud-aiplatform` version in agent engine requirements ([e15e19d](https://github.com/google/adk-python/commit/e15e19da05ee1b763228467e83f6f73e0eced4b5)) +* Stop updating write mode in the global settings during tool execution ([5adbf95](https://github.com/google/adk-python/commit/5adbf95a0ab0657dd7df5c4a6bac109d424d436e)) +* Update description for `load_artifacts` tool ([c485889](https://github.com/google/adk-python/commit/c4858896ff085bedcfbc42b2010af8bd78febdd0)) + +### Improvements + +* Add BigQuery related label handling ([ffbab4c](https://github.com/google/adk-python/commit/ffbab4cf4ed6ceb313241c345751214d3c0e11ce)) +* Add demo for rewind ([8eb1bdb](https://github.com/google/adk-python/commit/8eb1bdbc58dc709006988f5b6eec5fda25bd0c89)) +* Add debug logging for live connection ([5d5708b](https://github.com/google/adk-python/commit/5d5708b2ab26cb714556311c490b4d6f0a1f9666)) +* Add debug logging for missing function call events ([f3d6fcf](https://github.com/google/adk-python/commit/f3d6fcf44411d07169c14ae12189543f44f96c27)) +* Add default retry options as fall back to llm_request that are made during evals ([696852a](https://github.com/google/adk-python/commit/696852a28095a024cbe76413ee7617356e19a9e3)) +* Add plugin for returning GenAI Parts from tools into the model request ([116b26c](https://github.com/google/adk-python/commit/116b26c33e166bf1a22964e2b67013907fbfcb80)) +* Add support for abstract types in AFC ([2efc184](https://github.com/google/adk-python/commit/2efc184a46173529bdfc622db0d6f3866e7ee778)) +* Add support for structured output schemas in LiteLLM models ([7ea4aed](https://github.com/google/adk-python/commit/7ea4aed35ba70ec5a38dc1b3b0a9808183c2bab1)) +* Add tests for `max_query_result_rows` in BigQuery tool config ([fd33610](https://github.com/google/adk-python/commit/fd33610e967ad814bc02422f5d14dae046bee833)) +* Add type hints in `cleanup_unused_files.py` ([2dea573](https://github.com/google/adk-python/commit/2dea5733b759a7a07d74f36a4d6da7b081afc732)) +* Add util to build our llms.txt and llms-full.txt files +* ADK changes ([f1f4467](https://github.com/google/adk-python/commit/f1f44675e4a86b75e72cfd838efd8a0399f23e24)) +* Defer import of `google.cloud.storage` in `GCSArtifactService` ([999af55](https://github.com/google/adk-python/commit/999af5588005e7b29451bdbf9252265187ca992d)) +* Defer import of `live`, `Client` and `_transformers` in `google.genai` ([22c6dbe](https://github.com/google/adk-python/commit/22c6dbe83cd1a8900d0ac6fd23d2092f095189fa)) +* Enhance the messaging with possible fixes for RESOURCE_EXHAUSTED errors from Gemini ([b2c45f8](https://github.com/google/adk-python/commit/b2c45f8d910eb7bca4805c567279e65aff72b58a)) +* Improve gepa tau-bench colab for external use ([e02f177](https://github.com/google/adk-python/commit/e02f177790d9772dd253c9102b80df1a9418aa7f)) +* Improve gepa voter agent demo colab ([d118479](https://github.com/google/adk-python/commit/d118479ccf3a970ce9b24ac834b4b6764edb5de4)) +* Lazy import DatabaseSessionService in the adk/sessions/ module ([5f05749](https://github.com/google/adk-python/commit/5f057498a274d3b3db0be0866f04d5225334f54a)) +* Move adk_agent_builder_assistant to built_in_agents ([b2b7f2d](https://github.com/google/adk-python/commit/b2b7f2d6aa5b919a00a92abaf2543993746e939e)) +* Plumb memory service from LocalEvalService to EvaluationGenerator ([dc3f60c](https://github.com/google/adk-python/commit/dc3f60cc939335da49399a69c0b4abc0e7f25ea4)) +* Removes the unrealistic todo comment of visibility management ([e511eb1](https://github.com/google/adk-python/commit/e511eb1f70f2a3fccc9464ddaf54d0165db22feb)) +* Returns agent state regardless if ctx.is_resumable ([d6b928b](https://github.com/google/adk-python/commit/d6b928bdf7cdbf8f1925d4c5227c7d580093348e)) +* Stop logging the full content of LLM blobs ([0826755](https://github.com/google/adk-python/commit/082675546f501a70f4bc8969b9431a2e4808bd13)) +* Update ADK web to match main branch ([14e3802](https://github.com/google/adk-python/commit/14e3802643a2d8ce436d030734fafd163080a1ad)) +* Update agent instructions and retry limit in `plugin_reflect_tool_retry` sample ([01bac62](https://github.com/google/adk-python/commit/01bac62f0c14cce5d454a389b64a9f44a03a3673)) +* Update conformance test CLI to handle long-running tool calls ([dd706bd](https://github.com/google/adk-python/commit/dd706bdc4563a2a815459482237190a63994cb6f)) +* Update Gemini Live model names in live bidi streaming sample ([aa77834](https://github.com/google/adk-python/commit/aa77834e2ecd4b77dfb4e689ef37549b3ebd6134)) + + ## [1.18.0](https://github.com/google/adk-python/compare/v1.17.0...v1.18.0) (2025-11-05) ### Features diff --git a/README.md b/README.md index 3113b1ccdf..9046cc49a3 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [![PyPI](https://img.shields.io/pypi/v/google-adk)](https://pypi.org/project/google-adk/) [![Python Unit Tests](https://github.com/google/adk-python/actions/workflows/python-unit-tests.yml/badge.svg)](https://github.com/google/adk-python/actions/workflows/python-unit-tests.yml) [![r/agentdevelopmentkit](https://img.shields.io/badge/Reddit-r%2Fagentdevelopmentkit-FF4500?style=flat&logo=reddit&logoColor=white)](https://www.reddit.com/r/agentdevelopmentkit/) -[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/google/adk-python) +Ask Code Wiki

diff --git a/contributing/samples/adk_stale_agent/PROMPT_INSTRUCTION.txt b/contributing/samples/adk_stale_agent/PROMPT_INSTRUCTION.txt new file mode 100644 index 0000000000..8f5f585ff6 --- /dev/null +++ b/contributing/samples/adk_stale_agent/PROMPT_INSTRUCTION.txt @@ -0,0 +1,68 @@ +You are a highly intelligent repository auditor for '{OWNER}/{REPO}'. +Your job is to analyze a specific issue and report findings before taking action. + +**Primary Directive:** Ignore any events from users ending in `[bot]`. +**Reporting Directive:** Output a concise summary starting with "Analysis for Issue #[number]:". + +**THRESHOLDS:** +- Stale Threshold: {stale_threshold_days} days. +- Close Threshold: {close_threshold_days} days. + +**WORKFLOW:** +1. **Context Gathering**: Call `get_issue_state`. +2. **Decision**: Follow this strict decision tree using the data returned by the tool. + +--- **DECISION TREE** --- + +**STEP 1: CHECK IF ALREADY STALE** +- **Condition**: Is `is_stale` (from tool) **True**? +- **Action**: + - **Check Role**: Look at `last_action_role`. + + - **IF 'author' OR 'other_user'**: + - **Context**: The user has responded. The issue is now ACTIVE. + - **Action 1**: Call `remove_label_from_issue` with '{STALE_LABEL_NAME}'. + - **Action 2 (ALERT CHECK)**: Look at `maintainer_alert_needed`. + - **IF True**: User edited description silently. + -> **Action**: Call `alert_maintainer_of_edit`. + - **IF False**: User commented normally. No alert needed. + - **Report**: "Analysis for Issue #[number]: ACTIVE. User activity detected. Removed stale label." + + - **IF 'maintainer'**: + - **Check Time**: Check `days_since_stale_label`. + - **If `days_since_stale_label` > {close_threshold_days}**: + - **Action**: Call `close_as_stale`. + - **Report**: "Analysis for Issue #[number]: STALE. Close threshold met. Closing." + - **Else**: + - **Report**: "Analysis for Issue #[number]: STALE. Waiting for close threshold. No action." + +**STEP 2: CHECK IF ACTIVE (NOT STALE)** +- **Condition**: `is_stale` is **False**. +- **Action**: + - **Check Role**: If `last_action_role` is 'author' or 'other_user': + - **Context**: The issue is Active. + - **Action (ALERT CHECK)**: Look at `maintainer_alert_needed`. + - **IF True**: The user edited the description silently, and we haven't alerted yet. + -> **Action**: Call `alert_maintainer_of_edit`. + -> **Report**: "Analysis for Issue #[number]: ACTIVE. Silent update detected (Description Edit). Alerted maintainer." + - **IF False**: + -> **Report**: "Analysis for Issue #[number]: ACTIVE. Last action was by user. No action." + + - **Check Role**: If `last_action_role` is 'maintainer': + - **Proceed to STEP 3.** + +**STEP 3: ANALYZE MAINTAINER INTENT** +- **Context**: The last person to act was a Maintainer. +- **Action**: Read the text in `last_comment_text`. + - **Question Check**: Does the text ask a question, request clarification, ask for logs, or suggest trying a fix? + - **Time Check**: Is `days_since_activity` > {stale_threshold_days}? + + - **DECISION**: + - **IF (Question == YES) AND (Time == YES)**: + - **Action**: Call `add_stale_label_and_comment`. + - **Check**: If '{REQUEST_CLARIFICATION_LABEL}' is not in `current_labels`, call `add_label_to_issue` for it. + - **Report**: "Analysis for Issue #[number]: STALE. Maintainer asked question [days_since_activity] days ago. Marking stale." + - **IF (Question == YES) BUT (Time == NO)**: + - **Report**: "Analysis for Issue #[number]: PENDING. Maintainer asked question, but threshold not met yet. No action." + - **IF (Question == NO)** (e.g., "I am working on this"): + - **Report**: "Analysis for Issue #[number]: ACTIVE. Maintainer gave status update (not a question). No action." \ No newline at end of file diff --git a/contributing/samples/adk_stale_agent/README.md b/contributing/samples/adk_stale_agent/README.md new file mode 100644 index 0000000000..afc47b11cc --- /dev/null +++ b/contributing/samples/adk_stale_agent/README.md @@ -0,0 +1,89 @@ +# ADK Stale Issue Auditor Agent + +This directory contains an autonomous, **GraphQL-powered** agent designed to audit a GitHub repository for stale issues. It maintains repository hygiene by ensuring all open items are actionable and responsive. + +Unlike traditional "Stale Bots" that only look at timestamps, this agent uses a **Unified History Trace** and an **LLM (Large Language Model)** to understand the *context* of a conversation. It distinguishes between a maintainer asking a question (stale candidate) vs. a maintainer providing a status update (active). + +--- + +## Core Logic & Features + +The agent operates as a "Repository Auditor," proactively scanning open issues using a high-efficiency decision tree. + +### 1. Smart State Verification (GraphQL) +Instead of making multiple expensive API calls, the agent uses a single **GraphQL** query per issue to reconstruct the entire history of the conversation. It combines: +* **Comments** +* **Description/Body Edits** ("Ghost Edits") +* **Title Renames** +* **State Changes** (Reopens) + +It sorts these events chronologically to determine the **Last Active Actor**. + +### 2. The "Last Actor" Rule +The agent follows a precise logic flow based on who acted last: + +* **If Author/User acted last:** The issue is **ACTIVE**. + * This includes comments, title changes, and *silent* description edits. + * **Action:** The agent immediately removes the `stale` label. + * **Silent Update Alert:** If the user edited the description but *did not* comment, the agent posts a specific alert: *"Notification: The author has updated the issue description..."* to ensure maintainers are notified (since GitHub does not trigger notifications for body edits). + * **Spam Prevention:** The agent checks if it has already alerted about a specific silent edit to avoid spamming the thread. + +* **If Maintainer acted last:** The issue is **POTENTIALLY STALE**. + * The agent passes the text of the maintainer's last comment to the LLM. + +### 3. Semantic Intent Analysis (LLM) +If the maintainer was the last person to speak, the LLM analyzes the comment text to determine intent: +* **Question/Request:** "Can you provide logs?" / "Please try v2.0." + * **Verdict:** **STALE** (Waiting on Author). + * **Action:** If the time threshold is met, the agent adds the `stale` label. It also checks for the `request clarification` label and adds it if missing. +* **Status Update:** "We are working on a fix." / "Added to backlog." + * **Verdict:** **ACTIVE** (Waiting on Maintainer). + * **Action:** No action taken. The issue remains open without stale labels. + +### 4. Lifecycle Management +* **Marking Stale:** After `STALE_HOURS_THRESHOLD` (default: 7 days) of inactivity following a maintainer's question. +* **Closing:** After `CLOSE_HOURS_AFTER_STALE_THRESHOLD` (default: 7 days) of continued inactivity while marked stale. + +--- + +## Performance & Safety + +* **GraphQL Optimized:** Fetches comments, edits, labels, and timeline events in a single network request to minimize latency and API quota usage. +* **Search API Filtering:** Uses the GitHub Search API to pre-filter issues created recently, ensuring the bot doesn't waste cycles analyzing brand-new issues. +* **Rate Limit Aware:** Includes intelligent sleeping and retry logic (exponential backoff) to handle GitHub API rate limits (HTTP 429) gracefully. +* **Execution Metrics:** Logs the time taken and API calls consumed for every issue processed. + +--- + +## Configuration + +The agent is configured via environment variables, typically set as secrets in GitHub Actions. + +### Required Secrets + +| Secret Name | Description | +| :--- | :--- | +| `GITHUB_TOKEN` | A GitHub Personal Access Token (PAT) or Service Account Token with `repo` scope. | +| `GOOGLE_API_KEY` | An API key for the Google AI (Gemini) model used for reasoning. | + +### Optional Configuration + +These variables control the timing thresholds and model selection. + +| Variable Name | Description | Default | +| :--- | :--- | :--- | +| `STALE_HOURS_THRESHOLD` | Hours of inactivity after a maintainer's question before marking as `stale`. | `168` (7 days) | +| `CLOSE_HOURS_AFTER_STALE_THRESHOLD` | Hours after being marked `stale` before the issue is closed. | `168` (7 days) | +| `LLM_MODEL_NAME`| The specific Gemini model version to use. | `gemini-2.5-flash` | +| `OWNER` | Repository owner (auto-detected in Actions). | (Environment dependent) | +| `REPO` | Repository name (auto-detected in Actions). | (Environment dependent) | + +--- + +## Deployment + +To deploy this agent, a GitHub Actions workflow file (`.github/workflows/stale-bot.yml`) is recommended. + +### Directory Structure Note +Because this agent resides within the `adk-python` package structure, the workflow must ensure the script is executed correctly to handle imports. + diff --git a/contributing/samples/adk_stale_agent/__init__.py b/contributing/samples/adk_stale_agent/__init__.py new file mode 100644 index 0000000000..c48963cdc7 --- /dev/null +++ b/contributing/samples/adk_stale_agent/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/adk_stale_agent/agent.py b/contributing/samples/adk_stale_agent/agent.py new file mode 100644 index 0000000000..8769adc193 --- /dev/null +++ b/contributing/samples/adk_stale_agent/agent.py @@ -0,0 +1,597 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime +from datetime import timezone +import logging +import os +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +from adk_stale_agent.settings import CLOSE_HOURS_AFTER_STALE_THRESHOLD +from adk_stale_agent.settings import GITHUB_BASE_URL +from adk_stale_agent.settings import GRAPHQL_COMMENT_LIMIT +from adk_stale_agent.settings import GRAPHQL_EDIT_LIMIT +from adk_stale_agent.settings import GRAPHQL_TIMELINE_LIMIT +from adk_stale_agent.settings import LLM_MODEL_NAME +from adk_stale_agent.settings import OWNER +from adk_stale_agent.settings import REPO +from adk_stale_agent.settings import REQUEST_CLARIFICATION_LABEL +from adk_stale_agent.settings import STALE_HOURS_THRESHOLD +from adk_stale_agent.settings import STALE_LABEL_NAME +from adk_stale_agent.utils import delete_request +from adk_stale_agent.utils import error_response +from adk_stale_agent.utils import get_request +from adk_stale_agent.utils import patch_request +from adk_stale_agent.utils import post_request +import dateutil.parser +from google.adk.agents.llm_agent import Agent +from requests.exceptions import RequestException + +logger = logging.getLogger("google_adk." + __name__) + +# --- Constants --- +# Used to detect if the bot has already posted an alert to avoid spamming. +BOT_ALERT_SIGNATURE = ( + "**Notification:** The author has updated the issue description" +) + +# --- Global Cache --- +_MAINTAINERS_CACHE: Optional[List[str]] = None + + +def _get_cached_maintainers() -> List[str]: + """ + Fetches the list of repository maintainers. + + This function relies on `utils.get_request` for network resilience. + `get_request` is configured with an HTTPAdapter that automatically performs + exponential backoff retries (up to 6 times) for 5xx errors and rate limits. + + If the retries are exhausted or the data format is invalid, this function + raises a RuntimeError to prevent the bot from running with incorrect permissions. + + Returns: + List[str]: A list of GitHub usernames with push access. + + Raises: + RuntimeError: If the API fails after all retries or returns invalid data. + """ + global _MAINTAINERS_CACHE + if _MAINTAINERS_CACHE is not None: + return _MAINTAINERS_CACHE + + logger.info("Initializing Maintainers Cache...") + + try: + url = f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/collaborators" + params = {"permission": "push"} + + data = get_request(url, params) + + if isinstance(data, list): + _MAINTAINERS_CACHE = [u["login"] for u in data if "login" in u] + logger.info(f"Cached {len(_MAINTAINERS_CACHE)} maintainers.") + return _MAINTAINERS_CACHE + else: + logger.error( + f"Invalid API response format: Expected list, got {type(data)}" + ) + raise ValueError(f"GitHub API returned non-list data: {data}") + + except Exception as e: + logger.critical( + f"FATAL: Failed to verify repository maintainers. Error: {e}" + ) + raise RuntimeError( + "Maintainer verification failed. processing aborted." + ) from e + + +def load_prompt_template(filename: str) -> str: + """ + Loads the raw text content of a prompt file. + + Args: + filename (str): The name of the file (e.g., 'PROMPT_INSTRUCTION.txt'). + + Returns: + str: The file content. + """ + file_path = os.path.join(os.path.dirname(__file__), filename) + with open(file_path, "r") as f: + return f.read() + + +PROMPT_TEMPLATE = load_prompt_template("PROMPT_INSTRUCTION.txt") + + +def _fetch_graphql_data(item_number: int) -> Dict[str, Any]: + """ + Executes the GraphQL query to fetch raw issue data, including comments, + edits, and timeline events. + + Args: + item_number (int): The GitHub issue number. + + Returns: + Dict[str, Any]: The raw 'issue' object from the GraphQL response. + + Raises: + RequestException: If the GraphQL query returns errors or the issue is not found. + """ + query = """ + query($owner: String!, $name: String!, $number: Int!, $commentLimit: Int!, $timelineLimit: Int!, $editLimit: Int!) { + repository(owner: $owner, name: $name) { + issue(number: $number) { + author { login } + createdAt + labels(first: 20) { nodes { name } } + + comments(last: $commentLimit) { + nodes { + author { login } + body + createdAt + lastEditedAt + } + } + + userContentEdits(last: $editLimit) { + nodes { + editor { login } + editedAt + } + } + + timelineItems(itemTypes: [LABELED_EVENT, RENAMED_TITLE_EVENT, REOPENED_EVENT], last: $timelineLimit) { + nodes { + __typename + ... on LabeledEvent { + createdAt + actor { login } + label { name } + } + ... on RenamedTitleEvent { + createdAt + actor { login } + } + ... on ReopenedEvent { + createdAt + actor { login } + } + } + } + } + } + } + """ + + variables = { + "owner": OWNER, + "name": REPO, + "number": item_number, + "commentLimit": GRAPHQL_COMMENT_LIMIT, + "editLimit": GRAPHQL_EDIT_LIMIT, + "timelineLimit": GRAPHQL_TIMELINE_LIMIT, + } + + response = post_request( + f"{GITHUB_BASE_URL}/graphql", {"query": query, "variables": variables} + ) + + if "errors" in response: + raise RequestException(f"GraphQL Error: {response['errors'][0]['message']}") + + data = response.get("data", {}).get("repository", {}).get("issue", {}) + if not data: + raise RequestException(f"Issue #{item_number} not found.") + + return data + + +def _build_history_timeline( + data: Dict[str, Any], +) -> Tuple[List[Dict[str, Any]], List[datetime], Optional[datetime]]: + """ + Parses raw GraphQL data into a unified, chronologically sorted history list. + Also extracts specific event times needed for logic checks. + + Args: + data (Dict[str, Any]): The raw issue data from `_fetch_graphql_data`. + + Returns: + Tuple[List[Dict], List[datetime], Optional[datetime]]: + - history: A list of normalized event dictionaries sorted by time. + - label_events: A list of timestamps when the stale label was applied. + - last_bot_alert_time: Timestamp of the last bot silent-edit alert (if any). + """ + issue_author = data.get("author", {}).get("login") + history = [] + label_events = [] + last_bot_alert_time = None + + # 1. Baseline: Issue Creation + history.append({ + "type": "created", + "actor": issue_author, + "time": dateutil.parser.isoparse(data["createdAt"]), + "data": None, + }) + + # 2. Process Comments + for c in data.get("comments", {}).get("nodes", []): + if not c: + continue + + actor = c.get("author", {}).get("login") + c_body = c.get("body", "") + c_time = dateutil.parser.isoparse(c.get("createdAt")) + + # Track bot alerts for spam prevention + if BOT_ALERT_SIGNATURE in c_body: + if last_bot_alert_time is None or c_time > last_bot_alert_time: + last_bot_alert_time = c_time + + if actor and not actor.endswith("[bot]"): + # Use edit time if available, otherwise creation time + e_time = c.get("lastEditedAt") + actual_time = dateutil.parser.isoparse(e_time) if e_time else c_time + history.append({ + "type": "commented", + "actor": actor, + "time": actual_time, + "data": c_body, + }) + + # 3. Process Body Edits ("Ghost Edits") + for e in data.get("userContentEdits", {}).get("nodes", []): + if not e: + continue + actor = e.get("editor", {}).get("login") + if actor and not actor.endswith("[bot]"): + history.append({ + "type": "edited_description", + "actor": actor, + "time": dateutil.parser.isoparse(e.get("editedAt")), + "data": None, + }) + + # 4. Process Timeline Events + for t in data.get("timelineItems", {}).get("nodes", []): + if not t: + continue + + etype = t.get("__typename") + actor = t.get("actor", {}).get("login") + time_val = dateutil.parser.isoparse(t.get("createdAt")) + + if etype == "LabeledEvent": + if t.get("label", {}).get("name") == STALE_LABEL_NAME: + label_events.append(time_val) + continue + + if actor and not actor.endswith("[bot]"): + pretty_type = ( + "renamed_title" if etype == "RenamedTitleEvent" else "reopened" + ) + history.append({ + "type": pretty_type, + "actor": actor, + "time": time_val, + "data": None, + }) + + # Sort chronologically + history.sort(key=lambda x: x["time"]) + return history, label_events, last_bot_alert_time + + +def _replay_history_to_find_state( + history: List[Dict[str, Any]], maintainers: List[str], issue_author: str +) -> Dict[str, Any]: + """ + Replays the unified event history to determine the absolute last actor and their role. + + Args: + history (List[Dict]): Chronologically sorted list of events. + maintainers (List[str]): List of maintainer usernames. + issue_author (str): Username of the issue author. + + Returns: + Dict[str, Any]: A dictionary containing the last state of the issue: + - last_action_role (str): 'author', 'maintainer', or 'other_user'. + - last_activity_time (datetime): Timestamp of the last human action. + - last_action_type (str): The type of the last action (e.g., 'commented'). + - last_comment_text (Optional[str]): The text of the last comment. + """ + last_action_role = "author" + last_activity_time = history[0]["time"] + last_action_type = "created" + last_comment_text = None + + for event in history: + actor = event["actor"] + etype = event["type"] + + role = "other_user" + if actor == issue_author: + role = "author" + elif actor in maintainers: + role = "maintainer" + + last_action_role = role + last_activity_time = event["time"] + last_action_type = etype + + # Only store text if it was a comment (resets on other events like labels/edits) + if etype == "commented": + last_comment_text = event["data"] + else: + last_comment_text = None + + return { + "last_action_role": last_action_role, + "last_activity_time": last_activity_time, + "last_action_type": last_action_type, + "last_comment_text": last_comment_text, + } + + +def get_issue_state(item_number: int) -> Dict[str, Any]: + """ + Retrieves the comprehensive state of a GitHub issue using GraphQL. + + This function orchestrates the fetching, parsing, and analysis of the issue's + history to determine if it is stale, active, or pending maintainer review. + + Args: + item_number (int): The GitHub issue number. + + Returns: + Dict[str, Any]: A comprehensive state dictionary for the LLM agent. + Contains keys such as 'last_action_role', 'is_stale', 'days_since_activity', + and 'maintainer_alert_needed'. + """ + try: + maintainers = _get_cached_maintainers() + + # 1. Fetch + raw_data = _fetch_graphql_data(item_number) + + issue_author = raw_data.get("author", {}).get("login") + labels_list = [ + l["name"] for l in raw_data.get("labels", {}).get("nodes", []) + ] + + # 2. Parse & Sort + history, label_events, last_bot_alert_time = _build_history_timeline( + raw_data + ) + + # 3. Analyze (Replay) + state = _replay_history_to_find_state(history, maintainers, issue_author) + + # 4. Final Calculations & Alert Logic + current_time = datetime.now(timezone.utc) + days_since_activity = ( + current_time - state["last_activity_time"] + ).total_seconds() / 86400 + + # Stale Checks + is_stale = STALE_LABEL_NAME in labels_list + days_since_stale_label = 0.0 + if is_stale and label_events: + latest_label_time = max(label_events) + days_since_stale_label = ( + current_time - latest_label_time + ).total_seconds() / 86400 + + # Silent Edit Alert Logic + maintainer_alert_needed = False + if ( + state["last_action_role"] in ["author", "other_user"] + and state["last_action_type"] == "edited_description" + ): + if ( + last_bot_alert_time + and last_bot_alert_time > state["last_activity_time"] + ): + logger.info( + f"#{item_number}: Silent edit detected, but Bot already alerted. No" + " spam." + ) + else: + maintainer_alert_needed = True + logger.info(f"#{item_number}: Silent edit detected. Alert needed.") + + logger.debug( + f"#{item_number} VERDICT: Role={state['last_action_role']}, " + f"Idle={days_since_activity:.2f}d" + ) + + return { + "status": "success", + "last_action_role": state["last_action_role"], + "last_action_type": state["last_action_type"], + "maintainer_alert_needed": maintainer_alert_needed, + "is_stale": is_stale, + "days_since_activity": days_since_activity, + "days_since_stale_label": days_since_stale_label, + "last_comment_text": state["last_comment_text"], + "current_labels": labels_list, + "stale_threshold_days": STALE_HOURS_THRESHOLD / 24, + "close_threshold_days": CLOSE_HOURS_AFTER_STALE_THRESHOLD / 24, + } + + except RequestException as e: + return error_response(f"Network Error: {e}") + except Exception as e: + logger.error( + f"Unexpected error analyzing #{item_number}: {e}", exc_info=True + ) + return error_response(f"Analysis Error: {e}") + + +# --- Tool Definitions --- + + +def _format_days(hours: float) -> str: + """ + Formats a duration in hours into a clean day string. + + Example: + 168.0 -> "7" + 12.0 -> "0.5" + """ + days = hours / 24 + return f"{days:.1f}" if days % 1 != 0 else f"{int(days)}" + + +def add_label_to_issue(item_number: int, label_name: str) -> dict[str, Any]: + """ + Adds a label to the issue. + + Args: + item_number (int): The GitHub issue number. + label_name (str): The name of the label to add. + """ + logger.debug(f"Adding label '{label_name}' to issue #{item_number}.") + url = f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{item_number}/labels" + try: + post_request(url, [label_name]) + return {"status": "success"} + except RequestException as e: + return error_response(f"Error adding label: {e}") + + +def remove_label_from_issue( + item_number: int, label_name: str +) -> dict[str, Any]: + """ + Removes a label from the issue. + + Args: + item_number (int): The GitHub issue number. + label_name (str): The name of the label to remove. + """ + logger.debug(f"Removing label '{label_name}' from issue #{item_number}.") + url = f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{item_number}/labels/{label_name}" + try: + delete_request(url) + return {"status": "success"} + except RequestException as e: + return error_response(f"Error removing label: {e}") + + +def add_stale_label_and_comment(item_number: int) -> dict[str, Any]: + """ + Marks the issue as stale with a comment and label. + + Args: + item_number (int): The GitHub issue number. + """ + stale_days_str = _format_days(STALE_HOURS_THRESHOLD) + close_days_str = _format_days(CLOSE_HOURS_AFTER_STALE_THRESHOLD) + + comment = ( + "This issue has been automatically marked as stale because it has not" + f" had recent activity for {stale_days_str} days after a maintainer" + " requested clarification. It will be closed if no further activity" + f" occurs within {close_days_str} days." + ) + try: + post_request( + f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{item_number}/comments", + {"body": comment}, + ) + post_request( + f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{item_number}/labels", + [STALE_LABEL_NAME], + ) + return {"status": "success"} + except RequestException as e: + return error_response(f"Error marking issue as stale: {e}") + + +def alert_maintainer_of_edit(item_number: int) -> dict[str, Any]: + """ + Posts a comment alerting maintainers of a silent description update. + + Args: + item_number (int): The GitHub issue number. + """ + # Uses the constant signature to ensure detection logic in get_issue_state works. + comment = f"{BOT_ALERT_SIGNATURE}. Maintainers, please review." + try: + post_request( + f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{item_number}/comments", + {"body": comment}, + ) + return {"status": "success"} + except RequestException as e: + return error_response(f"Error posting alert: {e}") + + +def close_as_stale(item_number: int) -> dict[str, Any]: + """ + Closes the issue as not planned/stale. + + Args: + item_number (int): The GitHub issue number. + """ + days_str = _format_days(CLOSE_HOURS_AFTER_STALE_THRESHOLD) + + comment = ( + "This has been automatically closed because it has been marked as stale" + f" for over {days_str} days." + ) + try: + post_request( + f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{item_number}/comments", + {"body": comment}, + ) + patch_request( + f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{item_number}", + {"state": "closed"}, + ) + return {"status": "success"} + except RequestException as e: + return error_response(f"Error closing issue: {e}") + + +root_agent = Agent( + model=LLM_MODEL_NAME, + name="adk_repository_auditor_agent", + description="Audits open issues.", + instruction=PROMPT_TEMPLATE.format( + OWNER=OWNER, + REPO=REPO, + STALE_LABEL_NAME=STALE_LABEL_NAME, + REQUEST_CLARIFICATION_LABEL=REQUEST_CLARIFICATION_LABEL, + stale_threshold_days=STALE_HOURS_THRESHOLD / 24, + close_threshold_days=CLOSE_HOURS_AFTER_STALE_THRESHOLD / 24, + ), + tools=[ + add_label_to_issue, + add_stale_label_and_comment, + alert_maintainer_of_edit, + close_as_stale, + get_issue_state, + remove_label_from_issue, + ], +) diff --git a/contributing/samples/adk_stale_agent/main.py b/contributing/samples/adk_stale_agent/main.py new file mode 100644 index 0000000000..d4fe58dd63 --- /dev/null +++ b/contributing/samples/adk_stale_agent/main.py @@ -0,0 +1,195 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import time +from typing import Tuple + +from adk_stale_agent.agent import root_agent +from adk_stale_agent.settings import CONCURRENCY_LIMIT +from adk_stale_agent.settings import OWNER +from adk_stale_agent.settings import REPO +from adk_stale_agent.settings import SLEEP_BETWEEN_CHUNKS +from adk_stale_agent.settings import STALE_HOURS_THRESHOLD +from adk_stale_agent.utils import get_api_call_count +from adk_stale_agent.utils import get_old_open_issue_numbers +from adk_stale_agent.utils import reset_api_call_count +from google.adk.cli.utils import logs +from google.adk.runners import InMemoryRunner +from google.genai import types + +logs.setup_adk_logger(level=logging.INFO) +logger = logging.getLogger("google_adk." + __name__) + +APP_NAME = "stale_bot_app" +USER_ID = "stale_bot_user" + + +async def process_single_issue(issue_number: int) -> Tuple[float, int]: + """ + Processes a single GitHub issue using the AI agent and logs execution metrics. + + Args: + issue_number (int): The GitHub issue number to audit. + + Returns: + Tuple[float, int]: A tuple containing: + - duration (float): Time taken to process the issue in seconds. + - api_calls (int): The number of API calls made during this specific execution. + + Raises: + Exception: catches generic exceptions to prevent one failure from stopping the batch. + """ + start_time = time.perf_counter() + + start_api_calls = get_api_call_count() + + logger.info(f"Processing Issue #{issue_number}...") + logger.debug(f"#{issue_number}: Initializing runner and session.") + + try: + runner = InMemoryRunner(agent=root_agent, app_name=APP_NAME) + session = await runner.session_service.create_session( + user_id=USER_ID, app_name=APP_NAME + ) + + prompt_text = f"Audit Issue #{issue_number}." + prompt_message = types.Content( + role="user", parts=[types.Part(text=prompt_text)] + ) + + logger.debug(f"#{issue_number}: Sending prompt to agent.") + + async for event in runner.run_async( + user_id=USER_ID, session_id=session.id, new_message=prompt_message + ): + if ( + event.content + and event.content.parts + and hasattr(event.content.parts[0], "text") + ): + text = event.content.parts[0].text + if text: + clean_text = text[:150].replace("\n", " ") + logger.info(f"#{issue_number} Decision: {clean_text}...") + + except Exception as e: + logger.error(f"Error processing issue #{issue_number}: {e}", exc_info=True) + + duration = time.perf_counter() - start_time + + end_api_calls = get_api_call_count() + issue_api_calls = end_api_calls - start_api_calls + + logger.info( + f"Issue #{issue_number} finished in {duration:.2f}s " + f"with ~{issue_api_calls} API calls." + ) + + return duration, issue_api_calls + + +async def main(): + """ + Main entry point to run the stale issue bot concurrently. + + Fetches old issues and processes them in batches to respect API rate limits + and concurrency constraints. + """ + logger.info(f"--- Starting Stale Bot for {OWNER}/{REPO} ---") + logger.info(f"Concurrency level set to {CONCURRENCY_LIMIT}") + + reset_api_call_count() + + filter_days = STALE_HOURS_THRESHOLD / 24 + logger.debug(f"Fetching issues older than {filter_days:.2f} days...") + + try: + all_issues = get_old_open_issue_numbers(OWNER, REPO, days_old=filter_days) + except Exception as e: + logger.critical(f"Failed to fetch issue list: {e}", exc_info=True) + return + + total_count = len(all_issues) + + search_api_calls = get_api_call_count() + + if total_count == 0: + logger.info("No issues matched the criteria. Run finished.") + return + + logger.info( + f"Found {total_count} issues to process. " + f"(Initial search used {search_api_calls} API calls)." + ) + + total_processing_time = 0.0 + total_issue_api_calls = 0 + processed_count = 0 + + # Process the list in chunks of size CONCURRENCY_LIMIT + for i in range(0, total_count, CONCURRENCY_LIMIT): + chunk = all_issues[i : i + CONCURRENCY_LIMIT] + current_chunk_num = i // CONCURRENCY_LIMIT + 1 + + logger.info( + f"--- Starting chunk {current_chunk_num}: Processing issues {chunk} ---" + ) + + tasks = [process_single_issue(issue_num) for issue_num in chunk] + + results = await asyncio.gather(*tasks) + + for duration, api_calls in results: + total_processing_time += duration + total_issue_api_calls += api_calls + + processed_count += len(chunk) + logger.info( + f"--- Finished chunk {current_chunk_num}. Progress:" + f" {processed_count}/{total_count} ---" + ) + + if (i + CONCURRENCY_LIMIT) < total_count: + logger.debug( + f"Sleeping for {SLEEP_BETWEEN_CHUNKS}s to respect rate limits..." + ) + await asyncio.sleep(SLEEP_BETWEEN_CHUNKS) + + total_api_calls_for_run = search_api_calls + total_issue_api_calls + avg_time_per_issue = ( + total_processing_time / total_count if total_count > 0 else 0 + ) + + logger.info("--- Stale Agent Run Finished ---") + logger.info(f"Successfully processed {processed_count} issues.") + logger.info(f"Total API calls made this run: {total_api_calls_for_run}") + logger.info( + f"Average processing time per issue: {avg_time_per_issue:.2f} seconds." + ) + + +if __name__ == "__main__": + start_time = time.perf_counter() + + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.warning("Bot execution interrupted manually.") + except Exception as e: + logger.critical(f"Unexpected fatal error: {e}", exc_info=True) + + duration = time.perf_counter() - start_time + logger.info(f"Full audit finished in {duration/60:.2f} minutes.") diff --git a/contributing/samples/adk_stale_agent/settings.py b/contributing/samples/adk_stale_agent/settings.py new file mode 100644 index 0000000000..599c6ef2ea --- /dev/null +++ b/contributing/samples/adk_stale_agent/settings.py @@ -0,0 +1,63 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from dotenv import load_dotenv + +# Load environment variables from a .env file for local testing +load_dotenv(override=True) + +# --- GitHub API Configuration --- +GITHUB_BASE_URL = "https://api.github.com" +GITHUB_TOKEN = os.getenv("GITHUB_TOKEN") +if not GITHUB_TOKEN: + raise ValueError("GITHUB_TOKEN environment variable not set") + +OWNER = os.getenv("OWNER", "google") +REPO = os.getenv("REPO", "adk-python") +LLM_MODEL_NAME = os.getenv("LLM_MODEL_NAME", "gemini-2.5-flash") + +STALE_LABEL_NAME = "stale" +REQUEST_CLARIFICATION_LABEL = "request clarification" + +# --- THRESHOLDS IN HOURS --- +# Default: 168 hours (7 days) +# The number of hours of inactivity after a maintainer comment before an issue is marked as stale. +STALE_HOURS_THRESHOLD = float(os.getenv("STALE_HOURS_THRESHOLD", 168)) + +# Default: 168 hours (7 days) +# The number of hours of inactivity after an issue is marked 'stale' before it is closed. +CLOSE_HOURS_AFTER_STALE_THRESHOLD = float( + os.getenv("CLOSE_HOURS_AFTER_STALE_THRESHOLD", 168) +) + +# --- Performance Configuration --- +# The number of issues to process concurrently. +# Higher values are faster but increase the immediate rate of API calls +CONCURRENCY_LIMIT = int(os.getenv("CONCURRENCY_LIMIT", 3)) + +# --- GraphQL Query Limits --- +# The number of most recent comments to fetch for context analysis. +GRAPHQL_COMMENT_LIMIT = int(os.getenv("GRAPHQL_COMMENT_LIMIT", 30)) + +# The number of most recent description edits to fetch. +GRAPHQL_EDIT_LIMIT = int(os.getenv("GRAPHQL_EDIT_LIMIT", 10)) + +# The number of most recent timeline events (labels, renames, reopens) to fetch. +GRAPHQL_TIMELINE_LIMIT = int(os.getenv("GRAPHQL_TIMELINE_LIMIT", 20)) + +# --- Rate Limiting --- +# Time in seconds to wait between processing chunks. +SLEEP_BETWEEN_CHUNKS = float(os.getenv("SLEEP_BETWEEN_CHUNKS", 1.5)) diff --git a/contributing/samples/adk_stale_agent/utils.py b/contributing/samples/adk_stale_agent/utils.py new file mode 100644 index 0000000000..a396c22ac7 --- /dev/null +++ b/contributing/samples/adk_stale_agent/utils.py @@ -0,0 +1,260 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime +from datetime import timedelta +from datetime import timezone +import logging +import threading +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +from adk_stale_agent.settings import GITHUB_TOKEN +from adk_stale_agent.settings import STALE_HOURS_THRESHOLD +import dateutil.parser +import requests +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry + +logger = logging.getLogger("google_adk." + __name__) + +# --- API Call Counter for Monitoring --- +_api_call_count = 0 +_counter_lock = threading.Lock() + + +def get_api_call_count() -> int: + """ + Returns the total number of API calls made since the last reset. + + Returns: + int: The global count of API calls. + """ + with _counter_lock: + return _api_call_count + + +def reset_api_call_count() -> None: + """Resets the global API call counter to zero.""" + global _api_call_count + with _counter_lock: + _api_call_count = 0 + + +def _increment_api_call_count() -> None: + """ + Atomically increments the global API call counter. + Required because the agent may run tools in parallel threads. + """ + global _api_call_count + with _counter_lock: + _api_call_count += 1 + + +# --- Production-Ready HTTP Session with Exponential Backoff --- + +# Configure the retry strategy: +retry_strategy = Retry( + total=6, + backoff_factor=2, + status_forcelist=[429, 500, 502, 503, 504], + allowed_methods=[ + "HEAD", + "GET", + "POST", + "PUT", + "DELETE", + "OPTIONS", + "TRACE", + "PATCH", + ], +) + +adapter = HTTPAdapter(max_retries=retry_strategy) + +# Create a single, reusable Session object for connection pooling +_session = requests.Session() +_session.mount("https://", adapter) +_session.mount("http://", adapter) + +_session.headers.update({ + "Authorization": f"token {GITHUB_TOKEN}", + "Accept": "application/vnd.github.v3+json", +}) + + +def get_request(url: str, params: Optional[Dict[str, Any]] = None) -> Any: + """ + Sends a GET request to the GitHub API with automatic retries. + + Args: + url (str): The URL endpoint. + params (Optional[Dict[str, Any]]): Query parameters. + + Returns: + Any: The JSON response parsed into a dict or list. + + Raises: + requests.exceptions.RequestException: If retries are exhausted. + """ + _increment_api_call_count() + try: + response = _session.get(url, params=params or {}, timeout=60) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + logger.error(f"GET request failed for {url}: {e}") + raise + + +def post_request(url: str, payload: Any) -> Any: + """ + Sends a POST request to the GitHub API with automatic retries. + + Args: + url (str): The URL endpoint. + payload (Any): The JSON payload. + + Returns: + Any: The JSON response. + """ + _increment_api_call_count() + try: + response = _session.post(url, json=payload, timeout=60) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + logger.error(f"POST request failed for {url}: {e}") + raise + + +def patch_request(url: str, payload: Any) -> Any: + """ + Sends a PATCH request to the GitHub API with automatic retries. + + Args: + url (str): The URL endpoint. + payload (Any): The JSON payload. + + Returns: + Any: The JSON response. + """ + _increment_api_call_count() + try: + response = _session.patch(url, json=payload, timeout=60) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + logger.error(f"PATCH request failed for {url}: {e}") + raise + + +def delete_request(url: str) -> Any: + """ + Sends a DELETE request to the GitHub API with automatic retries. + + Args: + url (str): The URL endpoint. + + Returns: + Any: A success dict if 204, else the JSON response. + """ + _increment_api_call_count() + try: + response = _session.delete(url, timeout=60) + response.raise_for_status() + if response.status_code == 204: + return {"status": "success", "message": "Deletion successful."} + return response.json() + except requests.exceptions.RequestException as e: + logger.error(f"DELETE request failed for {url}: {e}") + raise + + +def error_response(error_message: str) -> Dict[str, Any]: + """ + Creates a standardized error response dictionary for tool outputs. + + Args: + error_message (str): The error details. + + Returns: + Dict[str, Any]: Standardized error object. + """ + return {"status": "error", "message": error_message} + + +def get_old_open_issue_numbers( + owner: str, repo: str, days_old: Optional[float] = None +) -> List[int]: + """ + Finds open issues older than the specified threshold using server-side filtering. + + OPTIMIZATION: + Instead of fetching ALL issues and filtering in Python (which wastes API calls), + this uses the GitHub Search API `created: dict[str, Any]: - """List most recent `issue_count` number of unlabeled issues in the repo. +def list_untriaged_issues(issue_count: int) -> dict[str, Any]: + """List open issues that need triaging. + + Returns issues that need any of the following actions: + 1. Issues without component labels (need labeling + type setting) + 2. Issues with 'planned' label but no assignee (need owner assignment) Args: issue_count: number of issues to return Returns: The status of this request, with a list of issues when successful. + Each issue includes flags indicating what actions are needed. """ url = f"{GITHUB_BASE_URL}/search/issues" - query = f"repo:{OWNER}/{REPO} is:open is:issue no:label" + query = f"repo:{OWNER}/{REPO} is:open is:issue" params = { "q": query, "sort": "created", "order": "desc", - "per_page": issue_count, + "per_page": 100, # Fetch more to filter "page": 1, } @@ -101,27 +109,48 @@ def list_unlabeled_issues(issue_count: int) -> dict[str, Any]: response = get_request(url, params) except requests.exceptions.RequestException as e: return error_response(f"Error: {e}") - issues = response.get("items", None) + issues = response.get("items", []) - unlabeled_issues = [] + component_labels = set(LABEL_TO_OWNER.keys()) + untriaged_issues = [] for issue in issues: - if not issue.get("labels", None): - unlabeled_issues.append(issue) - return {"status": "success", "issues": unlabeled_issues} - - -def add_label_and_owner_to_issue( - issue_number: int, label: str -) -> dict[str, Any]: - """Add the specified label and owner to the given issue number. + issue_labels = {label["name"] for label in issue.get("labels", [])} + assignees = issue.get("assignees", []) + + existing_component_labels = issue_labels & component_labels + has_component = bool(existing_component_labels) + has_planned = "planned" in issue_labels + + # Determine what actions are needed + needs_component_label = not has_component + needs_owner = has_planned and not assignees + + # Include issue if it needs any action + if needs_component_label or needs_owner: + issue["has_planned_label"] = has_planned + issue["has_component_label"] = has_component + issue["existing_component_label"] = ( + list(existing_component_labels)[0] + if existing_component_labels + else None + ) + issue["needs_component_label"] = needs_component_label + issue["needs_owner"] = needs_owner + untriaged_issues.append(issue) + if len(untriaged_issues) >= issue_count: + break + return {"status": "success", "issues": untriaged_issues} + + +def add_label_to_issue(issue_number: int, label: str) -> dict[str, Any]: + """Add the specified component label to the given issue number. Args: issue_number: issue number of the GitHub issue. label: label to assign Returns: - The the status of this request, with the applied label and assigned owner - when successful. + The status of this request, with the applied label when successful. """ print(f"Attempting to add label '{label}' to issue #{issue_number}") if label not in LABEL_TO_OWNER: @@ -139,15 +168,38 @@ def add_label_and_owner_to_issue( except requests.exceptions.RequestException as e: return error_response(f"Error: {e}") + return { + "status": "success", + "message": response, + "applied_label": label, + } + + +def add_owner_to_issue(issue_number: int, label: str) -> dict[str, Any]: + """Assign an owner to the issue based on the component label. + + This should only be called for issues that have the 'planned' label. + + Args: + issue_number: issue number of the GitHub issue. + label: component label that determines the owner to assign + + Returns: + The status of this request, with the assigned owner when successful. + """ + print( + f"Attempting to assign owner for label '{label}' to issue #{issue_number}" + ) + if label not in LABEL_TO_OWNER: + return error_response( + f"Error: Label '{label}' is not a valid component label." + ) + owner = LABEL_TO_OWNER.get(label, None) if not owner: return { "status": "warning", - "message": ( - f"{response}\n\nLabel '{label}' does not have an owner. Will not" - " assign." - ), - "applied_label": label, + "message": f"Label '{label}' does not have an owner. Will not assign.", } assignee_url = ( @@ -163,7 +215,6 @@ def add_label_and_owner_to_issue( return { "status": "success", "message": response, - "applied_label": label, "assigned_owner": owner, } @@ -217,31 +268,50 @@ def change_issue_type(issue_number: int, issue_type: str) -> dict[str, Any]: - If it's about Model Context Protocol (e.g. MCP tool, MCP toolset, MCP session management etc.), label it with both "mcp" and "tools". - If it's about A2A integrations or workflows, label it with "a2a". - If it's about BigQuery integrations, label it with "bq". + - If it's about workflow agents or workflow execution, label it with "workflow". + - If it's about authentication, label it with "auth". - If you can't find an appropriate labels for the issue, follow the previous instruction that starts with "IMPORTANT:". - Call the `add_label_and_owner_to_issue` tool to label the issue, which will also assign the issue to the owner of the label. + ## Triaging Workflow + + Each issue will have flags indicating what actions are needed: + - `needs_component_label`: true if the issue needs a component label + - `needs_owner`: true if the issue needs an owner assigned (has 'planned' label but no assignee) + + For each issue, perform ONLY the required actions based on the flags: + + 1. **If `needs_component_label` is true**: + - Use `add_label_to_issue` to add the appropriate component label + - Use `change_issue_type` to set the issue type: + - Bug report → "Bug" + - Feature request → "Feature" + - Otherwise → do not change the issue type + + 2. **If `needs_owner` is true**: + - Use `add_owner_to_issue` to assign an owner based on the component label + - Note: If the issue already has a component label (`has_component_label: true`), use that existing label to determine the owner - After you label the issue, call the `change_issue_type` tool to change the issue type: - - If the issue is a bug report, change the issue type to "Bug". - - If the issue is a feature request, change the issue type to "Feature". - - Otherwise, **do not change the issue type**. + Do NOT add a component label if `needs_component_label` is false. + Do NOT assign an owner if `needs_owner` is false. Response quality requirements: - Summarize the issue in your own words without leaving template placeholders (never output text like "[fill in later]"). - Justify the chosen label with a short explanation referencing the issue details. - - Mention the assigned owner when a label maps to one. + - Mention the assigned owner only when you actually assign one (i.e., when + the issue has the 'planned' label). - If no label is applied, clearly state why. Present the following in an easy to read format highlighting issue number and your label. - the issue summary in a few sentence - your label recommendation and justification - - the owner of the label if you assign the issue to an owner + - the owner of the label if you assign the issue to an owner (only for planned issues) """, tools=[ - list_unlabeled_issues, - add_label_and_owner_to_issue, + list_untriaged_issues, + add_label_to_issue, + add_owner_to_issue, change_issue_type, ], ) diff --git a/contributing/samples/adk_triaging_agent/main.py b/contributing/samples/adk_triaging_agent/main.py index 317f5893e2..3a2d4da570 100644 --- a/contributing/samples/adk_triaging_agent/main.py +++ b/contributing/samples/adk_triaging_agent/main.py @@ -16,6 +16,7 @@ import time from adk_triaging_agent import agent +from adk_triaging_agent.agent import LABEL_TO_OWNER from adk_triaging_agent.settings import EVENT_NAME from adk_triaging_agent.settings import GITHUB_BASE_URL from adk_triaging_agent.settings import ISSUE_BODY @@ -37,21 +38,49 @@ async def fetch_specific_issue_details(issue_number: int): - """Fetches details for a single issue if it's unlabelled.""" + """Fetches details for a single issue if it needs triaging.""" url = f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{issue_number}" print(f"Fetching details for specific issue: {url}") try: issue_data = get_request(url) - if not issue_data.get("labels", None): - print(f"Issue #{issue_number} is unlabelled. Proceeding.") + labels = issue_data.get("labels", []) + label_names = {label["name"] for label in labels} + assignees = issue_data.get("assignees", []) + + # Check issue state + component_labels = set(LABEL_TO_OWNER.keys()) + has_planned = "planned" in label_names + existing_component_labels = label_names & component_labels + has_component = bool(existing_component_labels) + has_assignee = len(assignees) > 0 + + # Determine what actions are needed + needs_component_label = not has_component + needs_owner = has_planned and not has_assignee + + if needs_component_label or needs_owner: + print( + f"Issue #{issue_number} needs triaging. " + f"needs_component_label={needs_component_label}, " + f"needs_owner={needs_owner}" + ) return { "number": issue_data["number"], "title": issue_data["title"], "body": issue_data.get("body", ""), + "has_planned_label": has_planned, + "has_component_label": has_component, + "existing_component_label": ( + list(existing_component_labels)[0] + if existing_component_labels + else None + ), + "needs_component_label": needs_component_label, + "needs_owner": needs_owner, } else: - print(f"Issue #{issue_number} is already labelled. Skipping.") + print(f"Issue #{issue_number} is already fully triaged. Skipping.") return None except requests.exceptions.RequestException as e: print(f"Error fetching issue #{issue_number}: {e}") @@ -108,26 +137,32 @@ async def main(): specific_issue = await fetch_specific_issue_details(issue_number) if specific_issue is None: print( - f"No unlabelled issue details found for #{issue_number} or an error" - " occurred. Skipping agent interaction." + f"No issue details found for #{issue_number} that needs triaging," + " or an error occurred. Skipping agent interaction." ) return issue_title = ISSUE_TITLE or specific_issue["title"] issue_body = ISSUE_BODY or specific_issue["body"] + needs_component_label = specific_issue.get("needs_component_label", True) + needs_owner = specific_issue.get("needs_owner", False) + existing_component_label = specific_issue.get("existing_component_label") + prompt = ( - f"A new GitHub issue #{issue_number} has been opened or" - f' reopened. Title: "{issue_title}"\nBody:' - f' "{issue_body}"\n\nBased on the rules, recommend an' - " appropriate label and its justification." - " Then, use the 'add_label_to_issue' tool to apply the label " - "directly to this issue. Only label it, do not" - " process any other issues." + f"Triage GitHub issue #{issue_number}.\n\n" + f'Title: "{issue_title}"\n' + f'Body: "{issue_body}"\n\n' + f"Issue state: needs_component_label={needs_component_label}, " + f"needs_owner={needs_owner}, " + f"existing_component_label={existing_component_label}" ) else: print(f"EVENT: Processing batch of issues (event: {EVENT_NAME}).") issue_count = parse_number_string(ISSUE_COUNT_TO_PROCESS, default_value=3) - prompt = f"Please triage the most recent {issue_count} issues." + prompt = ( + f"Please use 'list_untriaged_issues' to find {issue_count} issues that" + " need triaging, then triage each one according to your instructions." + ) response = await call_agent_async(runner, USER_ID, session.id, prompt) print(f"<<<< Agent Final Output: {response}\n") diff --git a/contributing/samples/api_registry_agent/README.md b/contributing/samples/api_registry_agent/README.md new file mode 100644 index 0000000000..78b3c22382 --- /dev/null +++ b/contributing/samples/api_registry_agent/README.md @@ -0,0 +1,21 @@ +# BigQuery API Registry Agent + +This agent demonstrates how to use `ApiRegistry` to discover and interact with Google Cloud services like BigQuery via tools exposed by an MCP server registered in an API Registry. + +## Prerequisites + +- A Google Cloud project with the API Registry API enabled. +- An MCP server exposing BigQuery tools registered in API Registry. + +## Configuration & Running + +1. **Configure:** Edit `agent.py` and replace `your-google-cloud-project-id` and `your-mcp-server-name` with your Google Cloud Project ID and the name of your registered MCP server. +2. **Run in CLI:** + ```bash + adk run contributing/samples/api_registry_agent -- --log-level DEBUG + ``` +3. **Run in Web UI:** + ```bash + adk web contributing/samples/ + ``` + Navigate to `http://127.0.0.1:8080` and select the `api_registry_agent` agent. diff --git a/contributing/samples/api_registry_agent/__init__.py b/contributing/samples/api_registry_agent/__init__.py new file mode 100644 index 0000000000..c48963cdc7 --- /dev/null +++ b/contributing/samples/api_registry_agent/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/api_registry_agent/agent.py b/contributing/samples/api_registry_agent/agent.py new file mode 100644 index 0000000000..6504822092 --- /dev/null +++ b/contributing/samples/api_registry_agent/agent.py @@ -0,0 +1,39 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from google.adk.agents.llm_agent import LlmAgent +from google.adk.tools.api_registry import ApiRegistry + +# TODO: Fill in with your GCloud project id and MCP server name +PROJECT_ID = "your-google-cloud-project-id" +MCP_SERVER_NAME = "your-mcp-server-name" + +# Header required for BigQuery MCP server +header_provider = lambda context: { + "x-goog-user-project": PROJECT_ID, +} +api_registry = ApiRegistry(PROJECT_ID, header_provider=header_provider) +registry_tools = api_registry.get_toolset( + mcp_server_name=MCP_SERVER_NAME, +) +root_agent = LlmAgent( + model="gemini-2.0-flash", + name="bigquery_assistant", + instruction=""" +Help user access their BigQuery data via API Registry tools. + """, + tools=[registry_tools], +) diff --git a/contributing/samples/computer_use/README.md b/contributing/samples/computer_use/README.md index ff7ae6c0a5..38b6fe79c6 100644 --- a/contributing/samples/computer_use/README.md +++ b/contributing/samples/computer_use/README.md @@ -19,7 +19,7 @@ The computer use agent consists of: Install the required Python packages from the requirements file: ```bash -uv pip install -r internal/samples/computer_use/requirements.txt +uv pip install -r contributing/samples/computer_use/requirements.txt ``` ### 2. Install Playwright Dependencies @@ -45,7 +45,7 @@ playwright install chromium To start the computer use agent, run the following command from the project root: ```bash -adk web internal/samples +adk web contributing/samples ``` This will start the ADK web interface where you can interact with the computer_use agent. diff --git a/contributing/samples/hello_world_stream_fc_args/__init__.py b/contributing/samples/hello_world_stream_fc_args/__init__.py new file mode 100755 index 0000000000..c48963cdc7 --- /dev/null +++ b/contributing/samples/hello_world_stream_fc_args/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/hello_world_stream_fc_args/agent.py b/contributing/samples/hello_world_stream_fc_args/agent.py new file mode 100755 index 0000000000..f613842171 --- /dev/null +++ b/contributing/samples/hello_world_stream_fc_args/agent.py @@ -0,0 +1,55 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk import Agent +from google.genai import types + + +def concat_number_and_string(num: int, s: str) -> str: + """Concatenate a number and a string. + + Args: + num: The number to concatenate. + s: The string to concatenate. + + Returns: + The concatenated string. + """ + return str(num) + ': ' + s + + +root_agent = Agent( + model='gemini-3-pro-preview', + name='hello_world_stream_fc_args', + description='Demo agent showcasing streaming function call arguments.', + instruction=""" + You are a helpful assistant. + You can use the `concat_number_and_string` tool to concatenate a number and a string. + You should always call the concat_number_and_string tool to concatenate a number and a string. + You should never concatenate on your own. + """, + tools=[ + concat_number_and_string, + ], + generate_content_config=types.GenerateContentConfig( + automatic_function_calling=types.AutomaticFunctionCallingConfig( + disable=True, + ), + tool_config=types.ToolConfig( + function_calling_config=types.FunctionCallingConfig( + stream_function_call_arguments=True, + ), + ), + ), +) diff --git a/contributing/samples/live_bidi_streaming_single_agent/agent.py b/contributing/samples/live_bidi_streaming_single_agent/agent.py index c295adc136..9246fca9d5 100755 --- a/contributing/samples/live_bidi_streaming_single_agent/agent.py +++ b/contributing/samples/live_bidi_streaming_single_agent/agent.py @@ -65,8 +65,8 @@ async def check_prime(nums: list[int]) -> str: root_agent = Agent( - # model='gemini-live-2.5-flash-preview-native-audio-09-2025', # vertex - model='gemini-2.5-flash-native-audio-preview-09-2025', # for AI studio + model='gemini-live-2.5-flash-preview-native-audio-09-2025', # vertex + # model='gemini-2.5-flash-native-audio-preview-09-2025', # for AI studio # key name='roll_dice_agent', description=( diff --git a/contributing/samples/migrate_session_db/sample-output/alembic.ini b/contributing/samples/migrate_session_db/sample-output/alembic.ini index 6405320948..e346ee8ac6 100644 --- a/contributing/samples/migrate_session_db/sample-output/alembic.ini +++ b/contributing/samples/migrate_session_db/sample-output/alembic.ini @@ -21,7 +21,7 @@ prepend_sys_path = . # timezone to use when rendering the date within the migration file # as well as the filename. -# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library. +# If specified, requires the python>=3.10 and tzdata library. # Any required deps can installed by adding `alembic[tz]` to the pip requirements # string value is passed to ZoneInfo() # leave blank for localtime diff --git a/contributing/samples/spanner_rag_agent/README.md b/contributing/samples/spanner_rag_agent/README.md index 8148983736..99b60794fe 100644 --- a/contributing/samples/spanner_rag_agent/README.md +++ b/contributing/samples/spanner_rag_agent/README.md @@ -57,9 +57,9 @@ model endpoint. CREATE MODEL EmbeddingsModel INPUT( content STRING(MAX), ) OUTPUT( -embeddings STRUCT, values ARRAY>, +embeddings STRUCT>, ) REMOTE OPTIONS ( -endpoint = '//aiplatform.googleapis.com/projects//locations/us-central1/publishers/google/models/text-embedding-004' +endpoint = '//aiplatform.googleapis.com/projects//locations//publishers/google/models/text-embedding-005' ); ``` @@ -187,40 +187,203 @@ type. ## Which tool to use and When? -There are a few options to perform similarity search (see the `agent.py` for -implementation details): - -1. Wraps the built-in `similarity_search` in the Spanner Toolset. - - - This provides an easy and controlled way to perform similarity search. - You can specify different configurations related to vector search based - on your need without having to figure out all the details for a vector - search query. - -2. Wraps the built-in `execute_sql` in the Spanner Toolset. - - - `execute_sql` is a lower-level tool that you can have more control over - with. With the flexibility, you can specify a complicated (parameterized) - SQL query for your need, and let the `LlmAgent` pass the parameters. - -3. Use the Spanner Toolset (and all the tools that come with it) directly. - - - The most flexible and generic way. Instead of fixing configurations via - code, you can also specify the configurations via `instruction` to - the `LlmAgent` and let LLM to decide which tool to use and what parameters - to pass to different tools. It might even combine different tools together! - Note that in this usage, SQL generation is powered by the LlmAgent, which - can be more suitable for data analysis and assistant scenarios. - - To restrict the ability of an `LlmAgent`, `SpannerToolSet` also supports - `tool_filter` to explicitly specify allowed tools. As an example, the - following code specifies that only `execute_sql` and `get_table_schema` - are allowed: - - ```py - toolset = SpannerToolset( - credentials_config=credentials_config, - tool_filter=["execute_sql", "get_table_schema"], - spanner_tool_settings=SpannerToolSettings(), - ) - ``` - +There are a few options to perform similarity search: + +1. Use the built-in `vector_store_similarity_search` in the Spanner Toolset with explicit `SpannerVectorStoreSettings` configuration. + + - This provides an easy way to perform similarity search. You can specify + different configurations related to vector search based on your Spanner + database vector store table setup. + + Example pseudocode (see the `agent.py` for details): + + ```py + from google.adk.agents.llm_agent import LlmAgent + from google.adk.tools.spanner.settings import Capabilities + from google.adk.tools.spanner.settings import SpannerToolSettings + from google.adk.tools.spanner.settings import SpannerVectorStoreSettings + from google.adk.tools.spanner.spanner_toolset import SpannerToolset + + # credentials_config = SpannerCredentialsConfig(...) + + # Define Spanner tool config with the vector store settings. + vector_store_settings = SpannerVectorStoreSettings( + project_id="", + instance_id="", + database_id="", + table_name="products", + content_column="productDescription", + embedding_column="productDescriptionEmbedding", + vector_length=768, + vertex_ai_embedding_model_name="text-embedding-005", + selected_columns=[ + "productId", + "productName", + "productDescription", + ], + nearest_neighbors_algorithm="EXACT_NEAREST_NEIGHBORS", + top_k=3, + distance_type="COSINE", + additional_filter="inventoryCount > 0", + ) + + tool_settings = SpannerToolSettings( + capabilities=[Capabilities.DATA_READ], + vector_store_settings=vector_store_settings, + ) + + # Get the Spanner toolset with the Spanner tool settings and credentials config. + spanner_toolset = SpannerToolset( + credentials_config=credentials_config, + spanner_tool_settings=tool_settings, + # Use `vector_store_similarity_search` only + tool_filter=["vector_store_similarity_search"], + ) + + root_agent = LlmAgent( + model="gemini-2.5-flash", + name="spanner_knowledge_base_agent", + description=( + "Agent to answer questions about product-specific recommendations." + ), + instruction=""" + You are a helpful assistant that answers user questions about product-specific recommendations. + 1. Always use the `vector_store_similarity_search` tool to find relevant information. + 2. If no relevant information is found, say you don't know. + 3. Present all the relevant information naturally and well formatted in your response. + """, + tools=[spanner_toolset], + ) + ``` + +2. Use the built-in `similarity_search` in the Spanner Toolset. + + - `similarity_search` is a lower-level tool, which provide the most flexible + and generic way. Specify all the necessary tool's parameters is required + when interacting with `LlmAgent` before performing the tool call. This is + more suitable for data analysis, ad-hoc query and assistant scenarios. + + Example pseudocode: + + ```py + from google.adk.agents.llm_agent import LlmAgent + from google.adk.tools.spanner.settings import Capabilities + from google.adk.tools.spanner.settings import SpannerToolSettings + from google.adk.tools.spanner.spanner_toolset import SpannerToolset + + # credentials_config = SpannerCredentialsConfig(...) + + tool_settings = SpannerToolSettings( + capabilities=[Capabilities.DATA_READ], + ) + + spanner_toolset = SpannerToolset( + credentials_config=credentials_config, + spanner_tool_settings=tool_settings, + # Use `similarity_search` only + tool_filter=["similarity_search"], + ) + + root_agent = LlmAgent( + model="gemini-2.5-flash", + name="spanner_knowledge_base_agent", + description=( + "Agent to answer questions by retrieving relevant information " + "from the Spanner database." + ), + instruction=""" + You are a helpful assistant that answers user questions to find the most relavant information from a Spanner database. + 1. Always use the `similarity_search` tool to find relevant information. + 2. If no relevant information is found, say you don't know. + 3. Present all the relevant information naturally and well formatted in your response. + """, + tools=[spanner_toolset], + ) + ``` + +3. Wraps the built-in `similarity_search` in the Spanner Toolset. + + - This provides a more controlled way to perform similarity search via code. + You can extend the tool as a wrapped function tool to have customized logic. + + Example pseudocode: + + ```py + from google.adk.agents.llm_agent import LlmAgent + + from google.adk.tools.google_tool import GoogleTool + from google.adk.tools.spanner import search_tool + import google.auth + from google.auth.credentials import Credentials + + # credentials_config = SpannerCredentialsConfig(...) + + # Create a wrapped function tool for the agent on top of the built-in + # similarity_search tool in the Spanner toolset. + # This customized tool is used to perform a Spanner KNN vector search on a + # embedded knowledge base stored in a Spanner database table. + def wrapped_spanner_similarity_search( + search_query: str, + credentials: Credentials, + ) -> str: + """Perform a similarity search on the product catalog. + + Args: + search_query: The search query to find relevant content. + + Returns: + Relevant product catalog content with sources + """ + + # ... Customized logic ... + + # Instead of fixing all parameters, you can also expose some of them for + # the LLM to decide. + return search_tool.similarity_search( + project_id="", + instance_id="", + database_id="", + table_name="products", + query=search_query, + embedding_column_to_search="productDescriptionEmbedding", + columns= [ + "productId", + "productName", + "productDescription", + ] + embedding_options={ + "vertex_ai_embedding_model_name": "text-embedding-005", + }, + credentials=credentials, + additional_filter="inventoryCount > 0", + search_options={ + "top_k": 3, + "distance_type": "EUCLIDEAN", + }, + ) + + # ... + + root_agent = LlmAgent( + model="gemini-2.5-flash", + name="spanner_knowledge_base_agent", + description=( + "Agent to answer questions about product-specific recommendations." + ), + instruction=""" + You are a helpful assistant that answers user questions about product-specific recommendations. + 1. Always use the `wrapped_spanner_similarity_search` tool to find relevant information. + 2. If no relevant information is found, say you don't know. + 3. Present all the relevant information naturally and well formatted in your response. + """, + tools=[ + # Add customized Spanner tool based on the built-in similarity_search + # in the Spanner toolset. + GoogleTool( + func=wrapped_spanner_similarity_search, + credentials_config=credentials_config, + tool_settings=tool_settings, + ), + ], + ) + ``` diff --git a/contributing/samples/spanner_rag_agent/agent.py b/contributing/samples/spanner_rag_agent/agent.py index 58ec95294e..1460242184 100644 --- a/contributing/samples/spanner_rag_agent/agent.py +++ b/contributing/samples/spanner_rag_agent/agent.py @@ -13,23 +13,15 @@ # limitations under the License. import os -from typing import Any -from typing import Dict -from typing import Optional from google.adk.agents.llm_agent import LlmAgent from google.adk.auth.auth_credential import AuthCredentialTypes -from google.adk.tools.base_tool import BaseTool -from google.adk.tools.google_tool import GoogleTool -from google.adk.tools.spanner import query_tool -from google.adk.tools.spanner import search_tool from google.adk.tools.spanner.settings import Capabilities from google.adk.tools.spanner.settings import SpannerToolSettings +from google.adk.tools.spanner.settings import SpannerVectorStoreSettings from google.adk.tools.spanner.spanner_credentials import SpannerCredentialsConfig -from google.adk.tools.tool_context import ToolContext +from google.adk.tools.spanner.spanner_toolset import SpannerToolset import google.auth -from google.auth.credentials import Credentials -from pydantic import BaseModel # Define an appropriate credential type # Set to None to use the application default credentials (ADC) for a quick @@ -37,9 +29,6 @@ CREDENTIALS_TYPE = None -# Define Spanner tool config with read capability set to allowed. -tool_settings = SpannerToolSettings(capabilities=[Capabilities.DATA_READ]) - if CREDENTIALS_TYPE == AuthCredentialTypes.OAUTH2: # Initialize the tools to do interactive OAuth # The environment variables OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET @@ -67,172 +56,46 @@ credentials=application_default_credentials ) +# Follow the instructions in README.md to set up the example Spanner database. +# Replace the following settings with your specific Spanner database. + +# Define Spanner vector store settings. +vector_store_settings = SpannerVectorStoreSettings( + project_id="", + instance_id="", + database_id="", + table_name="products", + content_column="productDescription", + embedding_column="productDescriptionEmbedding", + vector_length=768, + vertex_ai_embedding_model_name="text-embedding-005", + selected_columns=[ + "productId", + "productName", + "productDescription", + ], + nearest_neighbors_algorithm="EXACT_NEAREST_NEIGHBORS", + top_k=3, + distance_type="COSINE", + additional_filter="inventoryCount > 0", +) -### Section 1: Extending the built-in Spanner Toolset for Custom Use Cases ### -# This example illustrates how to extend the built-in Spanner toolset to create -# a customized Spanner tool. This method is advantageous when you need to deal -# with a specific use case: -# -# 1. Streamline the end user experience by pre-configuring the tool with fixed -# parameters (such as a specific database, instance, or project) and a -# dedicated SQL query, making it perfect for a single, focused use case -# like vector search on a specific table. -# 2. Enhance functionality by adding custom logic to manage tool inputs, -# execution, and result processing, providing greater control over the -# tool's behavior. -class SpannerRagSetting(BaseModel): - """Customized Spanner RAG settings for an example use case.""" - - # Replace the following settings for your Spanner database used in the sample. - project_id: str = "" - instance_id: str = "" - database_id: str = "" - - # Follow the instructions in README.md, the table name is "products" and the - # Spanner embedding model name is "EmbeddingsModel" in this sample. - table_name: str = "products" - # Learn more about Spanner Vertex AI integration for embedding and Spanner - # vector search. - # https://cloud.google.com/spanner/docs/ml-tutorial-embeddings - # https://cloud.google.com/spanner/docs/vector-search/overview - embedding_model_name: str = "EmbeddingsModel" - - selected_columns: list[str] = [ - "productId", - "productName", - "productDescription", - ] - embedding_column_name: str = "productDescriptionEmbedding" - - additional_filter_expression: str = "inventoryCount > 0" - vector_distance_function: str = "EUCLIDEAN_DISTANCE" - top_k: int = 3 - - -RAG_SETTINGS = SpannerRagSetting() - - -### (Option 1) Use the built-in similarity_search tool ### -# Create a wrapped function tool for the agent on top of the built-in -# similarity_search tool in the Spanner toolset. -# This customized tool is used to perform a Spanner KNN vector search on a -# embedded knowledge base stored in a Spanner database table. -def wrapped_spanner_similarity_search( - search_query: str, - credentials: Credentials, # GoogleTool handles `credentials` automatically - settings: SpannerToolSettings, # GoogleTool handles `settings` automatically - tool_context: ToolContext, # GoogleTool handles `tool_context` automatically -) -> str: - """Perform a similarity search on the product catalog. - - Args: - search_query: The search query to find relevant content. - - Returns: - Relevant product catalog content with sources - """ - columns = RAG_SETTINGS.selected_columns.copy() - - # Instead of fixing all parameters, you can also expose some of them for - # the LLM to decide. - return search_tool.similarity_search( - RAG_SETTINGS.project_id, - RAG_SETTINGS.instance_id, - RAG_SETTINGS.database_id, - RAG_SETTINGS.table_name, - search_query, - RAG_SETTINGS.embedding_column_name, - columns, - { - "spanner_embedding_model_name": RAG_SETTINGS.embedding_model_name, - }, - credentials, - settings, - tool_context, - RAG_SETTINGS.additional_filter_expression, - { - "top_k": RAG_SETTINGS.top_k, - "distance_type": RAG_SETTINGS.vector_distance_function, - }, - ) - - -### (Option 2) Use the built-in execute_sql tool ### -# Create a wrapped function tool for the agent on top of the built-in -# execute_sql tool in the Spanner toolset. -# This customized tool is used to perform a Spanner KNN vector search on a -# embedded knowledge base stored in a Spanner database table. -# -# Compared with similarity_search, using execute_sql (a lower level tool) means -# that you have more control, but you also need to do more work (e.g. to write -# the SQL query from scratch). Consider using this option if your scenario is -# more complicated than a plain similarity search. -def wrapped_spanner_execute_sql_tool( - search_query: str, - credentials: Credentials, # GoogleTool handles `credentials` automatically - settings: SpannerToolSettings, # GoogleTool handles `settings` automatically - tool_context: ToolContext, # GoogleTool handles `tool_context` automatically -) -> str: - """Perform a similarity search on the product catalog. - - Args: - search_query: The search query to find relevant content. - - Returns: - Relevant product catalog content with sources - """ - - embedding_query = f"""SELECT embeddings.values - FROM ML.PREDICT( - MODEL {RAG_SETTINGS.embedding_model_name}, - (SELECT "{search_query}" as content) - ) - """ - - distance_alias = "distance" - columns = [f"{column}" for column in RAG_SETTINGS.selected_columns] - columns += [f"""{RAG_SETTINGS.vector_distance_function}( - {RAG_SETTINGS.embedding_column_name}, - ({embedding_query})) AS {distance_alias} - """] - columns = ", ".join(columns) - - knn_query = f""" - SELECT {columns} - FROM {RAG_SETTINGS.table_name} - WHERE {RAG_SETTINGS.additional_filter_expression} - ORDER BY {distance_alias} - LIMIT {RAG_SETTINGS.top_k} - """ - - # Customized tool based on the built-in Spanner toolset. - return query_tool.execute_sql( - project_id=RAG_SETTINGS.project_id, - instance_id=RAG_SETTINGS.instance_id, - database_id=RAG_SETTINGS.database_id, - query=knn_query, - credentials=credentials, - settings=settings, - tool_context=tool_context, - ) - - -def inspect_tool_params( - tool: BaseTool, - args: Dict[str, Any], - tool_context: ToolContext, -) -> Optional[Dict]: - """A callback function to inspect tool parameters before execution.""" - print("Inspect for tool: " + tool.name) - - actual_search_query_in_args = args.get("search_query") - # Inspect the `search_query` when calling the tool for tutorial purposes. - print(f"Tool args `search_query`: {actual_search_query_in_args}") +# Define Spanner tool config with the vector store settings. +tool_settings = SpannerToolSettings( + capabilities=[Capabilities.DATA_READ], + vector_store_settings=vector_store_settings, +) - pass +# Get the Spanner toolset with the Spanner tool settings and credentials config. +# Filter the tools to only include the `vector_store_similarity_search` tool. +spanner_toolset = SpannerToolset( + credentials_config=credentials_config, + spanner_tool_settings=tool_settings, + # Comment to include all allowed tools. + tool_filter=["vector_store_similarity_search"], +) -### Section 2: Create the root agent ### root_agent = LlmAgent( model="gemini-2.5-flash", name="spanner_knowledge_base_agent", @@ -241,27 +104,10 @@ def inspect_tool_params( ), instruction=""" You are a helpful assistant that answers user questions about product-specific recommendations. - 1. Always use the `wrapped_spanner_similarity_search` tool to find relevant information. - 2. If no relevant information is found, say you don't know. - 3. Present all the relevant information naturally and well formatted in your response. + 1. Always use the `vector_store_similarity_search` tool to find information. + 2. Directly present all the information results from the `vector_store_similarity_search` tool naturally and well formatted in your response. + 3. If no information result is returned by the `vector_store_similarity_search` tool, say you don't know. """, - tools=[ - # # (Option 1) - # # Add customized Spanner tool based on the built-in similarity_search - # # in the Spanner toolset. - GoogleTool( - func=wrapped_spanner_similarity_search, - credentials_config=credentials_config, - tool_settings=tool_settings, - ), - # # (Option 2) - # # Add customized Spanner tool based on the built-in execute_sql in - # # the Spanner toolset. - # GoogleTool( - # func=wrapped_spanner_execute_sql_tool, - # credentials_config=credentials_config, - # tool_settings=tool_settings, - # ), - ], - before_tool_callback=inspect_tool_params, + # Use the Spanner toolset for vector similarity search. + tools=[spanner_toolset], ) diff --git a/contributing/samples/telemetry/main.py b/contributing/samples/telemetry/main.py index e580060dc4..c6e05f0f62 100755 --- a/contributing/samples/telemetry/main.py +++ b/contributing/samples/telemetry/main.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +from contextlib import aclosing import os import time @@ -46,19 +47,16 @@ async def run_prompt(session: Session, new_message: str): role='user', parts=[types.Part.from_text(text=new_message)] ) print('** User says:', content.model_dump(exclude_none=True)) - # TODO - migrate try...finally to contextlib.aclosing after Python 3.9 is - # no longer supported. - agen = runner.run_async( - user_id=user_id_1, - session_id=session.id, - new_message=content, - ) - try: + async with aclosing( + runner.run_async( + user_id=user_id_1, + session_id=session.id, + new_message=content, + ) + ) as agen: async for event in agen: if event.content.parts and event.content.parts[0].text: print(f'** {event.author}: {event.content.parts[0].text}') - finally: - await agen.aclose() async def run_prompt_bytes(session: Session, new_message: str): content = types.Content( @@ -70,20 +68,17 @@ async def run_prompt_bytes(session: Session, new_message: str): ], ) print('** User says:', content.model_dump(exclude_none=True)) - # TODO - migrate try...finally to contextlib.aclosing after Python 3.9 is - # no longer supported. - agen = runner.run_async( - user_id=user_id_1, - session_id=session.id, - new_message=content, - run_config=RunConfig(save_input_blobs_as_artifacts=True), - ) - try: + async with aclosing( + runner.run_async( + user_id=user_id_1, + session_id=session.id, + new_message=content, + run_config=RunConfig(save_input_blobs_as_artifacts=True), + ) + ) as agen: async for event in agen: if event.content.parts and event.content.parts[0].text: print(f'** {event.author}: {event.content.parts[0].text}') - finally: - await agen.aclose() start_time = time.time() print('Start time:', start_time) diff --git a/llms-full.txt b/llms-full.txt index 4c744512e4..b84e9496ee 100644 --- a/llms-full.txt +++ b/llms-full.txt @@ -5620,7 +5620,7 @@ pip install google-cloud-aiplatform[adk,agent_engines] ``` !!!info - Agent Engine only supported Python version >=3.9 and <=3.12. + Agent Engine only supported Python version >=3.10 and <=3.12. ### Initialization @@ -8073,7 +8073,7 @@ setting up a basic agent with multiple tools, and running it locally either in t This quickstart assumes a local IDE (VS Code, PyCharm, IntelliJ IDEA, etc.) -with Python 3.9+ or Java 17+ and terminal access. This method runs the +with Python 3.10+ or Java 17+ and terminal access. This method runs the application entirely on your machine and is recommended for internal development. ## 1. Set up Environment & Install ADK {#venv-install} @@ -16475,7 +16475,7 @@ This guide covers two primary integration patterns: Before you begin, ensure you have the following set up: * **Set up ADK:** Follow the standard ADK [setup instructions](../get-started/quickstart.md/#venv-install) in the quickstart. -* **Install/update Python/Java:** MCP requires Python version of 3.9 or higher for Python or Java 17+. +* **Install/update Python/Java:** MCP requires Python version of 3.10 or higher for Python or Java 17+. * **Setup Node.js and npx:** **(Python only)** Many community MCP servers are distributed as Node.js packages and run using `npx`. Install Node.js (which includes npx) if you haven't already. For details, see [https://nodejs.org/en](https://nodejs.org/en). * **Verify Installations:** **(Python only)** Confirm `adk` and `npx` are in your PATH within the activated virtual environment: diff --git a/pyproject.toml b/pyproject.toml index 4f8e42bcf7..06ddb04ef2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,11 +40,11 @@ dependencies = [ "google-cloud-secret-manager>=2.22.0, <3.0.0", # Fetching secrets in RestAPI Tool "google-cloud-spanner>=3.56.0, <4.0.0", # For Spanner database "google-cloud-speech>=2.30.0, <3.0.0", # For Audio Transcription - "google-cloud-storage>=3.0.0, <4.0.0", # For GCS Artifact service - "google-genai>=1.45.0, <2.0.0", # Google GenAI SDK + "google-cloud-storage>=2.18.0, <4.0.0", # For GCS Artifact service + "google-genai>=1.51.0, <2.0.0", # Google GenAI SDK "graphviz>=0.20.2, <1.0.0", # Graphviz for graph rendering "jsonschema>=4.23.0, <5.0.0", # Agent Builder config validation - "mcp>=1.8.0, <2.0.0", # For MCP Toolset + "mcp>=1.10.0, <2.0.0", # For MCP Toolset "opentelemetry-api>=1.37.0, <=1.37.0", # OpenTelemetry - limit upper version for sdk and api to not risk breaking changes from unstable _logs package. "opentelemetry-exporter-gcp-logging>=1.9.0a0, <2.0.0", "opentelemetry-exporter-gcp-monitoring>=1.9.0a0, <2.0.0", diff --git a/src/google/adk/a2a/converters/part_converter.py b/src/google/adk/a2a/converters/part_converter.py index a21042cc10..dfe6f4a0a2 100644 --- a/src/google/adk/a2a/converters/part_converter.py +++ b/src/google/adk/a2a/converters/part_converter.py @@ -26,23 +26,11 @@ from typing import Optional from typing import Union -from .utils import _get_adk_metadata_key - -try: - from a2a import types as a2a_types -except ImportError as e: - import sys - - if sys.version_info < (3, 10): - raise ImportError( - 'A2A requires Python 3.10 or above. Please upgrade your Python version.' - ) from e - else: - raise e - +from a2a import types as a2a_types from google.genai import types as genai_types from ..experimental import a2a_experimental +from .utils import _get_adk_metadata_key logger = logging.getLogger('google_adk.' + __name__) diff --git a/src/google/adk/a2a/converters/request_converter.py b/src/google/adk/a2a/converters/request_converter.py index 39db41dac6..1746ec0bca 100644 --- a/src/google/adk/a2a/converters/request_converter.py +++ b/src/google/adk/a2a/converters/request_converter.py @@ -15,23 +15,12 @@ from __future__ import annotations from collections.abc import Callable -import sys from typing import Any from typing import Optional -from pydantic import BaseModel - -try: - from a2a.server.agent_execution import RequestContext -except ImportError as e: - if sys.version_info < (3, 10): - raise ImportError( - 'A2A requires Python 3.10 or above. Please upgrade your Python version.' - ) from e - else: - raise e - +from a2a.server.agent_execution import RequestContext from google.genai import types as genai_types +from pydantic import BaseModel from ...runners import RunConfig from ..experimental import a2a_experimental diff --git a/src/google/adk/a2a/executor/a2a_agent_executor.py b/src/google/adk/a2a/executor/a2a_agent_executor.py index 608a818864..b6880aaa5c 100644 --- a/src/google/adk/a2a/executor/a2a_agent_executor.py +++ b/src/google/adk/a2a/executor/a2a_agent_executor.py @@ -23,34 +23,22 @@ from typing import Optional import uuid -from ...utils.context_utils import Aclosing - -try: - from a2a.server.agent_execution import AgentExecutor - from a2a.server.agent_execution.context import RequestContext - from a2a.server.events.event_queue import EventQueue - from a2a.types import Artifact - from a2a.types import Message - from a2a.types import Role - from a2a.types import TaskArtifactUpdateEvent - from a2a.types import TaskState - from a2a.types import TaskStatus - from a2a.types import TaskStatusUpdateEvent - from a2a.types import TextPart - -except ImportError as e: - import sys - - if sys.version_info < (3, 10): - raise ImportError( - 'A2A requires Python 3.10 or above. Please upgrade your Python version.' - ) from e - else: - raise e +from a2a.server.agent_execution import AgentExecutor +from a2a.server.agent_execution.context import RequestContext +from a2a.server.events.event_queue import EventQueue +from a2a.types import Artifact +from a2a.types import Message +from a2a.types import Role +from a2a.types import TaskArtifactUpdateEvent +from a2a.types import TaskState +from a2a.types import TaskStatus +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart from google.adk.runners import Runner from pydantic import BaseModel from typing_extensions import override +from ...utils.context_utils import Aclosing from ..converters.event_converter import AdkEventToA2AEventsConverter from ..converters.event_converter import convert_event_to_a2a_events from ..converters.part_converter import A2APartToGenAIPartConverter diff --git a/src/google/adk/a2a/utils/agent_card_builder.py b/src/google/adk/a2a/utils/agent_card_builder.py index aa7f657f99..c007870931 100644 --- a/src/google/adk/a2a/utils/agent_card_builder.py +++ b/src/google/adk/a2a/utils/agent_card_builder.py @@ -15,25 +15,15 @@ from __future__ import annotations import re -import sys from typing import Dict from typing import List from typing import Optional -try: - from a2a.types import AgentCapabilities - from a2a.types import AgentCard - from a2a.types import AgentProvider - from a2a.types import AgentSkill - from a2a.types import SecurityScheme -except ImportError as e: - if sys.version_info < (3, 10): - raise ImportError( - 'A2A requires Python 3.10 or above. Please upgrade your Python version.' - ) from e - else: - raise e - +from a2a.types import AgentCapabilities +from a2a.types import AgentCard +from a2a.types import AgentProvider +from a2a.types import AgentSkill +from a2a.types import SecurityScheme from ...agents.base_agent import BaseAgent from ...agents.llm_agent import LlmAgent diff --git a/src/google/adk/a2a/utils/agent_to_a2a.py b/src/google/adk/a2a/utils/agent_to_a2a.py index 72a2292fb3..1a1ba35618 100644 --- a/src/google/adk/a2a/utils/agent_to_a2a.py +++ b/src/google/adk/a2a/utils/agent_to_a2a.py @@ -15,30 +15,18 @@ from __future__ import annotations import logging -import sys - -try: - from a2a.server.apps import A2AStarletteApplication - from a2a.server.request_handlers import DefaultRequestHandler - from a2a.server.tasks import InMemoryTaskStore - from a2a.types import AgentCard -except ImportError as e: - if sys.version_info < (3, 10): - raise ImportError( - "A2A requires Python 3.10 or above. Please upgrade your Python version." - ) from e - else: - raise e - from typing import Optional from typing import Union +from a2a.server.apps import A2AStarletteApplication +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import InMemoryTaskStore +from a2a.types import AgentCard from starlette.applications import Starlette from ...agents.base_agent import BaseAgent from ...artifacts.in_memory_artifact_service import InMemoryArtifactService from ...auth.credential_service.in_memory_credential_service import InMemoryCredentialService -from ...cli.utils.logs import setup_adk_logger from ...memory.in_memory_memory_service import InMemoryMemoryService from ...runners import Runner from ...sessions.in_memory_session_service import InMemorySessionService @@ -117,7 +105,8 @@ def to_a2a( app = to_a2a(agent, agent_card=my_custom_agent_card) """ # Set up ADK logging to ensure logs are visible when using uvicorn directly - setup_adk_logger(logging.INFO) + adk_logger = logging.getLogger("google_adk") + adk_logger.setLevel(logging.INFO) async def create_runner() -> Runner: """Create a runner for the agent.""" diff --git a/src/google/adk/agents/__init__.py b/src/google/adk/agents/__init__.py index 5710a21b7f..b5f8e88cde 100644 --- a/src/google/adk/agents/__init__.py +++ b/src/google/adk/agents/__init__.py @@ -12,9 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging -import sys - from .base_agent import BaseAgent from .invocation_context import InvocationContext from .live_request_queue import LiveRequest @@ -22,6 +19,7 @@ from .llm_agent import Agent from .llm_agent import LlmAgent from .loop_agent import LoopAgent +from .mcp_instruction_provider import McpInstructionProvider from .parallel_agent import ParallelAgent from .run_config import RunConfig from .sequential_agent import SequentialAgent @@ -31,6 +29,7 @@ 'BaseAgent', 'LlmAgent', 'LoopAgent', + 'McpInstructionProvider', 'ParallelAgent', 'SequentialAgent', 'InvocationContext', @@ -38,16 +37,3 @@ 'LiveRequestQueue', 'RunConfig', ] - -if sys.version_info < (3, 10): - logger = logging.getLogger('google_adk.' + __name__) - logger.warning( - 'MCP requires Python 3.10 or above. Please upgrade your Python' - ' version in order to use it.' - ) -else: - from .mcp_instruction_provider import McpInstructionProvider - - __all__.extend([ - 'McpInstructionProvider', - ]) diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index a644cb8b90..e15f9af981 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -15,6 +15,7 @@ from __future__ import annotations import inspect +import logging from typing import Any from typing import AsyncGenerator from typing import Awaitable @@ -49,6 +50,8 @@ if TYPE_CHECKING: from .invocation_context import InvocationContext +logger = logging.getLogger('google_adk.' + __name__) + _SingleAgentCallback: TypeAlias = Callable[ [CallbackContext], Union[Awaitable[Optional[types.Content]], Optional[types.Content]], @@ -563,6 +566,45 @@ def validate_name(cls, value: str): ) return value + @field_validator('sub_agents', mode='after') + @classmethod + def validate_sub_agents_unique_names( + cls, value: list[BaseAgent] + ) -> list[BaseAgent]: + """Validates that all sub-agents have unique names. + + Args: + value: The list of sub-agents to validate. + + Returns: + The validated list of sub-agents. + + """ + if not value: + return value + + seen_names: set[str] = set() + duplicates: set[str] = set() + + for sub_agent in value: + name = sub_agent.name + if name in seen_names: + duplicates.add(name) + else: + seen_names.add(name) + + if duplicates: + duplicate_names_str = ', '.join( + f'`{name}`' for name in sorted(duplicates) + ) + logger.warning( + 'Found duplicate sub-agent names: %s. ' + 'All sub-agents must have unique names.', + duplicate_names_str, + ) + + return value + def __set_parent_agent_for_sub_agents(self) -> BaseAgent: for sub_agent in self.sub_agents: if sub_agent.parent_agent is not None: diff --git a/src/google/adk/agents/config_agent_utils.py b/src/google/adk/agents/config_agent_utils.py index 7982a9cf59..38ba2e2578 100644 --- a/src/google/adk/agents/config_agent_utils.py +++ b/src/google/adk/agents/config_agent_utils.py @@ -132,7 +132,7 @@ def resolve_agent_reference( else: return from_config( os.path.join( - referencing_agent_config_abs_path.rsplit("/", 1)[0], + os.path.dirname(referencing_agent_config_abs_path), ref_config.config_path, ) ) diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 2f8a969fad..005d073cc7 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -910,7 +910,9 @@ def _parse_config( from .config_agent_utils import resolve_callbacks from .config_agent_utils import resolve_code_reference - if config.model: + if config.model_code: + kwargs['model'] = resolve_code_reference(config.model_code) + elif config.model: kwargs['model'] = config.model if config.instruction: kwargs['instruction'] = config.instruction diff --git a/src/google/adk/agents/llm_agent_config.py b/src/google/adk/agents/llm_agent_config.py index 7d2493597e..59c6d58869 100644 --- a/src/google/adk/agents/llm_agent_config.py +++ b/src/google/adk/agents/llm_agent_config.py @@ -15,6 +15,7 @@ from __future__ import annotations import logging +from typing import Any from typing import List from typing import Literal from typing import Optional @@ -22,6 +23,7 @@ from google.genai import types from pydantic import ConfigDict from pydantic import Field +from pydantic import model_validator from ..tools.tool_configs import ToolConfig from .base_agent_config import BaseAgentConfig @@ -52,11 +54,48 @@ class LlmAgentConfig(BaseAgentConfig): model: Optional[str] = Field( default=None, description=( - 'Optional. LlmAgent.model. If not set, the model will be inherited' - ' from the ancestor.' + 'Optional. LlmAgent.model. Provide a model name string (e.g.' + ' "gemini-2.0-flash"). If not set, the model will be inherited from' + ' the ancestor. To construct a model instance from code, use' + ' model_code.' ), ) + model_code: Optional[CodeConfig] = Field( + default=None, + description=( + 'Optional. A CodeConfig that instantiates a BaseLlm implementation' + ' such as LiteLlm with custom arguments (API base, fallbacks,' + ' etc.). Cannot be set together with `model`.' + ), + ) + + @model_validator(mode='before') + @classmethod + def _normalize_model_code(cls, data: Any) -> dict[str, Any] | Any: + if not isinstance(data, dict): + return data + + model_value = data.get('model') + model_code = data.get('model_code') + if isinstance(model_value, dict) and model_code is None: + logger.warning( + 'Detected legacy `model` mapping. Use `model_code` to provide a' + ' CodeConfig for custom model construction.' + ) + data = dict(data) + data['model_code'] = model_value + data['model'] = None + + return data + + @model_validator(mode='after') + def _validate_model_sources(self) -> LlmAgentConfig: + if self.model and self.model_code: + raise ValueError('Only one of `model` or `model_code` should be set.') + + return self + instruction: str = Field( description=( 'Required. LlmAgent.instruction. Dynamic instructions with' diff --git a/src/google/adk/agents/mcp_instruction_provider.py b/src/google/adk/agents/mcp_instruction_provider.py index e9f40663c9..20896a7a04 100644 --- a/src/google/adk/agents/mcp_instruction_provider.py +++ b/src/google/adk/agents/mcp_instruction_provider.py @@ -22,24 +22,12 @@ from typing import Dict from typing import TextIO +from mcp import types + +from ..tools.mcp_tool.mcp_session_manager import MCPSessionManager from .llm_agent import InstructionProvider from .readonly_context import ReadonlyContext -# Attempt to import MCP Session Manager from the MCP library, and hints user to -# upgrade their Python version to 3.10 if it fails. -try: - from mcp import types - - from ..tools.mcp_tool.mcp_session_manager import MCPSessionManager -except ImportError as e: - if sys.version_info < (3, 10): - raise ImportError( - "MCP Session Manager requires Python 3.10 or above. Please upgrade" - " your Python version." - ) from e - else: - raise e - class McpInstructionProvider(InstructionProvider): """Fetches agent instructions from an MCP server.""" diff --git a/src/google/adk/agents/parallel_agent.py b/src/google/adk/agents/parallel_agent.py index f7270a75c9..09e65a67a4 100644 --- a/src/google/adk/agents/parallel_agent.py +++ b/src/google/adk/agents/parallel_agent.py @@ -48,34 +48,13 @@ def _create_branch_ctx_for_sub_agent( return invocation_context -# TODO - remove once Python <3.11 is no longer supported. -async def _merge_agent_run_pre_3_11( +async def _merge_agent_run( agent_runs: list[AsyncGenerator[Event, None]], ) -> AsyncGenerator[Event, None]: - """Merges the agent run event generator. - This version works in Python 3.9 and 3.10 and uses custom replacement for - asyncio.TaskGroup for tasks cancellation and exception handling. - - This implementation guarantees for each agent, it won't move on until the - generated event is processed by upstream runner. - - Args: - agent_runs: A list of async generators that yield events from each agent. - - Yields: - Event: The next event from the merged generator. - """ + """Merges agent runs using asyncio.TaskGroup on Python 3.11+.""" sentinel = object() queue = asyncio.Queue() - def propagate_exceptions(tasks): - # Propagate exceptions and errors from tasks. - for task in tasks: - if task.done(): - # Ignore the result (None) of correctly finished tasks and re-raise - # exceptions and errors. - task.result() - # Agents are processed in parallel. # Events for each agent are put on queue sequentially. async def process_an_agent(events_for_one_agent): @@ -89,39 +68,34 @@ async def process_an_agent(events_for_one_agent): # Mark agent as finished. await queue.put((sentinel, None)) - tasks = [] - try: + async with asyncio.TaskGroup() as tg: for events_for_one_agent in agent_runs: - tasks.append(asyncio.create_task(process_an_agent(events_for_one_agent))) + tg.create_task(process_an_agent(events_for_one_agent)) sentinel_count = 0 # Run until all agents finished processing. while sentinel_count < len(agent_runs): - propagate_exceptions(tasks) event, resume_signal = await queue.get() # Agent finished processing. if event is sentinel: sentinel_count += 1 else: yield event - # Signal to agent that event has been processed by runner and it can - # continue now. + # Signal to agent that it should generate next event. resume_signal.set() - finally: - for task in tasks: - task.cancel() -async def _merge_agent_run( +# TODO - remove once Python <3.11 is no longer supported. +async def _merge_agent_run_pre_3_11( agent_runs: list[AsyncGenerator[Event, None]], ) -> AsyncGenerator[Event, None]: - """Merges the agent run event generator. + """Merges agent runs for Python 3.10 without asyncio.TaskGroup. - This implementation guarantees for each agent, it won't move on until the - generated event is processed by upstream runner. + Uses custom cancellation and exception handling to mirror TaskGroup + semantics. Each agent waits until the runner processes emitted events. Args: - agent_runs: A list of async generators that yield events from each agent. + agent_runs: Async generators that yield events from each agent. Yields: Event: The next event from the merged generator. @@ -129,6 +103,14 @@ async def _merge_agent_run( sentinel = object() queue = asyncio.Queue() + def propagate_exceptions(tasks): + # Propagate exceptions and errors from tasks. + for task in tasks: + if task.done(): + # Ignore the result (None) of correctly finished tasks and re-raise + # exceptions and errors. + task.result() + # Agents are processed in parallel. # Events for each agent are put on queue sequentially. async def process_an_agent(events_for_one_agent): @@ -142,21 +124,27 @@ async def process_an_agent(events_for_one_agent): # Mark agent as finished. await queue.put((sentinel, None)) - async with asyncio.TaskGroup() as tg: + tasks = [] + try: for events_for_one_agent in agent_runs: - tg.create_task(process_an_agent(events_for_one_agent)) + tasks.append(asyncio.create_task(process_an_agent(events_for_one_agent))) sentinel_count = 0 # Run until all agents finished processing. while sentinel_count < len(agent_runs): + propagate_exceptions(tasks) event, resume_signal = await queue.get() # Agent finished processing. if event is sentinel: sentinel_count += 1 else: yield event - # Signal to agent that it should generate next event. + # Signal to agent that event has been processed by runner and it can + # continue now. resume_signal.set() + finally: + for task in tasks: + task.cancel() class ParallelAgent(BaseAgent): @@ -195,13 +183,11 @@ async def _run_async_impl( pause_invocation = False try: - # TODO remove if once Python <3.11 is no longer supported. merge_func = ( _merge_agent_run if sys.version_info >= (3, 11) else _merge_agent_run_pre_3_11 ) - async with Aclosing(merge_func(agent_runs)) as agen: async for event in agen: yield event diff --git a/src/google/adk/agents/remote_a2a_agent.py b/src/google/adk/agents/remote_a2a_agent.py index 7b6ff5cdd9..8d133060ec 100644 --- a/src/google/adk/agents/remote_a2a_agent.py +++ b/src/google/adk/agents/remote_a2a_agent.py @@ -26,30 +26,22 @@ from urllib.parse import urlparse import uuid -try: - from a2a.client import Client as A2AClient - from a2a.client import ClientEvent as A2AClientEvent - from a2a.client.card_resolver import A2ACardResolver - from a2a.client.client import ClientConfig as A2AClientConfig - from a2a.client.client_factory import ClientFactory as A2AClientFactory - from a2a.client.errors import A2AClientError - from a2a.types import AgentCard - from a2a.types import Message as A2AMessage - from a2a.types import Part as A2APart - from a2a.types import Role - from a2a.types import TaskArtifactUpdateEvent as A2ATaskArtifactUpdateEvent - from a2a.types import TaskState - from a2a.types import TaskStatusUpdateEvent as A2ATaskStatusUpdateEvent - from a2a.types import TransportProtocol as A2ATransport -except ImportError as e: - import sys - - if sys.version_info < (3, 10): - raise ImportError( - "A2A requires Python 3.10 or above. Please upgrade your Python version." - ) from e - else: - raise e +from a2a.client import Client as A2AClient +from a2a.client import ClientEvent as A2AClientEvent +from a2a.client.card_resolver import A2ACardResolver +from a2a.client.client import ClientConfig as A2AClientConfig +from a2a.client.client_factory import ClientFactory as A2AClientFactory +from a2a.client.errors import A2AClientError +from a2a.types import AgentCard +from a2a.types import Message as A2AMessage +from a2a.types import Part as A2APart +from a2a.types import Role +from a2a.types import TaskArtifactUpdateEvent as A2ATaskArtifactUpdateEvent +from a2a.types import TaskState +from a2a.types import TaskStatusUpdateEvent as A2ATaskStatusUpdateEvent +from a2a.types import TransportProtocol as A2ATransport +from google.genai import types as genai_types +import httpx try: from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH @@ -57,9 +49,6 @@ # Fallback for older versions of a2a-sdk. AGENT_CARD_WELL_KNOWN_PATH = "/.well-known/agent.json" -from google.genai import types as genai_types -import httpx - from ..a2a.converters.event_converter import convert_a2a_message_to_event from ..a2a.converters.event_converter import convert_a2a_task_to_event from ..a2a.converters.event_converter import convert_event_to_a2a_message @@ -417,7 +406,9 @@ async def _handle_a2a_response( # This is the initial response for a streaming task or the complete # response for a non-streaming task, which is the full task state. # We process this to get the initial message. - event = convert_a2a_task_to_event(task, self.name, ctx) + event = convert_a2a_task_to_event( + task, self.name, ctx, self._a2a_part_converter + ) # for streaming task, we update the event with the task status. # We update the event as Thought updates. if task and task.status and task.status.state == TaskState.submitted: @@ -429,7 +420,7 @@ async def _handle_a2a_response( ): # This is a streaming task status update with a message. event = convert_a2a_message_to_event( - update.status.message, self.name, ctx + update.status.message, self.name, ctx, self._a2a_part_converter ) if event.content and update.status.state in [ TaskState.submitted, @@ -447,7 +438,9 @@ async def _handle_a2a_response( # signals: # 1. append: True for partial updates, False for full updates. # 2. last_chunk: True for full updates, False for partial updates. - event = convert_a2a_task_to_event(task, self.name, ctx) + event = convert_a2a_task_to_event( + task, self.name, ctx, self._a2a_part_converter + ) else: # This is a streaming update without a message (e.g. status change) # or a partial artifact update. We don't emit an event for these @@ -463,7 +456,9 @@ async def _handle_a2a_response( # Otherwise, it's a regular A2AMessage for non-streaming responses. elif isinstance(a2a_response, A2AMessage): - event = convert_a2a_message_to_event(a2a_response, self.name, ctx) + event = convert_a2a_message_to_event( + a2a_response, self.name, ctx, self._a2a_part_converter + ) event.custom_metadata = event.custom_metadata or {} if a2a_response.context_id: diff --git a/src/google/adk/apps/compaction.py b/src/google/adk/apps/compaction.py index a6f55f9ad6..4511b1b96e 100644 --- a/src/google/adk/apps/compaction.py +++ b/src/google/adk/apps/compaction.py @@ -72,9 +72,10 @@ async def _run_compaction_for_sliding_window( beginning. - A `CompactedEvent` is generated, summarizing events within `invocation_id` range [1, 2]. - - The session now contains: `[E(inv=1, role=user), E(inv=1, role=model), - E(inv=2, role=user), E(inv=2, role=model), E(inv=2, role=user), - CompactedEvent(inv=[1, 2])]`. + - The session now contains: `[ + E(inv=1, role=user), E(inv=1, role=model), + E(inv=2, role=user), E(inv=2, role=model), + CompactedEvent(inv=[1, 2])]`. 2. **After `invocation_id` 3 events are added:** - No compaction happens yet, because only 1 new invocation (`inv=3`) @@ -91,10 +92,13 @@ async def _run_compaction_for_sliding_window( - The new compaction range is from `invocation_id` 2 to 4. - A new `CompactedEvent` is generated, summarizing events within `invocation_id` range [2, 4]. - - The session now contains: `[E(inv=1, role=user), E(inv=1, role=model), - E(inv=2, role=user), E(inv=2, role=model), E(inv=2, role=user), - CompactedEvent(inv=[1, 2]), E(inv=3, role=user), E(inv=3, role=model), - E(inv=4, role=user), E(inv=4, role=model), CompactedEvent(inv=[2, 4])]`. + - The session now contains: `[ + E(inv=1, role=user), E(inv=1, role=model), + E(inv=2, role=user), E(inv=2, role=model), + CompactedEvent(inv=[1, 2]), + E(inv=3, role=user), E(inv=3, role=model), + E(inv=4, role=user), E(inv=4, role=model), + CompactedEvent(inv=[2, 4])]`. Args: diff --git a/src/google/adk/artifacts/file_artifact_service.py b/src/google/adk/artifacts/file_artifact_service.py index 825e4a7a71..53a830c066 100644 --- a/src/google/adk/artifacts/file_artifact_service.py +++ b/src/google/adk/artifacts/file_artifact_service.py @@ -32,6 +32,7 @@ from pydantic import ValidationError from typing_extensions import override +from ..errors.input_validation_error import InputValidationError from .base_artifact_service import ArtifactVersion from .base_artifact_service import BaseArtifactService @@ -100,14 +101,14 @@ def _resolve_scoped_artifact_path( to `scope_root`. Raises: - ValueError: If `filename` resolves outside of `scope_root`. + InputValidationError: If `filename` resolves outside of `scope_root`. """ stripped = _strip_user_namespace(filename).strip() pure_path = _to_posix_path(stripped) scope_root_resolved = scope_root.resolve(strict=False) if pure_path.is_absolute(): - raise ValueError( + raise InputValidationError( f"Absolute artifact filename {filename!r} is not permitted; " "provide a path relative to the storage scope." ) @@ -118,7 +119,7 @@ def _resolve_scoped_artifact_path( try: relative = candidate.relative_to(scope_root_resolved) except ValueError as exc: - raise ValueError( + raise InputValidationError( f"Artifact filename {filename!r} escapes storage directory " f"{scope_root_resolved}" ) from exc @@ -230,7 +231,7 @@ def _scope_root( if _is_user_scoped(session_id, filename): return _user_artifacts_dir(base) if not session_id: - raise ValueError( + raise InputValidationError( "Session ID must be provided for session-scoped artifacts." ) return _session_artifacts_dir(base, session_id) @@ -371,7 +372,9 @@ def _save_artifact_sync( content_path.write_text(artifact.text, encoding="utf-8") mime_type = None else: - raise ValueError("Artifact must have either inline_data or text content.") + raise InputValidationError( + "Artifact must have either inline_data or text content." + ) canonical_uri = self._canonical_uri( user_id=user_id, diff --git a/src/google/adk/artifacts/gcs_artifact_service.py b/src/google/adk/artifacts/gcs_artifact_service.py index fc18dab6fc..2bf713a5e8 100644 --- a/src/google/adk/artifacts/gcs_artifact_service.py +++ b/src/google/adk/artifacts/gcs_artifact_service.py @@ -30,6 +30,7 @@ from google.genai import types from typing_extensions import override +from ..errors.input_validation_error import InputValidationError from .base_artifact_service import ArtifactVersion from .base_artifact_service import BaseArtifactService @@ -161,7 +162,7 @@ def _get_blob_prefix( return f"{app_name}/{user_id}/user/{filename}" if session_id is None: - raise ValueError( + raise InputValidationError( "Session ID must be provided for session-scoped artifacts." ) return f"{app_name}/{user_id}/{session_id}/{filename}" @@ -230,7 +231,9 @@ def _save_artifact( " GcsArtifactService." ) else: - raise ValueError("Artifact must have either inline_data or text.") + raise InputValidationError( + "Artifact must have either inline_data or text." + ) return version diff --git a/src/google/adk/artifacts/in_memory_artifact_service.py b/src/google/adk/artifacts/in_memory_artifact_service.py index 246e8a85fb..2c7dd14127 100644 --- a/src/google/adk/artifacts/in_memory_artifact_service.py +++ b/src/google/adk/artifacts/in_memory_artifact_service.py @@ -18,12 +18,13 @@ from typing import Any from typing import Optional -from google.adk.artifacts import artifact_util from google.genai import types from pydantic import BaseModel from pydantic import Field from typing_extensions import override +from . import artifact_util +from ..errors.input_validation_error import InputValidationError from .base_artifact_service import ArtifactVersion from .base_artifact_service import BaseArtifactService @@ -86,7 +87,7 @@ def _artifact_path( return f"{app_name}/{user_id}/user/{filename}" if session_id is None: - raise ValueError( + raise InputValidationError( "Session ID must be provided for session-scoped artifacts." ) return f"{app_name}/{user_id}/{session_id}/{filename}" @@ -125,7 +126,7 @@ async def save_artifact( elif artifact.file_data is not None: if artifact_util.is_artifact_ref(artifact): if not artifact_util.parse_artifact_uri(artifact.file_data.file_uri): - raise ValueError( + raise InputValidationError( f"Invalid artifact reference URI: {artifact.file_data.file_uri}" ) # If it's a valid artifact URI, we store the artifact part as-is. @@ -133,7 +134,7 @@ async def save_artifact( else: artifact_version.mime_type = artifact.file_data.mime_type else: - raise ValueError("Not supported artifact type.") + raise InputValidationError("Not supported artifact type.") self.artifacts[path].append( _ArtifactEntry(data=artifact, artifact_version=artifact_version) @@ -172,7 +173,7 @@ async def load_artifact( artifact_data.file_data.file_uri ) if not parsed_uri: - raise ValueError( + raise InputValidationError( "Invalid artifact reference URI:" f" {artifact_data.file_data.file_uri}" ) diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index 1b422fe335..78fe426628 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -61,9 +61,11 @@ from ..agents.run_config import RunConfig from ..agents.run_config import StreamingMode from ..apps.app import App +from ..artifacts.base_artifact_service import ArtifactVersion from ..artifacts.base_artifact_service import BaseArtifactService from ..auth.credential_service.base_credential_service import BaseCredentialService from ..errors.already_exists_error import AlreadyExistsError +from ..errors.input_validation_error import InputValidationError from ..errors.not_found_error import NotFoundError from ..evaluation.base_eval_service import InferenceConfig from ..evaluation.base_eval_service import InferenceRequest @@ -194,6 +196,19 @@ class CreateSessionRequest(common.BaseModel): ) +class SaveArtifactRequest(common.BaseModel): + """Request payload for saving a new artifact.""" + + filename: str = Field(description="Artifact filename.") + artifact: types.Part = Field( + description="Artifact payload encoded as google.genai.types.Part." + ) + custom_metadata: Optional[dict[str, Any]] = Field( + default=None, + description="Optional metadata to associate with the artifact version.", + ) + + class AddSessionToEvalSetRequest(common.BaseModel): eval_id: str session_id: str @@ -280,6 +295,17 @@ class ListMetricsInfoResponse(common.BaseModel): metrics_info: list[MetricInfo] +class AppInfo(common.BaseModel): + name: str + root_agent_name: str + description: str + language: Literal["yaml", "python"] + + +class ListAppsResponse(common.BaseModel): + apps: list[AppInfo] + + def _setup_telemetry( otel_to_cloud: bool = False, internal_exporters: Optional[list[SpanProcessor]] = None, @@ -699,7 +725,14 @@ async def internal_lifespan(app: FastAPI): ) @app.get("/list-apps") - async def list_apps() -> list[str]: + async def list_apps( + detailed: bool = Query( + default=False, description="Return detailed app information" + ) + ) -> list[str] | ListAppsResponse: + if detailed: + apps_info = self.agent_loader.list_agents_detailed() + return ListAppsResponse(apps=[AppInfo(**app) for app in apps_info]) return self.agent_loader.list_agents() @app.get("/debug/trace/{event_id}", tags=[TAG_DEBUG]) @@ -1298,6 +1331,53 @@ async def load_artifact_version( raise HTTPException(status_code=404, detail="Artifact not found") return artifact + @app.post( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts", + response_model=ArtifactVersion, + response_model_exclude_none=True, + ) + async def save_artifact( + app_name: str, + user_id: str, + session_id: str, + req: SaveArtifactRequest, + ) -> ArtifactVersion: + try: + version = await self.artifact_service.save_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=req.filename, + artifact=req.artifact, + custom_metadata=req.custom_metadata, + ) + except InputValidationError as ive: + raise HTTPException(status_code=400, detail=str(ive)) from ive + except Exception as exc: # pylint: disable=broad-exception-caught + logger.error( + "Internal error while saving artifact %s for app=%s user=%s" + " session=%s: %s", + req.filename, + app_name, + user_id, + session_id, + exc, + exc_info=True, + ) + raise HTTPException(status_code=500, detail=str(exc)) from exc + artifact_version = await self.artifact_service.get_artifact_version( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=req.filename, + version=version, + ) + if artifact_version is None: + raise HTTPException( + status_code=500, detail="Artifact metadata unavailable" + ) + return artifact_version + @app.get( "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts", response_model_exclude_none=True, diff --git a/src/google/adk/cli/cli.py b/src/google/adk/cli/cli.py index ed294d3922..a1b63a4c46 100644 --- a/src/google/adk/cli/cli.py +++ b/src/google/adk/cli/cli.py @@ -15,6 +15,7 @@ from __future__ import annotations from datetime import datetime +from pathlib import Path from typing import Optional from typing import Union @@ -22,21 +23,21 @@ from google.genai import types from pydantic import BaseModel -from ..agents.base_agent import BaseAgent from ..agents.llm_agent import LlmAgent from ..apps.app import App from ..artifacts.base_artifact_service import BaseArtifactService -from ..artifacts.in_memory_artifact_service import InMemoryArtifactService from ..auth.credential_service.base_credential_service import BaseCredentialService from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService from ..runners import Runner from ..sessions.base_session_service import BaseSessionService -from ..sessions.in_memory_session_service import InMemorySessionService from ..sessions.session import Session from ..utils.context_utils import Aclosing from ..utils.env_utils import is_env_enabled +from .service_registry import load_services_module from .utils import envs from .utils.agent_loader import AgentLoader +from .utils.service_factory import create_artifact_service_from_options +from .utils.service_factory import create_session_service_from_options class InputFile(BaseModel): @@ -66,7 +67,7 @@ async def run_input_file( ) with open(input_path, 'r', encoding='utf-8') as f: input_file = InputFile.model_validate_json(f.read()) - input_file.state['_time'] = datetime.now() + input_file.state['_time'] = datetime.now().isoformat() session = await session_service.create_session( app_name=app_name, user_id=user_id, state=input_file.state @@ -134,6 +135,8 @@ async def run_cli( saved_session_file: Optional[str] = None, save_session: bool, session_id: Optional[str] = None, + session_service_uri: Optional[str] = None, + artifact_service_uri: Optional[str] = None, ) -> None: """Runs an interactive CLI for a certain agent. @@ -148,24 +151,48 @@ async def run_cli( contains a previously saved session, exclusive with input_file. save_session: bool, whether to save the session on exit. session_id: Optional[str], the session ID to save the session to on exit. + session_service_uri: Optional[str], custom session service URI. + artifact_service_uri: Optional[str], custom artifact service URI. """ + agent_parent_path = Path(agent_parent_dir).resolve() + agent_root = agent_parent_path / agent_folder_name + load_services_module(str(agent_root)) + user_id = 'test_user' - artifact_service = InMemoryArtifactService() - session_service = InMemorySessionService() - credential_service = InMemoryCredentialService() + # Create session and artifact services using factory functions + # Sessions persist under //.adk/session.db by default. + session_service = create_session_service_from_options( + base_dir=agent_parent_path, + session_service_uri=session_service_uri, + ) - user_id = 'test_user' - agent_or_app = AgentLoader(agents_dir=agent_parent_dir).load_agent( + artifact_service = create_artifact_service_from_options( + base_dir=agent_root, + artifact_service_uri=artifact_service_uri, + ) + + credential_service = InMemoryCredentialService() + agents_dir = str(agent_parent_path) + agent_or_app = AgentLoader(agents_dir=agents_dir).load_agent( agent_folder_name ) session_app_name = ( agent_or_app.name if isinstance(agent_or_app, App) else agent_folder_name ) - session = await session_service.create_session( - app_name=session_app_name, user_id=user_id - ) if not is_env_enabled('ADK_DISABLE_LOAD_DOTENV'): - envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir) + envs.load_dotenv_for_agent(agent_folder_name, agents_dir) + + # Helper function for printing events + def _print_event(event) -> None: + content = event.content + if not content or not content.parts: + return + text_parts = [part.text for part in content.parts if part.text] + if not text_parts: + return + author = event.author or 'system' + click.echo(f'[{author}]: {"".join(text_parts)}') + if input_file: session = await run_input_file( app_name=session_app_name, @@ -177,16 +204,22 @@ async def run_cli( input_path=input_file, ) elif saved_session_file: + # Load the saved session from file with open(saved_session_file, 'r', encoding='utf-8') as f: loaded_session = Session.model_validate_json(f.read()) + # Create a new session in the service, copying state from the file + session = await session_service.create_session( + app_name=session_app_name, + user_id=user_id, + state=loaded_session.state if loaded_session else None, + ) + + # Append events from the file to the new session and display them if loaded_session: for event in loaded_session.events: await session_service.append_event(session, event) - content = event.content - if not content or not content.parts or not content.parts[0].text: - continue - click.echo(f'[{event.author}]: {content.parts[0].text}') + _print_event(event) await run_interactively( agent_or_app, @@ -196,6 +229,9 @@ async def run_cli( credential_service, ) else: + session = await session_service.create_session( + app_name=session_app_name, user_id=user_id + ) click.echo(f'Running agent {agent_or_app.name}, type exit to exit.') await run_interactively( agent_or_app, @@ -207,9 +243,7 @@ async def run_cli( if save_session: session_id = session_id or input('Session ID to save: ') - session_path = ( - f'{agent_parent_dir}/{agent_folder_name}/{session_id}.session.json' - ) + session_path = agent_root / f'{session_id}.session.json' # Fetch the session again to get all the details. session = await session_service.get_session( @@ -217,7 +251,9 @@ async def run_cli( user_id=session.user_id, session_id=session.id, ) - with open(session_path, 'w', encoding='utf-8') as f: - f.write(session.model_dump_json(indent=2, exclude_none=True)) + session_path.write_text( + session.model_dump_json(indent=2, exclude_none=True, by_alias=True), + encoding='utf-8', + ) print('Session saved to', session_path) diff --git a/src/google/adk/cli/cli_deploy.py b/src/google/adk/cli/cli_deploy.py index 8da65d0e21..45dce7fda6 100644 --- a/src/google/adk/cli/cli_deploy.py +++ b/src/google/adk/cli/cli_deploy.py @@ -915,7 +915,7 @@ def to_agent_engine( ) else: if project and region and not agent_engine_id.startswith('projects/'): - agent_engine_id = f'projects/{project}/locations/{region}/agentEngines/{agent_engine_id}' + agent_engine_id = f'projects/{project}/locations/{region}/reasoningEngines/{agent_engine_id}' client.agent_engines.update(name=agent_engine_id, config=agent_config) click.secho(f'✅ Updated agent engine: {agent_engine_id}', fg='green') finally: diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 529ee7319c..5d228f72f3 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -24,6 +24,7 @@ import os from pathlib import Path import tempfile +import textwrap from typing import Optional import click @@ -108,6 +109,18 @@ def parse_args(self, ctx, args): logger = logging.getLogger("google_adk." + __name__) +_ADK_WEB_WARNING = ( + "ADK Web is for development purposes. It has access to all data and" + " should not be used in production." +) + + +def _warn_if_with_ui(with_ui: bool) -> None: + """Warn when deploying with the developer UI enabled.""" + if with_ui: + click.secho(f"WARNING: {_ADK_WEB_WARNING}", fg="yellow", err=True) + + @click.group(context_settings={"max_content_width": 240}) @click.version_option(version.__version__) def main(): @@ -354,7 +367,62 @@ def validate_exclusive(ctx, param, value): return value +def adk_services_options(): + """Decorator to add ADK services options to click commands.""" + + def decorator(func): + @click.option( + "--session_service_uri", + help=textwrap.dedent( + """\ + Optional. The URI of the session service. + - Leave unset to use the in-memory session service (default). + - Use 'agentengine://' to connect to Agent Engine + sessions. can either be the full qualified resource + name 'projects/abc/locations/us-central1/reasoningEngines/123' or + the resource id '123'. + - Use 'memory://' to run with the in-memory session service. + - Use 'sqlite://' to connect to a SQLite DB. + - See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported database URIs.""" + ), + ) + @click.option( + "--artifact_service_uri", + type=str, + help=textwrap.dedent( + """\ + Optional. The URI of the artifact service. + - Leave unset to store artifacts under '.adk/artifacts' locally. + - Use 'gs://' to connect to the GCS artifact service. + - Use 'memory://' to force the in-memory artifact service. + - Use 'file://' to store artifacts in a custom local directory.""" + ), + default=None, + ) + @click.option( + "--memory_service_uri", + type=str, + help=textwrap.dedent("""\ + Optional. The URI of the memory service. + - Use 'rag://' to connect to Vertex AI Rag Memory Service. + - Use 'agentengine://' to connect to Agent Engine + sessions. can either be the full qualified resource + name 'projects/abc/locations/us-central1/reasoningEngines/123' or + the resource id '123'. + - Use 'memory://' to force the in-memory memory service."""), + default=None, + ) + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + return decorator + + @main.command("run", cls=HelpfulCommand) +@adk_services_options() @click.option( "--save_session", type=bool, @@ -409,6 +477,9 @@ def cli_run( session_id: Optional[str], replay: Optional[str], resume: Optional[str], + session_service_uri: Optional[str] = None, + artifact_service_uri: Optional[str] = None, + memory_service_uri: Optional[str] = None, ): """Runs an interactive CLI for a certain agent. @@ -420,6 +491,14 @@ def cli_run( """ logs.log_to_tmp_folder() + # Validation warning for memory_service_uri (not supported for adk run) + if memory_service_uri: + click.secho( + "WARNING: --memory_service_uri is not supported for adk run.", + fg="yellow", + err=True, + ) + agent_parent_folder = os.path.dirname(agent) agent_folder_name = os.path.basename(agent) @@ -431,6 +510,8 @@ def cli_run( saved_session_file=resume, save_session=save_session, session_id=session_id, + session_service_uri=session_service_uri, + artifact_service_uri=artifact_service_uri, ) ) @@ -557,7 +638,7 @@ def cli_eval( from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager from ..evaluation.local_eval_sets_manager import load_eval_set_from_file from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager - from ..evaluation.user_simulator_provider import UserSimulatorProvider + from ..evaluation.simulation.user_simulator_provider import UserSimulatorProvider from .cli_eval import _collect_eval_results from .cli_eval import _collect_inferences from .cli_eval import get_root_agent @@ -865,55 +946,6 @@ def wrapper(*args, **kwargs): return decorator -def adk_services_options(): - """Decorator to add ADK services options to click commands.""" - - def decorator(func): - @click.option( - "--session_service_uri", - help=( - """Optional. The URI of the session service. - - Use 'agentengine://' to connect to Agent Engine - sessions. can either be the full qualified resource - name 'projects/abc/locations/us-central1/reasoningEngines/123' or - the resource id '123'. - - Use 'sqlite://' to connect to an aio-sqlite - based session service, which is good for local development. - - Use 'postgresql://:@:/' - to connect to a PostgreSQL DB. - - See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls - for more details on other database URIs supported by SQLAlchemy.""" - ), - ) - @click.option( - "--artifact_service_uri", - type=str, - help=( - "Optional. The URI of the artifact service," - " supported URIs: gs:// for GCS artifact service." - ), - default=None, - ) - @click.option( - "--memory_service_uri", - type=str, - help=("""Optional. The URI of the memory service. - - Use 'rag://' to connect to Vertex AI Rag Memory Service. - - Use 'agentengine://' to connect to Agent Engine - sessions. can either be the full qualified resource - name 'projects/abc/locations/us-central1/reasoningEngines/123' or - the resource id '123'."""), - default=None, - ) - @functools.wraps(func) - def wrapper(*args, **kwargs): - return func(*args, **kwargs) - - return wrapper - - return decorator - - def deprecated_adk_services_options(): """Deprecated ADK services options.""" @@ -921,7 +953,7 @@ def warn(alternative_param, ctx, param, value): if value: click.echo( click.style( - f"WARNING: Deprecated option {param.name} is used. Please use" + f"WARNING: Deprecated option --{param.name} is used. Please use" f" {alternative_param} instead.", fg="yellow", ), @@ -1116,6 +1148,8 @@ def cli_web( adk web --session_service_uri=[uri] --port=[port] path/to/agents_dir """ + session_service_uri = session_service_uri or session_db_url + artifact_service_uri = artifact_service_uri or artifact_storage_uri logs.setup_adk_logger(getattr(logging, log_level.upper())) @asynccontextmanager @@ -1140,8 +1174,6 @@ async def _lifespan(app: FastAPI): fg="green", ) - session_service_uri = session_service_uri or session_db_url - artifact_service_uri = artifact_service_uri or artifact_storage_uri app = get_fast_api_app( agents_dir=agents_dir, session_service_uri=session_service_uri, @@ -1215,10 +1247,10 @@ def cli_api_server( adk api_server --session_service_uri=[uri] --port=[port] path/to/agents_dir """ - logs.setup_adk_logger(getattr(logging, log_level.upper())) - session_service_uri = session_service_uri or session_db_url artifact_service_uri = artifact_service_uri or artifact_storage_uri + logs.setup_adk_logger(getattr(logging, log_level.upper())) + config = uvicorn.Config( get_fast_api_app( agents_dir=agents_dir, @@ -1408,6 +1440,8 @@ def cli_deploy_cloud_run( err=True, ) + _warn_if_with_ui(with_ui) + session_service_uri = session_service_uri or session_db_url artifact_service_uri = artifact_service_uri or artifact_storage_uri @@ -1792,6 +1826,7 @@ def cli_deploy_gke( --cluster_name=[cluster_name] path/to/my_agent """ try: + _warn_if_with_ui(with_ui) cli_deploy.to_gke( agent_folder=agent, project=project, diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index eec6bb646b..df06b1cf4c 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -14,6 +14,7 @@ from __future__ import annotations +import importlib import json import logging import os @@ -51,6 +52,26 @@ logger = logging.getLogger("google_adk." + __name__) +_LAZY_SERVICE_IMPORTS: dict[str, str] = { + "AgentLoader": ".utils.agent_loader", + "InMemoryArtifactService": "..artifacts.in_memory_artifact_service", + "InMemoryMemoryService": "..memory.in_memory_memory_service", + "InMemorySessionService": "..sessions.in_memory_session_service", + "LocalEvalSetResultsManager": "..evaluation.local_eval_set_results_manager", + "LocalEvalSetsManager": "..evaluation.local_eval_sets_manager", +} + + +def __getattr__(name: str): + """Lazily import defaults so patching in tests keeps working.""" + if name not in _LAZY_SERVICE_IMPORTS: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + module = importlib.import_module(_LAZY_SERVICE_IMPORTS[name], __package__) + attr = getattr(module, name) + globals()[name] = attr + return attr + def get_fast_api_app( *, @@ -321,25 +342,14 @@ async def get_agent_builder( ) if a2a: - try: - from a2a.server.apps import A2AStarletteApplication - from a2a.server.request_handlers import DefaultRequestHandler - from a2a.server.tasks import InMemoryTaskStore - from a2a.types import AgentCard - from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH - - from ..a2a.executor.a2a_agent_executor import A2aAgentExecutor - - except ImportError as e: - import sys - - if sys.version_info < (3, 10): - raise ImportError( - "A2A requires Python 3.10 or above. Please upgrade your Python" - " version." - ) from e - else: - raise e + from a2a.server.apps import A2AStarletteApplication + from a2a.server.request_handlers import DefaultRequestHandler + from a2a.server.tasks import InMemoryTaskStore + from a2a.types import AgentCard + from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH + + from ..a2a.executor.a2a_agent_executor import A2aAgentExecutor + # locate all a2a agent apps in the agents directory base_path = Path.cwd() / agents_dir # the root agents directory should be an existing folder diff --git a/src/google/adk/cli/service_registry.py b/src/google/adk/cli/service_registry.py index 9f23b73035..3e7921e075 100644 --- a/src/google/adk/cli/service_registry.py +++ b/src/google/adk/cli/service_registry.py @@ -76,7 +76,6 @@ def my_session_factory(uri: str, **kwargs): from ..artifacts.base_artifact_service import BaseArtifactService from ..memory.base_memory_service import BaseMemoryService -from ..sessions import InMemorySessionService from ..sessions.base_session_service import BaseSessionService from ..utils import yaml_utils @@ -272,18 +271,14 @@ def gcs_artifact_factory(uri: str, **kwargs): kwargs_copy = kwargs.copy() kwargs_copy.pop("agents_dir", None) + kwargs_copy.pop("per_agent", None) parsed_uri = urlparse(uri) bucket_name = parsed_uri.netloc return GcsArtifactService(bucket_name=bucket_name, **kwargs_copy) - def file_artifact_factory(uri: str, **kwargs): + def file_artifact_factory(uri: str, **_): from ..artifacts.file_artifact_service import FileArtifactService - per_agent = kwargs.get("per_agent", False) - if per_agent: - raise ValueError( - "file:// artifact URIs are not supported in multi-agent mode." - ) parsed_uri = urlparse(uri) if parsed_uri.netloc not in ("", "localhost"): raise ValueError( diff --git a/src/google/adk/cli/utils/__init__.py b/src/google/adk/cli/utils/__init__.py index 8aa11b252b..1800f5d04c 100644 --- a/src/google/adk/cli/utils/__init__.py +++ b/src/google/adk/cli/utils/__init__.py @@ -18,8 +18,10 @@ from ...agents.base_agent import BaseAgent from ...agents.llm_agent import LlmAgent +from .dot_adk_folder import DotAdkFolder from .state import create_empty_state __all__ = [ 'create_empty_state', + 'DotAdkFolder', ] diff --git a/src/google/adk/cli/utils/agent_loader.py b/src/google/adk/cli/utils/agent_loader.py index 9f01705d4f..d6965e5bbb 100644 --- a/src/google/adk/cli/utils/agent_loader.py +++ b/src/google/adk/cli/utils/agent_loader.py @@ -20,6 +20,8 @@ import os from pathlib import Path import sys +from typing import Any +from typing import Literal from typing import Optional from typing import Union @@ -56,7 +58,7 @@ class AgentLoader(BaseAgentLoader): """ def __init__(self, agents_dir: str): - self.agents_dir = agents_dir.rstrip("/") + self.agents_dir = str(Path(agents_dir)) self._original_sys_path = None self._agent_cache: dict[str, Union[BaseAgent, App]] = {} @@ -270,12 +272,13 @@ def _perform_load(self, agent_name: str) -> Union[BaseAgent, App]: f"No root_agent found for '{agent_name}'. Searched in" f" '{actual_agent_name}.agent.root_agent'," f" '{actual_agent_name}.root_agent' and" - f" '{actual_agent_name}/root_agent.yaml'.\n\nExpected directory" - f" structure:\n /\n {actual_agent_name}/\n " - " agent.py (with root_agent) OR\n root_agent.yaml\n\nThen run:" - f" adk web \n\nEnsure '{agents_dir}/{actual_agent_name}' is" - " structured correctly, an .env file can be loaded if present, and a" - f" root_agent is exposed.{hint}" + f" '{actual_agent_name}{os.sep}root_agent.yaml'.\n\nExpected directory" + f" structure:\n {os.sep}\n " + f" {actual_agent_name}{os.sep}\n agent.py (with root_agent) OR\n " + " root_agent.yaml\n\nThen run: adk web \n\nEnsure" + f" '{os.path.join(agents_dir, actual_agent_name)}' is structured" + " correctly, an .env file can be loaded if present, and a root_agent" + f" is exposed.{hint}" ) def _record_origin_metadata( @@ -341,6 +344,50 @@ def list_agents(self) -> list[str]: agent_names.sort() return agent_names + def list_agents_detailed(self) -> list[dict[str, Any]]: + """Lists all agents with detailed metadata (name, description, type).""" + agent_names = self.list_agents() + apps_info = [] + + for agent_name in agent_names: + try: + loaded = self.load_agent(agent_name) + if isinstance(loaded, App): + agent = loaded.root_agent + else: + agent = loaded + + language = self._determine_agent_language(agent_name) + + app_info = { + "name": agent_name, + "root_agent_name": agent.name, + "description": agent.description, + "language": language, + } + apps_info.append(app_info) + + except Exception as e: + logger.error("Failed to load agent '%s': %s", agent_name, e) + continue + + return apps_info + + def _determine_agent_language( + self, agent_name: str + ) -> Literal["yaml", "python"]: + """Determine the type of agent based on file structure.""" + base_path = Path.cwd() / self.agents_dir / agent_name + + if (base_path / "root_agent.yaml").exists(): + return "yaml" + elif (base_path / "agent.py").exists(): + return "python" + elif (base_path / "__init__.py").exists(): + return "python" + + raise ValueError(f"Could not determine agent type for '{agent_name}'.") + def remove_agent_from_cache(self, agent_name: str): # Clear module cache for the agent and its submodules keys_to_delete = [ diff --git a/src/google/adk/cli/utils/base_agent_loader.py b/src/google/adk/cli/utils/base_agent_loader.py index d62a6b8651..bcef0dae42 100644 --- a/src/google/adk/cli/utils/base_agent_loader.py +++ b/src/google/adk/cli/utils/base_agent_loader.py @@ -18,6 +18,7 @@ from abc import ABC from abc import abstractmethod +from typing import Any from typing import Union from ...agents.base_agent import BaseAgent @@ -34,3 +35,15 @@ def load_agent(self, agent_name: str) -> Union[BaseAgent, App]: @abstractmethod def list_agents(self) -> list[str]: """Lists all agents available in the agent loader in alphabetical order.""" + + def list_agents_detailed(self) -> list[dict[str, Any]]: + agent_names = self.list_agents() + return [ + { + 'name': name, + 'display_name': None, + 'description': None, + 'type': None, + } + for name in agent_names + ] diff --git a/src/google/adk/cli/utils/local_storage.py b/src/google/adk/cli/utils/local_storage.py index 9e6b3f3d54..ec7099b8c8 100644 --- a/src/google/adk/cli/utils/local_storage.py +++ b/src/google/adk/cli/utils/local_storage.py @@ -57,14 +57,39 @@ def create_local_database_session_service( return SqliteSessionService(db_path=str(session_db_path)) +def create_local_session_service( + *, + base_dir: Path | str, + per_agent: bool = False, +) -> BaseSessionService: + """Creates a local SQLite-backed session service. + + Args: + base_dir: The base directory for the agent(s). + per_agent: If True, creates a PerAgentDatabaseSessionService that stores + sessions in each agent's .adk folder. If False, creates a single + SqliteSessionService at base_dir/.adk/session.db. + + Returns: + A BaseSessionService instance backed by SQLite. + """ + if per_agent: + logger.info( + "Using per-agent session storage rooted at %s", + base_dir, + ) + return PerAgentDatabaseSessionService(agents_root=base_dir) + + return create_local_database_session_service(base_dir=base_dir) + + def create_local_artifact_service( - *, base_dir: Path | str, per_agent: bool = False + *, base_dir: Path | str ) -> BaseArtifactService: """Creates a file-backed artifact service rooted in `.adk/artifacts`. Args: base_dir: Directory whose `.adk` folder will store artifacts. - per_agent: Indicates whether the service is being used in multi-agent mode. Returns: A `FileArtifactService` scoped to the derived root directory. @@ -72,13 +97,7 @@ def create_local_artifact_service( manager = DotAdkFolder(base_dir) artifact_root = manager.artifacts_dir artifact_root.mkdir(parents=True, exist_ok=True) - if per_agent: - logger.info( - "Using shared file artifact service rooted at %s for multi-agent mode", - artifact_root, - ) - else: - logger.info("Using file artifact service at %s", artifact_root) + logger.info("Using file artifact service at %s", artifact_root) return FileArtifactService(root_dir=artifact_root) diff --git a/src/google/adk/cli/utils/service_factory.py b/src/google/adk/cli/utils/service_factory.py new file mode 100644 index 0000000000..60f4ddd3cf --- /dev/null +++ b/src/google/adk/cli/utils/service_factory.py @@ -0,0 +1,120 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any +from typing import Optional + +from ...artifacts.base_artifact_service import BaseArtifactService +from ...memory.base_memory_service import BaseMemoryService +from ...sessions.base_session_service import BaseSessionService +from ..service_registry import get_service_registry +from .local_storage import create_local_artifact_service +from .local_storage import create_local_session_service + +logger = logging.getLogger("google_adk." + __name__) + + +def create_session_service_from_options( + *, + base_dir: Path | str, + session_service_uri: Optional[str] = None, + session_db_kwargs: Optional[dict[str, Any]] = None, +) -> BaseSessionService: + """Creates a session service based on CLI/web options.""" + base_path = Path(base_dir) + registry = get_service_registry() + + kwargs: dict[str, Any] = { + "agents_dir": str(base_path), + } + if session_db_kwargs: + kwargs.update(session_db_kwargs) + + if session_service_uri: + logger.info("Using session service URI: %s", session_service_uri) + service = registry.create_session_service(session_service_uri, **kwargs) + if service is not None: + return service + + # Fallback to DatabaseSessionService if the registry doesn't support the + # session service URI scheme. This keeps support for SQLAlchemy-compatible + # databases like AlloyDB or Cloud Spanner without explicit registration. + from ...sessions.database_session_service import DatabaseSessionService + + fallback_kwargs = dict(kwargs) + fallback_kwargs.pop("agents_dir", None) + logger.info( + "Falling back to DatabaseSessionService for URI: %s", + session_service_uri, + ) + return DatabaseSessionService(db_url=session_service_uri, **fallback_kwargs) + + # Default to per-agent local SQLite storage in //.adk/. + return create_local_session_service(base_dir=base_path, per_agent=True) + + +def create_memory_service_from_options( + *, + base_dir: Path | str, + memory_service_uri: Optional[str] = None, +) -> BaseMemoryService: + """Creates a memory service based on CLI/web options.""" + base_path = Path(base_dir) + registry = get_service_registry() + + if memory_service_uri: + logger.info("Using memory service URI: %s", memory_service_uri) + service = registry.create_memory_service( + memory_service_uri, + agents_dir=str(base_path), + ) + if service is None: + raise ValueError(f"Unsupported memory service URI: {memory_service_uri}") + return service + + logger.info("Using in-memory memory service") + from ...memory.in_memory_memory_service import InMemoryMemoryService + + return InMemoryMemoryService() + + +def create_artifact_service_from_options( + *, + base_dir: Path | str, + artifact_service_uri: Optional[str] = None, +) -> BaseArtifactService: + """Creates an artifact service based on CLI/web options.""" + base_path = Path(base_dir) + registry = get_service_registry() + + if artifact_service_uri: + logger.info("Using artifact service URI: %s", artifact_service_uri) + service = registry.create_artifact_service( + artifact_service_uri, + agents_dir=str(base_path), + ) + if service is None: + logger.warning( + "Unsupported artifact service URI: %s, falling back to in-memory", + artifact_service_uri, + ) + from ...artifacts.in_memory_artifact_service import InMemoryArtifactService + + return InMemoryArtifactService() + return service + + return create_local_artifact_service(base_dir=base_path) diff --git a/src/google/adk/code_executors/unsafe_local_code_executor.py b/src/google/adk/code_executors/unsafe_local_code_executor.py index 6dd2ae9d8c..b47fbd17e9 100644 --- a/src/google/adk/code_executors/unsafe_local_code_executor.py +++ b/src/google/adk/code_executors/unsafe_local_code_executor.py @@ -72,7 +72,7 @@ def execute_code( _prepare_globals(code_execution_input.code, globals_) stdout = io.StringIO() with redirect_stdout(stdout): - exec(code_execution_input.code, globals_) + exec(code_execution_input.code, globals_, globals_) output = stdout.getvalue() except Exception as e: error = str(e) diff --git a/src/google/adk/errors/input_validation_error.py b/src/google/adk/errors/input_validation_error.py new file mode 100644 index 0000000000..76b1625a10 --- /dev/null +++ b/src/google/adk/errors/input_validation_error.py @@ -0,0 +1,28 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + + +class InputValidationError(ValueError): + """Represents an error raised when user input fails validation.""" + + def __init__(self, message="Invalid input."): + """Initializes the InputValidationError exception. + + Args: + message (str): A message describing why the input is invalid. + """ + self.message = message + super().__init__(self.message) diff --git a/src/google/adk/evaluation/agent_evaluator.py b/src/google/adk/evaluation/agent_evaluator.py index cafa712f56..514681bfa9 100644 --- a/src/google/adk/evaluation/agent_evaluator.py +++ b/src/google/adk/evaluation/agent_evaluator.py @@ -50,7 +50,7 @@ from .evaluator import EvalStatus from .in_memory_eval_sets_manager import InMemoryEvalSetsManager from .local_eval_sets_manager import convert_eval_set_to_pydantic_schema -from .user_simulator_provider import UserSimulatorProvider +from .simulation.user_simulator_provider import UserSimulatorProvider logger = logging.getLogger("google_adk." + __name__) diff --git a/src/google/adk/evaluation/eval_config.py b/src/google/adk/evaluation/eval_config.py index d5b94af5e1..13b2e92274 100644 --- a/src/google/adk/evaluation/eval_config.py +++ b/src/google/adk/evaluation/eval_config.py @@ -27,7 +27,7 @@ from ..evaluation.eval_metrics import EvalMetric from .eval_metrics import BaseCriterion from .eval_metrics import Threshold -from .user_simulator import BaseUserSimulatorConfig +from .simulation.user_simulator import BaseUserSimulatorConfig logger = logging.getLogger("google_adk." + __name__) diff --git a/src/google/adk/evaluation/evaluation_generator.py b/src/google/adk/evaluation/evaluation_generator.py index e9c7dc5436..5d8b48c150 100644 --- a/src/google/adk/evaluation/evaluation_generator.py +++ b/src/google/adk/evaluation/evaluation_generator.py @@ -45,9 +45,9 @@ from .eval_case import SessionInput from .eval_set import EvalSet from .request_intercepter_plugin import _RequestIntercepterPlugin -from .user_simulator import Status as UserSimulatorStatus -from .user_simulator import UserSimulator -from .user_simulator_provider import UserSimulatorProvider +from .simulation.user_simulator import Status as UserSimulatorStatus +from .simulation.user_simulator import UserSimulator +from .simulation.user_simulator_provider import UserSimulatorProvider _USER_AUTHOR = "user" _DEFAULT_AUTHOR = "agent" diff --git a/src/google/adk/evaluation/local_eval_service.py b/src/google/adk/evaluation/local_eval_service.py index 806a8d690d..f454266e00 100644 --- a/src/google/adk/evaluation/local_eval_service.py +++ b/src/google/adk/evaluation/local_eval_service.py @@ -31,6 +31,8 @@ from ..memory.base_memory_service import BaseMemoryService from ..sessions.base_session_service import BaseSessionService from ..sessions.in_memory_session_service import InMemorySessionService +from ..utils._client_labels_utils import client_label_context +from ..utils._client_labels_utils import EVAL_CLIENT_LABEL from ..utils.feature_decorator import experimental from .base_eval_service import BaseEvalService from .base_eval_service import EvaluateConfig @@ -53,7 +55,7 @@ from .evaluator import PerInvocationResult from .metric_evaluator_registry import DEFAULT_METRIC_EVALUATOR_REGISTRY from .metric_evaluator_registry import MetricEvaluatorRegistry -from .user_simulator_provider import UserSimulatorProvider +from .simulation.user_simulator_provider import UserSimulatorProvider logger = logging.getLogger('google_adk.' + __name__) @@ -249,11 +251,12 @@ async def _evaluate_single_inference_result( for eval_metric in evaluate_config.eval_metrics: # Perform evaluation of the metric. try: - evaluation_result = await self._evaluate_metric( - eval_metric=eval_metric, - actual_invocations=inference_result.inferences, - expected_invocations=eval_case.conversation, - ) + with client_label_context(EVAL_CLIENT_LABEL): + evaluation_result = await self._evaluate_metric( + eval_metric=eval_metric, + actual_invocations=inference_result.inferences, + expected_invocations=eval_case.conversation, + ) except Exception as e: # We intentionally catch the Exception as we don't want failures to # affect other metric evaluation. @@ -403,17 +406,18 @@ async def _perform_inference_single_eval_item( ) try: - inferences = ( - await EvaluationGenerator._generate_inferences_from_root_agent( - root_agent=root_agent, - user_simulator=self._user_simulator_provider.provide(eval_case), - initial_session=initial_session, - session_id=session_id, - session_service=self._session_service, - artifact_service=self._artifact_service, - memory_service=self._memory_service, - ) - ) + with client_label_context(EVAL_CLIENT_LABEL): + inferences = ( + await EvaluationGenerator._generate_inferences_from_root_agent( + root_agent=root_agent, + user_simulator=self._user_simulator_provider.provide(eval_case), + initial_session=initial_session, + session_id=session_id, + session_service=self._session_service, + artifact_service=self._artifact_service, + memory_service=self._memory_service, + ) + ) inference_result.inferences = inferences inference_result.status = InferenceStatus.SUCCESS diff --git a/src/google/adk/evaluation/simulation/__init__.py b/src/google/adk/evaluation/simulation/__init__.py new file mode 100644 index 0000000000..0a2669d7a2 --- /dev/null +++ b/src/google/adk/evaluation/simulation/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/google/adk/evaluation/llm_backed_user_simulator.py b/src/google/adk/evaluation/simulation/llm_backed_user_simulator.py similarity index 96% rename from src/google/adk/evaluation/llm_backed_user_simulator.py rename to src/google/adk/evaluation/simulation/llm_backed_user_simulator.py index 2fbfcc44d1..4af228772d 100644 --- a/src/google/adk/evaluation/llm_backed_user_simulator.py +++ b/src/google/adk/evaluation/simulation/llm_backed_user_simulator.py @@ -22,14 +22,14 @@ from pydantic import Field from typing_extensions import override -from ..events.event import Event -from ..models.llm_request import LlmRequest -from ..models.registry import LLMRegistry -from ..utils.context_utils import Aclosing -from ..utils.feature_decorator import experimental -from ._retry_options_utils import add_default_retry_options_if_not_present -from .conversation_scenarios import ConversationScenario -from .evaluator import Evaluator +from ...events.event import Event +from ...models.llm_request import LlmRequest +from ...models.registry import LLMRegistry +from ...utils.context_utils import Aclosing +from ...utils.feature_decorator import experimental +from .._retry_options_utils import add_default_retry_options_if_not_present +from ..conversation_scenarios import ConversationScenario +from ..evaluator import Evaluator from .user_simulator import BaseUserSimulatorConfig from .user_simulator import NextUserMessage from .user_simulator import Status diff --git a/src/google/adk/evaluation/static_user_simulator.py b/src/google/adk/evaluation/simulation/static_user_simulator.py similarity index 93% rename from src/google/adk/evaluation/static_user_simulator.py rename to src/google/adk/evaluation/simulation/static_user_simulator.py index 4c5e2cb54d..e1de18a706 100644 --- a/src/google/adk/evaluation/static_user_simulator.py +++ b/src/google/adk/evaluation/simulation/static_user_simulator.py @@ -19,10 +19,10 @@ from typing_extensions import override -from ..events.event import Event -from ..utils.feature_decorator import experimental -from .eval_case import StaticConversation -from .evaluator import Evaluator +from ...events.event import Event +from ...utils.feature_decorator import experimental +from ..eval_case import StaticConversation +from ..evaluator import Evaluator from .user_simulator import BaseUserSimulatorConfig from .user_simulator import NextUserMessage from .user_simulator import Status diff --git a/src/google/adk/evaluation/user_simulator.py b/src/google/adk/evaluation/simulation/user_simulator.py similarity index 95% rename from src/google/adk/evaluation/user_simulator.py rename to src/google/adk/evaluation/simulation/user_simulator.py index c5ab013d7c..57656b76de 100644 --- a/src/google/adk/evaluation/user_simulator.py +++ b/src/google/adk/evaluation/simulation/user_simulator.py @@ -26,10 +26,10 @@ from pydantic import model_validator from pydantic import ValidationError -from ..events.event import Event -from ..utils.feature_decorator import experimental -from .common import EvalBaseModel -from .evaluator import Evaluator +from ...events.event import Event +from ...utils.feature_decorator import experimental +from ..common import EvalBaseModel +from ..evaluator import Evaluator class BaseUserSimulatorConfig(BaseModel): diff --git a/src/google/adk/evaluation/user_simulator_provider.py b/src/google/adk/evaluation/simulation/user_simulator_provider.py similarity index 97% rename from src/google/adk/evaluation/user_simulator_provider.py rename to src/google/adk/evaluation/simulation/user_simulator_provider.py index 1aea8c8c92..b1bfd3226c 100644 --- a/src/google/adk/evaluation/user_simulator_provider.py +++ b/src/google/adk/evaluation/simulation/user_simulator_provider.py @@ -16,8 +16,8 @@ from typing import Optional -from ..utils.feature_decorator import experimental -from .eval_case import EvalCase +from ...utils.feature_decorator import experimental +from ..eval_case import EvalCase from .llm_backed_user_simulator import LlmBackedUserSimulator from .static_user_simulator import StaticUserSimulator from .user_simulator import BaseUserSimulatorConfig diff --git a/src/google/adk/flows/llm_flows/agent_transfer.py b/src/google/adk/flows/llm_flows/agent_transfer.py index 037b8c6d50..ea144bf75d 100644 --- a/src/google/adk/flows/llm_flows/agent_transfer.py +++ b/src/google/adk/flows/llm_flows/agent_transfer.py @@ -24,9 +24,8 @@ from ...agents.invocation_context import InvocationContext from ...events.event import Event from ...models.llm_request import LlmRequest -from ...tools.function_tool import FunctionTool from ...tools.tool_context import ToolContext -from ...tools.transfer_to_agent_tool import transfer_to_agent +from ...tools.transfer_to_agent_tool import TransferToAgentTool from ._base_llm_processor import BaseLlmRequestProcessor if typing.TYPE_CHECKING: @@ -50,13 +49,18 @@ async def run_async( if not transfer_targets: return + transfer_to_agent_tool = TransferToAgentTool( + agent_names=[agent.name for agent in transfer_targets] + ) + llm_request.append_instructions([ _build_target_agents_instructions( - invocation_context.agent, transfer_targets + transfer_to_agent_tool.name, + invocation_context.agent, + transfer_targets, ) ]) - transfer_to_agent_tool = FunctionTool(func=transfer_to_agent) tool_context = ToolContext(invocation_context) await transfer_to_agent_tool.process_llm_request( tool_context=tool_context, llm_request=llm_request @@ -80,10 +84,13 @@ def _build_target_agents_info(target_agent: BaseAgent) -> str: def _build_target_agents_instructions( - agent: LlmAgent, target_agents: list[BaseAgent] + tool_name: str, + agent: LlmAgent, + target_agents: list[BaseAgent], ) -> str: # Build list of available agent names for the NOTE - # target_agents already includes parent agent if applicable, so no need to add it again + # target_agents already includes parent agent if applicable, + # so no need to add it again available_agent_names = [target_agent.name for target_agent in target_agents] # Sort for consistency @@ -101,15 +108,16 @@ def _build_target_agents_instructions( _build_target_agents_info(target_agent) for target_agent in target_agents ])} -If you are the best to answer the question according to your description, you -can answer it. +If you are the best to answer the question according to your description, +you can answer it. If another agent is better for answering the question according to its -description, call `{_TRANSFER_TO_AGENT_FUNCTION_NAME}` function to transfer the -question to that agent. When transferring, do not generate any text other than -the function call. +description, call `{tool_name}` function to transfer the question to that +agent. When transferring, do not generate any text other than the function +call. -**NOTE**: the only available agents for `{_TRANSFER_TO_AGENT_FUNCTION_NAME}` function are {formatted_agent_names}. +**NOTE**: the only available agents for `{tool_name}` function are +{formatted_agent_names}. """ if agent.parent_agent and not agent.disallow_transfer_to_parent: @@ -119,9 +127,6 @@ def _build_target_agents_instructions( return si -_TRANSFER_TO_AGENT_FUNCTION_NAME = transfer_to_agent.__name__ - - def _get_transfer_targets(agent: LlmAgent) -> list[BaseAgent]: from ...agents.llm_agent import LlmAgent diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index db50e77809..824cd26be1 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -156,7 +156,7 @@ async def run_live( break logger.debug('Receive new event: %s', event) yield event - # send back the function response + # send back the function response to models if event.get_function_responses(): logger.debug( 'Sending back last function response event: %s', event @@ -164,6 +164,16 @@ async def run_live( invocation_context.live_request_queue.send_content( event.content ) + # We handle agent transfer here in `run_live` rather than + # in `_postprocess_live` to prevent duplication of function + # response processing. If agent transfer were handled in + # `_postprocess_live`, events yielded from child agent's + # `run_live` would bubble up to parent agent's `run_live`, + # causing `event.get_function_responses()` to be true in both + # child and parent, and `send_content()` to be called twice for + # the same function response. By handling agent transfer here, + # we ensure that only child agent processes its own function + # responses after the transfer. if ( event.content and event.content.parts @@ -174,7 +184,21 @@ async def run_live( await asyncio.sleep(DEFAULT_TRANSFER_AGENT_DELAY) # cancel the tasks that belongs to the closed connection. send_task.cancel() + logger.debug('Closing live connection') await llm_connection.close() + logger.debug('Live connection closed.') + # transfer to the sub agent. + transfer_to_agent = event.actions.transfer_to_agent + if transfer_to_agent: + logger.debug('Transferring to agent: %s', transfer_to_agent) + agent_to_run = self._get_agent_to_run( + invocation_context, transfer_to_agent + ) + async with Aclosing( + agent_to_run.run_live(invocation_context) + ) as agen: + async for item in agen: + yield item if ( event.content and event.content.parts @@ -638,15 +662,6 @@ async def _postprocess_live( ) yield final_event - transfer_to_agent = function_response_event.actions.transfer_to_agent - if transfer_to_agent: - agent_to_run = self._get_agent_to_run( - invocation_context, transfer_to_agent - ) - async with Aclosing(agent_to_run.run_live(invocation_context)) as agen: - async for item in agen: - yield item - async def _postprocess_run_processors_async( self, invocation_context: InvocationContext, llm_response: LlmResponse ) -> AsyncGenerator[Event, None]: diff --git a/src/google/adk/flows/llm_flows/contents.py b/src/google/adk/flows/llm_flows/contents.py index 9274cd462d..fefa014c45 100644 --- a/src/google/adk/flows/llm_flows/contents.py +++ b/src/google/adk/flows/llm_flows/contents.py @@ -224,7 +224,8 @@ def _contains_empty_content(event: Event) -> bool: This can happen to the events that only changed session state. When both content and transcriptions are empty, the event will be considered - as empty. + as empty. The content is considered empty if none of its parts contain text, + inline data, file data, function call, or function response. Args: event: The event to check. @@ -239,7 +240,14 @@ def _contains_empty_content(event: Event) -> bool: not event.content or not event.content.role or not event.content.parts - or event.content.parts[0].text == '' + or all( + not p.text + and not p.inline_data + and not p.file_data + and not p.function_call + and not p.function_response + for p in [event.content.parts[0]] + ) ) and (not event.output_transcription and not event.input_transcription) diff --git a/src/google/adk/memory/vertex_ai_memory_bank_service.py b/src/google/adk/memory/vertex_ai_memory_bank_service.py index b8f434c563..5df012e027 100644 --- a/src/google/adk/memory/vertex_ai_memory_bank_service.py +++ b/src/google/adk/memory/vertex_ai_memory_bank_service.py @@ -48,14 +48,15 @@ def __init__( Args: project: The project ID of the Memory Bank to use. location: The location of the Memory Bank to use. - agent_engine_id: The ID of the agent engine to use for the Memory Bank. + agent_engine_id: The ID of the agent engine to use for the Memory Bank, e.g. '456' in - 'projects/my-project/locations/us-central1/reasoningEngines/456'. + 'projects/my-project/locations/us-central1/reasoningEngines/456'. To + extract from api_resource.name, use: + ``agent_engine.api_resource.name.split('/')[-1]`` express_mode_api_key: The API key to use for Express Mode. If not provided, the API key from the GOOGLE_API_KEY environment variable will - be used. It will only be used if GOOGLE_GENAI_USE_VERTEXAI is true. - Do not use Google AI Studio API key for this field. For more details, - visit + be used. It will only be used if GOOGLE_GENAI_USE_VERTEXAI is true. Do + not use Google AI Studio API key for this field. For more details, visit https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview """ self._project = project @@ -65,6 +66,14 @@ def __init__( project, location, express_mode_api_key ) + if agent_engine_id and '/' in agent_engine_id: + logger.warning( + "agent_engine_id appears to be a full resource path: '%s'. " + "Expected just the ID (e.g., '456'). " + "Extract the ID using: agent_engine.api_resource.name.split('/')[-1]", + agent_engine_id, + ) + @override async def add_session_to_memory(self, session: Session): if not self._agent_engine_id: diff --git a/src/google/adk/models/__init__.py b/src/google/adk/models/__init__.py index 9f3c2a2c48..1be0cc698e 100644 --- a/src/google/adk/models/__init__.py +++ b/src/google/adk/models/__init__.py @@ -33,3 +33,23 @@ LLMRegistry.register(Gemini) LLMRegistry.register(Gemma) LLMRegistry.register(ApigeeLlm) + +# Optionally register Claude if anthropic package is installed +try: + from .anthropic_llm import Claude + + LLMRegistry.register(Claude) + __all__.append('Claude') +except Exception: + # Claude support requires: pip install google-adk[extensions] + pass + +# Optionally register LiteLlm if litellm package is installed +try: + from .lite_llm import LiteLlm + + LLMRegistry.register(LiteLlm) + __all__.append('LiteLlm') +except Exception: + # LiteLLM support requires: pip install google-adk[extensions] + pass diff --git a/src/google/adk/models/anthropic_llm.py b/src/google/adk/models/anthropic_llm.py index 6f343367a3..163fbe4571 100644 --- a/src/google/adk/models/anthropic_llm.py +++ b/src/google/adk/models/anthropic_llm.py @@ -28,7 +28,8 @@ from typing import TYPE_CHECKING from typing import Union -from anthropic import AnthropicVertex +from anthropic import AsyncAnthropic +from anthropic import AsyncAnthropicVertex from anthropic import NOT_GIVEN from anthropic import types as anthropic_types from google.genai import types @@ -41,7 +42,7 @@ if TYPE_CHECKING: from .llm_request import LlmRequest -__all__ = ["Claude"] +__all__ = ["AnthropicLlm", "Claude"] logger = logging.getLogger("google_adk." + __name__) @@ -155,9 +156,11 @@ def content_to_message_param( ) -> anthropic_types.MessageParam: message_block = [] for part in content.parts or []: - # Image data is not supported in Claude for model turns. - if _is_image_part(part): - logger.warning("Image data is not supported in Claude for model turns.") + # Image data is not supported in Claude for assistant turns. + if content.role != "user" and _is_image_part(part): + logger.warning( + "Image data is not supported in Claude for assistant turns." + ) continue message_block.append(part_to_message_block(part)) @@ -262,15 +265,15 @@ def function_declaration_to_tool_param( ) -class Claude(BaseLlm): - """Integration with Claude models served from Vertex AI. +class AnthropicLlm(BaseLlm): + """Integration with Claude models via the Anthropic API. Attributes: model: The name of the Claude model. max_tokens: The maximum number of tokens to generate. """ - model: str = "claude-3-5-sonnet-v2@20241022" + model: str = "claude-sonnet-4-20250514" max_tokens: int = 8192 @classmethod @@ -302,7 +305,7 @@ async def generate_content_async( else NOT_GIVEN ) # TODO(b/421255973): Enable streaming for anthropic models. - message = self._anthropic_client.messages.create( + message = await self._anthropic_client.messages.create( model=llm_request.model, system=llm_request.config.system_instruction, messages=messages, @@ -313,7 +316,23 @@ async def generate_content_async( yield message_to_generate_content_response(message) @cached_property - def _anthropic_client(self) -> AnthropicVertex: + def _anthropic_client(self) -> AsyncAnthropic: + return AsyncAnthropic() + + +class Claude(AnthropicLlm): + """Integration with Claude models served from Vertex AI. + + Attributes: + model: The name of the Claude model. + max_tokens: The maximum number of tokens to generate. + """ + + model: str = "claude-3-5-sonnet-v2@20241022" + + @cached_property + @override + def _anthropic_client(self) -> AsyncAnthropicVertex: if ( "GOOGLE_CLOUD_PROJECT" not in os.environ or "GOOGLE_CLOUD_LOCATION" not in os.environ @@ -323,7 +342,7 @@ def _anthropic_client(self) -> AnthropicVertex: " Anthropic on Vertex." ) - return AnthropicVertex( + return AsyncAnthropicVertex( project_id=os.environ["GOOGLE_CLOUD_PROJECT"], region=os.environ["GOOGLE_CLOUD_LOCATION"], ) diff --git a/src/google/adk/models/base_llm_connection.py b/src/google/adk/models/base_llm_connection.py index afce550b13..1bf522740e 100644 --- a/src/google/adk/models/base_llm_connection.py +++ b/src/google/adk/models/base_llm_connection.py @@ -72,7 +72,8 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: Yields: LlmResponse: The model response. """ - pass + # We need to yield here to help type checkers infer the correct type. + yield @abstractmethod async def close(self): diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 0b72c79f83..55d4b62e96 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -21,6 +21,7 @@ from google.genai import types from ..utils.context_utils import Aclosing +from ..utils.variant_utils import GoogleLLMVariant from .base_llm_connection import BaseLlmConnection from .llm_response import LlmResponse @@ -36,10 +37,15 @@ class GeminiLlmConnection(BaseLlmConnection): """The Gemini model connection.""" - def __init__(self, gemini_session: live.AsyncSession): + def __init__( + self, + gemini_session: live.AsyncSession, + api_backend: GoogleLLMVariant = GoogleLLMVariant.VERTEX_AI, + ): self._gemini_session = gemini_session self._input_transcription_text: str = '' self._output_transcription_text: str = '' + self._api_backend = api_backend async def send_history(self, history: list[types.Content]): """Sends the conversation history to the gemini model. @@ -171,6 +177,9 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: yield self.__build_full_text_response(text) text = '' yield llm_response + # Note: in some cases, tool_call may arrive before + # generation_complete, causing transcription to appear after + # tool_call in the session log. if message.server_content.input_transcription: if message.server_content.input_transcription.text: self._input_transcription_text += ( @@ -215,6 +224,32 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: partial=False, ) self._output_transcription_text = '' + # The Gemini API might not send a transcription finished signal. + # Instead, we rely on generation_complete, turn_complete or + # interrupted signals to flush any pending transcriptions. + if self._api_backend == GoogleLLMVariant.GEMINI_API and ( + message.server_content.interrupted + or message.server_content.turn_complete + or message.server_content.generation_complete + ): + if self._input_transcription_text: + yield LlmResponse( + input_transcription=types.Transcription( + text=self._input_transcription_text, + finished=True, + ), + partial=False, + ) + self._input_transcription_text = '' + if self._output_transcription_text: + yield LlmResponse( + output_transcription=types.Transcription( + text=self._output_transcription_text, + finished=True, + ), + partial=False, + ) + self._output_transcription_text = '' if message.server_content.turn_complete: if text: yield self.__build_full_text_response(text) @@ -244,7 +279,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: ] yield LlmResponse(content=types.Content(role='model', parts=parts)) if message.session_resumption_update: - logger.info('Received session resumption message: %s', message) + logger.debug('Received session resumption message: %s', message) yield ( LlmResponse( live_session_resumption_update=message.session_resumption_update diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 1bdd311104..6b21cf62c7 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -19,8 +19,6 @@ import copy from functools import cached_property import logging -import os -import sys from typing import AsyncGenerator from typing import cast from typing import Optional @@ -31,7 +29,7 @@ from google.genai.errors import ClientError from typing_extensions import override -from .. import version +from ..utils._client_labels_utils import get_client_labels from ..utils.context_utils import Aclosing from ..utils.streaming_utils import StreamingResponseAggregator from ..utils.variant_utils import GoogleLLMVariant @@ -49,8 +47,7 @@ _NEW_LINE = '\n' _EXCLUDED_PART_FIELD = {'inline_data': {'data'}} -_AGENT_ENGINE_TELEMETRY_TAG = 'remote_reasoning_engine' -_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_AGENT_ENGINE_ID' + _RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE = """ On how to mitigate this issue, please refer to: @@ -245,7 +242,7 @@ def api_client(self) -> Client: return Client( http_options=types.HttpOptions( - headers=self._tracking_headers, + headers=self._tracking_headers(), retry_options=self.retry_options, ) ) @@ -258,16 +255,12 @@ def _api_backend(self) -> GoogleLLMVariant: else GoogleLLMVariant.GEMINI_API ) - @cached_property def _tracking_headers(self) -> dict[str, str]: - framework_label = f'google-adk/{version.__version__}' - if os.environ.get(_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME): - framework_label = f'{framework_label}+{_AGENT_ENGINE_TELEMETRY_TAG}' - language_label = 'gl-python/' + sys.version.split()[0] - version_header_value = f'{framework_label} {language_label}' + labels = get_client_labels() + header_value = ' '.join(labels) tracking_headers = { - 'x-goog-api-client': version_header_value, - 'user-agent': version_header_value, + 'x-goog-api-client': header_value, + 'user-agent': header_value, } return tracking_headers @@ -286,7 +279,7 @@ def _live_api_client(self) -> Client: return Client( http_options=types.HttpOptions( - headers=self._tracking_headers, api_version=self._live_api_version + headers=self._tracking_headers(), api_version=self._live_api_version ) ) @@ -310,7 +303,7 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: if not llm_request.live_connect_config.http_options.headers: llm_request.live_connect_config.http_options.headers = {} llm_request.live_connect_config.http_options.headers.update( - self._tracking_headers + self._tracking_headers() ) llm_request.live_connect_config.http_options.api_version = ( self._live_api_version @@ -325,6 +318,23 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: types.Part.from_text(text=llm_request.config.system_instruction) ], ) + if ( + llm_request.live_connect_config.session_resumption + and llm_request.live_connect_config.session_resumption.transparent + ): + logger.debug( + 'session resumption config: %s', + llm_request.live_connect_config.session_resumption, + ) + logger.debug( + 'self._api_backend: %s', + self._api_backend, + ) + if self._api_backend == GoogleLLMVariant.GEMINI_API: + raise ValueError( + 'Transparent session resumption is only supported for Vertex AI' + ' backend. Please use Vertex AI backend.' + ) llm_request.live_connect_config.tools = llm_request.config.tools logger.info('Connecting to live for model: %s', llm_request.model) logger.debug('Connecting to live with llm_request:%s', llm_request) @@ -332,7 +342,7 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: async with self._live_api_client.aio.live.connect( model=llm_request.model, config=llm_request.live_connect_config ) as live_session: - yield GeminiLlmConnection(live_session) + yield GeminiLlmConnection(live_session, api_backend=self._api_backend) async def _adapt_computer_use_tool(self, llm_request: LlmRequest) -> None: """Adapt the google computer use predefined functions to the adk computer use toolset.""" @@ -380,7 +390,7 @@ async def _preprocess_request(self, llm_request: LlmRequest) -> None: def _merge_tracking_headers(self, headers: dict[str, str]) -> dict[str, str]: """Merge tracking headers to the given headers.""" headers = headers or {} - for key, tracking_header_value in self._tracking_headers.items(): + for key, tracking_header_value in self._tracking_headers().items(): custom_value = headers.get(key, None) if not custom_value: headers[key] = tracking_header_value diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index c263a41b2a..162db05945 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -39,8 +39,8 @@ from litellm import acompletion from litellm import ChatCompletionAssistantMessage from litellm import ChatCompletionAssistantToolCall -from litellm import ChatCompletionDeveloperMessage from litellm import ChatCompletionMessageToolCall +from litellm import ChatCompletionSystemMessage from litellm import ChatCompletionToolMessage from litellm import ChatCompletionUserMessage from litellm import completion @@ -96,6 +96,64 @@ def _decode_inline_text_data(raw_bytes: bytes) -> str: return raw_bytes.decode("latin-1", errors="replace") +def _iter_reasoning_texts(reasoning_value: Any) -> Iterable[str]: + """Yields textual fragments from provider specific reasoning payloads.""" + if reasoning_value is None: + return + + if isinstance(reasoning_value, types.Content): + if not reasoning_value.parts: + return + for part in reasoning_value.parts: + if part and part.text: + yield part.text + return + + if isinstance(reasoning_value, str): + yield reasoning_value + return + + if isinstance(reasoning_value, list): + for value in reasoning_value: + yield from _iter_reasoning_texts(value) + return + + if isinstance(reasoning_value, dict): + # LiteLLM currently nests “reasoning” text under a few known keys. + # (Documented in https://docs.litellm.ai/docs/openai#reasoning-outputs) + for key in ("text", "content", "reasoning", "reasoning_content"): + text_value = reasoning_value.get(key) + if isinstance(text_value, str): + yield text_value + return + + text_attr = getattr(reasoning_value, "text", None) + if isinstance(text_attr, str): + yield text_attr + elif isinstance(reasoning_value, (int, float, bool)): + yield str(reasoning_value) + + +def _convert_reasoning_value_to_parts(reasoning_value: Any) -> List[types.Part]: + """Converts provider reasoning payloads into Gemini thought parts.""" + return [ + types.Part(text=text, thought=True) + for text in _iter_reasoning_texts(reasoning_value) + if text + ] + + +def _extract_reasoning_value(message: Message | Dict[str, Any]) -> Any: + """Fetches the reasoning payload from a LiteLLM message or dict.""" + if message is None: + return None + if hasattr(message, "reasoning_content"): + return getattr(message, "reasoning_content") + if isinstance(message, dict): + return message.get("reasoning_content") + return None + + class ChatCompletionFileUrlObject(TypedDict, total=False): file_data: str file_id: str @@ -113,6 +171,10 @@ class TextChunk(BaseModel): text: str +class ReasoningChunk(BaseModel): + parts: List[types.Part] + + class UsageMetadataChunk(BaseModel): prompt_tokens: int completion_tokens: int @@ -570,10 +632,8 @@ def _to_litellm_role(role: Optional[str]) -> Literal["user", "assistant"]: } -def _schema_to_dict(schema: types.Schema) -> dict: - """Recursively converts a types.Schema to a pure-python dict - - with all enum values written as lower-case strings. +def _schema_to_dict(schema: types.Schema | dict[str, Any]) -> dict: + """Recursively converts a schema object or dict to a pure-python dict. Args: schema: The schema to convert. @@ -581,38 +641,36 @@ def _schema_to_dict(schema: types.Schema) -> dict: Returns: The dictionary representation of the schema. """ - # Dump without json encoding so we still get Enum members - schema_dict = schema.model_dump(exclude_none=True) + schema_dict = ( + schema.model_dump(exclude_none=True) + if isinstance(schema, types.Schema) + else dict(schema) + ) + enum_values = schema_dict.get("enum") + if isinstance(enum_values, (list, tuple)): + schema_dict["enum"] = [value for value in enum_values if value is not None] - # ---- normalise this level ------------------------------------------------ - if "type" in schema_dict: - # schema_dict["type"] can be an Enum or a str + if "type" in schema_dict and schema_dict["type"] is not None: t = schema_dict["type"] - schema_dict["type"] = (t.value if isinstance(t, types.Type) else t).lower() + schema_dict["type"] = ( + t.value if isinstance(t, types.Type) else str(t) + ).lower() - # ---- recurse into `items` ----------------------------------------------- if "items" in schema_dict: - schema_dict["items"] = _schema_to_dict( - schema.items - if isinstance(schema.items, types.Schema) - else types.Schema.model_validate(schema_dict["items"]) + items = schema_dict["items"] + schema_dict["items"] = ( + _schema_to_dict(items) + if isinstance(items, (types.Schema, dict)) + else items ) - # ---- recurse into `properties` ------------------------------------------ if "properties" in schema_dict: new_props = {} for key, value in schema_dict["properties"].items(): - # value is a dict → rebuild a Schema object and recurse - if isinstance(value, dict): - new_props[key] = _schema_to_dict(types.Schema.model_validate(value)) - # value is already a Schema instance - elif isinstance(value, types.Schema): + if isinstance(value, (types.Schema, dict)): new_props[key] = _schema_to_dict(value) - # plain dict without nested schemas else: new_props[key] = value - if "type" in new_props[key]: - new_props[key]["type"] = new_props[key]["type"].lower() schema_dict["properties"] = new_props return schema_dict @@ -660,7 +718,6 @@ def _function_declaration_to_tool_param( }, } - # Handle required field from parameters required_fields = ( getattr(function_declaration.parameters, "required", None) if function_declaration.parameters @@ -668,8 +725,6 @@ def _function_declaration_to_tool_param( ) if required_fields: tool_params["function"]["parameters"]["required"] = required_fields - # parameters_json_schema already has required field in the json schema, - # no need to add it separately return tool_params @@ -678,7 +733,14 @@ def _model_response_to_chunk( response: ModelResponse, ) -> Generator[ Tuple[ - Optional[Union[TextChunk, FunctionChunk, UsageMetadataChunk]], + Optional[ + Union[ + TextChunk, + FunctionChunk, + UsageMetadataChunk, + ReasoningChunk, + ] + ], Optional[str], ], None, @@ -703,11 +765,18 @@ def _model_response_to_chunk( message_content: Optional[OpenAIMessageContent] = None tool_calls: list[ChatCompletionMessageToolCall] = [] + reasoning_parts: List[types.Part] = [] if message is not None: ( message_content, tool_calls, ) = _split_message_content_and_tool_calls(message) + reasoning_value = _extract_reasoning_value(message) + if reasoning_value: + reasoning_parts = _convert_reasoning_value_to_parts(reasoning_value) + + if reasoning_parts: + yield ReasoningChunk(parts=reasoning_parts), finish_reason if message_content: yield TextChunk(text=message_content), finish_reason @@ -771,8 +840,13 @@ def _model_response_to_generate_content_response( if not message: raise ValueError("No message in response") + thought_parts = _convert_reasoning_value_to_parts( + _extract_reasoning_value(message) + ) llm_response = _message_to_generate_content_response( - message, model_version=response.model + message, + model_version=response.model, + thought_parts=thought_parts or None, ) if finish_reason: # If LiteLLM already provides a FinishReason enum (e.g., for Gemini), use @@ -797,7 +871,11 @@ def _model_response_to_generate_content_response( def _message_to_generate_content_response( - message: Message, *, is_partial: bool = False, model_version: str = None + message: Message, + *, + is_partial: bool = False, + model_version: str = None, + thought_parts: Optional[List[types.Part]] = None, ) -> LlmResponse: """Converts a litellm message to LlmResponse. @@ -810,7 +888,13 @@ def _message_to_generate_content_response( The LlmResponse. """ - parts = [] + parts: List[types.Part] = [] + if not thought_parts: + thought_parts = _convert_reasoning_value_to_parts( + _extract_reasoning_value(message) + ) + if thought_parts: + parts.extend(thought_parts) message_content, tool_calls = _split_message_content_and_tool_calls(message) if isinstance(message_content, str) and message_content: parts.append(types.Part.from_text(text=message_content)) @@ -899,8 +983,8 @@ def _get_completion_inputs( if llm_request.config.system_instruction: messages.insert( 0, - ChatCompletionDeveloperMessage( - role="developer", + ChatCompletionSystemMessage( + role="system", content=llm_request.config.system_instruction, ), ) @@ -972,15 +1056,9 @@ def _build_function_declaration_log( k: v.model_dump(exclude_none=True) for k, v in func_decl.parameters.properties.items() }) - elif func_decl.parameters_json_schema: - param_str = str(func_decl.parameters_json_schema) - return_str = "None" if func_decl.response: return_str = str(func_decl.response.model_dump(exclude_none=True)) - elif func_decl.response_json_schema: - return_str = str(func_decl.response_json_schema) - return f"{func_decl.name}: {param_str} -> {return_str}" @@ -1182,6 +1260,7 @@ async def generate_content_async( if stream: text = "" + reasoning_parts: List[types.Part] = [] # Track function calls by index function_calls = {} # index -> {name, args, id} completion_args["stream"] = True @@ -1223,6 +1302,14 @@ async def generate_content_async( is_partial=True, model_version=part.model, ) + elif isinstance(chunk, ReasoningChunk): + if chunk.parts: + reasoning_parts.extend(chunk.parts) + yield LlmResponse( + content=types.Content(role="model", parts=list(chunk.parts)), + partial=True, + model_version=part.model, + ) elif isinstance(chunk, UsageMetadataChunk): usage_metadata = types.GenerateContentResponseUsageMetadata( prompt_token_count=chunk.prompt_tokens, @@ -1256,16 +1343,27 @@ async def generate_content_async( tool_calls=tool_calls, ), model_version=part.model, + thought_parts=list(reasoning_parts) + if reasoning_parts + else None, ) ) text = "" + reasoning_parts = [] function_calls.clear() - elif finish_reason == "stop" and text: + elif finish_reason == "stop" and (text or reasoning_parts): + message_content = text if text else None aggregated_llm_response = _message_to_generate_content_response( - ChatCompletionAssistantMessage(role="assistant", content=text), + ChatCompletionAssistantMessage( + role="assistant", content=message_content + ), model_version=part.model, + thought_parts=list(reasoning_parts) + if reasoning_parts + else None, ) text = "" + reasoning_parts = [] # waiting until streaming ends to yield the llm_response as litellm tends # to send chunk that contains usage_metadata after the chunk with @@ -1290,11 +1388,19 @@ async def generate_content_async( def supported_models(cls) -> list[str]: """Provides the list of supported models. - LiteLlm supports all models supported by litellm. We do not keep track of - these models here. So we return an empty list. + This registers common provider prefixes. LiteLlm can handle many more, + but these patterns activate the integration for the most common use cases. + See https://docs.litellm.ai/docs/providers for a full list. Returns: A list of supported models. """ - return [] + return [ + # For OpenAI models (e.g., "openai/gpt-4o") + r"openai/.*", + # For Groq models via Groq API (e.g., "groq/llama3-70b-8192") + r"groq/.*", + # For Anthropic models (e.g., "anthropic/claude-3-opus-20240229") + r"anthropic/.*", + ] diff --git a/src/google/adk/models/registry.py b/src/google/adk/models/registry.py index 22e24d4c18..852996ff40 100644 --- a/src/google/adk/models/registry.py +++ b/src/google/adk/models/registry.py @@ -99,4 +99,26 @@ def resolve(model: str) -> type[BaseLlm]: if re.compile(regex).fullmatch(model): return llm_class - raise ValueError(f'Model {model} not found.') + # Provide helpful error messages for known patterns + error_msg = f'Model {model} not found.' + + # Check if it matches known patterns that require optional dependencies + if re.match(r'^claude-', model): + error_msg += ( + '\n\nClaude models require the anthropic package.' + '\nInstall it with: pip install google-adk[extensions]' + '\nOr: pip install anthropic>=0.43.0' + ) + elif '/' in model: + # Any model with provider/model format likely needs LiteLLM + error_msg += ( + '\n\nProvider-style models (e.g., "provider/model-name") require' + ' the litellm package.' + '\nInstall it with: pip install google-adk[extensions]' + '\nOr: pip install litellm>=1.75.5' + '\n\nSupported providers include: openai, groq, anthropic, and 100+' + ' others.' + '\nSee https://docs.litellm.ai/docs/providers for a full list.' + ) + + raise ValueError(error_msg) diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 2bb0168928..4cf5a29546 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -19,6 +19,7 @@ import logging from pathlib import Path import queue +import sys from typing import Any from typing import AsyncGenerator from typing import Callable @@ -66,6 +67,23 @@ logger = logging.getLogger('google_adk.' + __name__) +def _is_tool_call_or_response(event: Event) -> bool: + return bool(event.get_function_calls() or event.get_function_responses()) + + +def _is_transcription(event: Event) -> bool: + return ( + event.input_transcription is not None + or event.output_transcription is not None + ) + + +def _has_non_empty_transcription_text(transcription) -> bool: + return bool( + transcription and transcription.text and transcription.text.strip() + ) + + class Runner: """The Runner class is used to run agents. @@ -625,6 +643,7 @@ async def _exec_with_plugin( invocation_context: The invocation context session: The current session execute_fn: A callable that returns an AsyncGenerator of Events + is_live_call: Whether this is a live call Yields: Events from the execution, including any generated by plugins @@ -650,13 +669,74 @@ async def _exec_with_plugin( yield early_exit_event else: # Step 2: Otherwise continue with normal execution + # Note for live/bidi: + # the transcription may arrive later then the action(function call + # event and thus function response event). In this case, the order of + # transcription and function call event will be wrong if we just + # append as it arrives. To address this, we should check if there is + # transcription going on. If there is transcription going on, we + # should hold on appending the function call event until the + # transcription is finished. The transcription in progress can be + # identified by checking if the transcription event is partial. When + # the next transcription event is not partial, it means the previous + # transcription is finished. Then if there is any buffered function + # call event, we should append them after this finished(non-parital) + # transcription event. + buffered_events: list[Event] = [] + is_transcribing: bool = False + async with Aclosing(execute_fn(invocation_context)) as agen: async for event in agen: - if not event.partial: - if self._should_append_event(event, is_live_call): + if is_live_call: + if event.partial and _is_transcription(event): + is_transcribing = True + if is_transcribing and _is_tool_call_or_response(event): + # only buffer function call and function response event which is + # non-partial + buffered_events.append(event) + continue + # Note for live/bidi: for audio response, it's considered as + # non-paritla event(event.partial=None) + # event.partial=False and event.partial=None are considered as + # non-partial event; event.partial=True is considered as partial + # event. + if event.partial is not True: + if _is_transcription(event) and ( + _has_non_empty_transcription_text(event.input_transcription) + or _has_non_empty_transcription_text( + event.output_transcription + ) + ): + # transcription end signal, append buffered events + is_transcribing = False + logger.debug( + 'Appending transcription finished event: %s', event + ) + if self._should_append_event(event, is_live_call): + await self.session_service.append_event( + session=session, event=event + ) + + for buffered_event in buffered_events: + logger.debug('Appending buffered event: %s', buffered_event) + await self.session_service.append_event( + session=session, event=buffered_event + ) + buffered_events = [] + else: + # non-transcription event or empty transcription event, for + # example, event that stores blob reference, should be appended. + if self._should_append_event(event, is_live_call): + logger.debug('Appending non-buffered event: %s', event) + await self.session_service.append_event( + session=session, event=event + ) + else: + if event.partial is not True: await self.session_service.append_event( session=session, event=event ) + # Step 3: Run the on_event callbacks to optionally modify the event. modified_event = await plugin_manager.run_on_event_callback( invocation_context=invocation_context, event=event @@ -985,16 +1065,16 @@ async def run_debug( over session management, event streaming, and error handling. Args: - user_messages: Message(s) to send to the agent. Can be: - - Single string: "What is 2+2?" - - List of strings: ["Hello!", "What's my name?"] + user_messages: Message(s) to send to the agent. Can be: - Single string: + "What is 2+2?" - List of strings: ["Hello!", "What's my name?"] user_id: User identifier. Defaults to "debug_user_id". - session_id: Session identifier for conversation persistence. - Defaults to "debug_session_id". Reuse the same ID to continue a conversation. + session_id: Session identifier for conversation persistence. Defaults to + "debug_session_id". Reuse the same ID to continue a conversation. run_config: Optional configuration for the agent execution. - quiet: If True, suppresses console output. Defaults to False (output shown). - verbose: If True, shows detailed tool calls and responses. Defaults to False - for cleaner output showing only final agent responses. + quiet: If True, suppresses console output. Defaults to False (output + shown). + verbose: If True, shows detailed tool calls and responses. Defaults to + False for cleaner output showing only final agent responses. Returns: list[Event]: All events from all messages. @@ -1011,7 +1091,8 @@ async def run_debug( >>> await runner.run_debug(["Hello!", "What's my name?"]) Continue a debug session: - >>> await runner.run_debug("What did we discuss?") # Continues default session + >>> await runner.run_debug("What did we discuss?") # Continues default + session Separate debug sessions: >>> await runner.run_debug("Hi", user_id="alice", session_id="debug1") @@ -1353,7 +1434,12 @@ async def close(self): logger.info('Runner closed.') - async def __aenter__(self): + if sys.version_info < (3, 11): + Self = 'Runner' # pylint: disable=invalid-name + else: + from typing import Self # pylint: disable=g-import-not-at-top + + async def __aenter__(self) -> Self: """Async context manager entry.""" return self diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index b929f23409..a352918211 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -722,6 +722,11 @@ async def append_event(self, session: Session, event: Event) -> Event: if session_state_delta: storage_session.state = storage_session.state | session_state_delta + if storage_session._dialect_name == "sqlite": + update_time = datetime.utcfromtimestamp(event.timestamp) + else: + update_time = datetime.fromtimestamp(event.timestamp) + storage_session.update_time = update_time sql_session.add(StorageEvent.from_event(session, event)) await sql_session.commit() diff --git a/src/google/adk/sessions/migration/README.md b/src/google/adk/sessions/migration/README.md new file mode 100644 index 0000000000..6a9079534e --- /dev/null +++ b/src/google/adk/sessions/migration/README.md @@ -0,0 +1,109 @@ +# Process for a New Schema Version + +This document outlines the steps required to introduce a new database schema +version for `DatabaseSessionService`. Let's assume you are introducing schema +version `2.0`, migrating from `1.0`. + +## 1. Update SQLAlchemy Models + +Modify the SQLAlchemy model classes (`StorageSession`, `StorageEvent`, +`StorageAppState`, `StorageUserState`, `StorageMetadata`) in +`database_session_service.py` to reflect the new `2.0` schema. This could +involve adding new `mapped_column` definitions, changing types, or adding new +classes for new tables. + +## 2. Create a New Migration Script + +You need to create a script that migrates data from schema `1.0` to `2.0`. + +* Create a new file, for example: + `google/adk/sessions/migration/migrate_1_0_to_2_0.py`. +* This script must contain a `migrate(source_db_url: str, dest_db_url: str)` + function, similar to `migrate_from_sqlalchemy_pickle.py`. +* Inside this function: + * Connect to the `source_db_url` (which has schema 1.0) and `dest_db_url` + engines using SQLAlchemy. + * **Important**: Create the tables in the destination database using the + new 2.0 schema definition by calling + `dss.Base.metadata.create_all(dest_engine)`. + * Read data from the source tables (schema 1.0). The recommended way to do + this without relying on outdated models is to use `sqlalchemy.text`, + like: + + ```python + from sqlalchemy import text + ... + rows = source_session.execute(text("SELECT * FROM sessions")).mappings().all() + ``` + + * For each row read from the source, transform the data as necessary to + fit the `2.0` schema, and create an instance of the corresponding new + SQLAlchemy model (e.g., `dss.StorageSession(...)`). + * Add these new `2.0` objects to the destination session, ideally using + `dest_session.merge()` to upsert. + * After migrating data for all tables, ensure the destination database is + marked with the new schema version: + + ```python + from google.adk.sessions import database_session_service as dss + from google.adk.sessions.migration import _schema_check + ... + dest_session.merge( + dss.StorageMetadata( + key=_schema_check.SCHEMA_VERSION_KEY, + value="2.0", + ) + ) + dest_session.commit() + ``` + +## 3. Update Schema Version Constant + +You need to update `CURRENT_SCHEMA_VERSION` in +`google/adk/sessions/migration/_schema_check.py` to reflect the new version: + +```python +CURRENT_SCHEMA_VERSION = "2.0" +``` + +This will also update `LATEST_VERSION` in `migration_runner.py`, as it uses this +constant. + +## 4. Register the New Migration in Migration Runner + +In `google/adk/sessions/migration/migration_runner.py`, import your new +migration script and add it to the `MIGRATIONS` dictionary. This tells the +runner how to get from version `1.0` to `2.0`. For example: + +```python +from google.adk.sessions.migration import _schema_check +from google.adk.sessions.migration import migrate_from_sqlalchemy_pickle +from google.adk.sessions.migration import migrate_1_0_to_2_0 +... +MIGRATIONS = { + _schema_check.SCHEMA_VERSION_0_1_PICKLE: ( + _schema_check.SCHEMA_VERSION_1_0_JSON, + migrate_from_sqlalchemy_pickle.migrate, + ), + _schema_check.SCHEMA_VERSION_1_0_JSON: ( + "2.0", + migrate_1_0_to_2_0.migrate, + ), +} +``` + +## 5. Update `DatabaseSessionService` Business Logic + +If your schema change affects how data should be read or written during normal +operation (e.g., you added a new column that needs to be populated on session +creation), update the methods within `DatabaseSessionService` (`create_session`, +`get_session`, `append_event`, etc.) in `database_session_service.py` +accordingly. + +## 6. CLI Command Changes + +No changes are needed for the Click command definition in `cli_tools_click.py`. +The `adk migrate session` command calls `migration_runner.upgrade()`, which will +now automatically detect the source database version and apply the necessary +migration steps (e.g., `0.1 -> 1.0 -> 2.0`, or `1.0 -> 2.0`) to reach +`LATEST_VERSION`. \ No newline at end of file diff --git a/src/google/adk/sessions/sqlite_session_service.py b/src/google/adk/sessions/sqlite_session_service.py index 10d05f6dfd..8ba6531f52 100644 --- a/src/google/adk/sessions/sqlite_session_service.py +++ b/src/google/adk/sessions/sqlite_session_service.py @@ -323,7 +323,7 @@ async def append_event(self, session: Session, event: Event) -> Event: # Trim temp state before persisting event = self._trim_temp_delta_state(event) - now = time.time() + event_timestamp = event.timestamp async with self._get_db_connection() as db: # Check for stale session @@ -355,11 +355,15 @@ async def append_event(self, session: Session, event: Event) -> Event: if app_state_delta: await self._upsert_app_state( - db, session.app_name, app_state_delta, now + db, session.app_name, app_state_delta, event_timestamp ) if user_state_delta: await self._upsert_user_state( - db, session.app_name, session.user_id, user_state_delta, now + db, + session.app_name, + session.user_id, + user_state_delta, + event_timestamp, ) if session_state_delta: await self._update_session_state_in_db( @@ -368,7 +372,7 @@ async def append_event(self, session: Session, event: Event) -> Event: session.user_id, session.id, session_state_delta, - now, + event_timestamp, ) has_session_state_delta = True @@ -392,12 +396,17 @@ async def append_event(self, session: Session, event: Event) -> Event: await db.execute( "UPDATE sessions SET update_time=? WHERE app_name=? AND user_id=?" " AND id=?", - (now, session.app_name, session.user_id, session.id), + ( + event_timestamp, + session.app_name, + session.user_id, + session.id, + ), ) await db.commit() - # Update timestamp with commit time - session.last_update_time = now + # Update timestamp based on event time + session.last_update_time = event_timestamp # Also update the in-memory session await super().append_event(session=session, event=event) diff --git a/src/google/adk/tools/__init__.py b/src/google/adk/tools/__init__.py index 1777bd93c5..32264adcbd 100644 --- a/src/google/adk/tools/__init__.py +++ b/src/google/adk/tools/__init__.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from ..auth.auth_tool import AuthToolArguments from .agent_tool import AgentTool + from .api_registry import ApiRegistry from .apihub_tool.apihub_toolset import APIHubToolset from .base_tool import BaseTool from .discovery_engine_search_tool import DiscoveryEngineSearchTool @@ -37,6 +38,7 @@ from .preload_memory_tool import preload_memory_tool as preload_memory from .tool_context import ToolContext from .transfer_to_agent_tool import transfer_to_agent + from .transfer_to_agent_tool import TransferToAgentTool from .url_context_tool import url_context from .vertex_ai_search_tool import VertexAiSearchTool @@ -75,10 +77,15 @@ 'preload_memory': ('.preload_memory_tool', 'preload_memory_tool'), 'ToolContext': ('.tool_context', 'ToolContext'), 'transfer_to_agent': ('.transfer_to_agent_tool', 'transfer_to_agent'), + 'TransferToAgentTool': ( + '.transfer_to_agent_tool', + 'TransferToAgentTool', + ), 'url_context': ('.url_context_tool', 'url_context'), 'VertexAiSearchTool': ('.vertex_ai_search_tool', 'VertexAiSearchTool'), 'MCPToolset': ('.mcp_tool.mcp_toolset', 'MCPToolset'), 'McpToolset': ('.mcp_tool.mcp_toolset', 'McpToolset'), + 'ApiRegistry': ('.api_registry', 'ApiRegistry'), } __all__ = list(_LAZY_MAPPING.keys()) diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index 702f6e43aa..46d8616619 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -45,11 +45,22 @@ class AgentTool(BaseTool): Attributes: agent: The agent to wrap. skip_summarization: Whether to skip summarization of the agent output. + include_plugins: Whether to propagate plugins from the parent runner context + to the agent's runner. When True (default), the agent will inherit all + plugins from its parent. Set to False to run the agent with an isolated + plugin environment. """ - def __init__(self, agent: BaseAgent, skip_summarization: bool = False): + def __init__( + self, + agent: BaseAgent, + skip_summarization: bool = False, + *, + include_plugins: bool = True, + ): self.agent = agent self.skip_summarization: bool = skip_summarization + self.include_plugins = include_plugins super().__init__(name=agent.name, description=agent.description) @@ -68,6 +79,8 @@ def _get_declaration(self) -> types.FunctionDeclaration: result = _automatic_function_calling_util.build_function_declaration( func=self.agent.input_schema, variant=self._api_variant ) + # Override the description with the agent's description + result.description = self.agent.description else: result = types.FunctionDeclaration( parameters=types.Schema( @@ -130,6 +143,11 @@ async def run_async( invocation_context.app_name if invocation_context else None ) child_app_name = parent_app_name or self.agent.name + plugins = ( + tool_context._invocation_context.plugin_manager.plugins + if self.include_plugins + else None + ) runner = Runner( app_name=child_app_name, agent=self.agent, @@ -137,7 +155,7 @@ async def run_async( session_service=InMemorySessionService(), memory_service=InMemoryMemoryService(), credential_service=tool_context._invocation_context.credential_service, - plugins=list(tool_context._invocation_context.plugin_manager.plugins), + plugins=plugins, ) state_dict = { @@ -192,7 +210,9 @@ def from_config( agent_tool_config.agent, config_abs_path ) return cls( - agent=agent, skip_summarization=agent_tool_config.skip_summarization + agent=agent, + skip_summarization=agent_tool_config.skip_summarization, + include_plugins=agent_tool_config.include_plugins, ) @@ -204,3 +224,6 @@ class AgentToolConfig(BaseToolConfig): skip_summarization: bool = False """Whether to skip summarization of the agent output.""" + + include_plugins: bool = True + """Whether to include plugins from parent runner context.""" diff --git a/src/google/adk/tools/api_registry.py b/src/google/adk/tools/api_registry.py new file mode 100644 index 0000000000..e3f0076404 --- /dev/null +++ b/src/google/adk/tools/api_registry.py @@ -0,0 +1,123 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import sys +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + +import google.auth +import google.auth.transport.requests +import httpx + +from .base_toolset import ToolPredicate +from .mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams +from .mcp_tool.mcp_toolset import McpToolset + +API_REGISTRY_URL = "https://cloudapiregistry.googleapis.com" + + +class ApiRegistry: + """Registry that provides McpToolsets for MCP servers registered in API Registry.""" + + def __init__( + self, + api_registry_project_id: str, + location: str = "global", + header_provider: Optional[ + Callable[[ReadonlyContext], Dict[str, str]] + ] = None, + ): + """Initialize the API Registry. + + Args: + api_registry_project_id: The project ID for the Google Cloud API Registry. + location: The location of the API Registry resources. + header_provider: Optional function to provide additional headers for MCP + server calls. + """ + self.api_registry_project_id = api_registry_project_id + self.location = location + self._credentials, _ = google.auth.default() + self._mcp_servers: Dict[str, Dict[str, Any]] = {} + self._header_provider = header_provider + + url = f"{API_REGISTRY_URL}/v1beta/projects/{self.api_registry_project_id}/locations/{self.location}/mcpServers" + try: + request = google.auth.transport.requests.Request() + self._credentials.refresh(request) + headers = { + "Authorization": f"Bearer {self._credentials.token}", + "Content-Type": "application/json", + } + with httpx.Client() as client: + response = client.get(url, headers=headers) + response.raise_for_status() + mcp_servers_list = response.json().get("mcpServers", []) + for server in mcp_servers_list: + server_name = server.get("name", "") + if server_name: + self._mcp_servers[server_name] = server + except (httpx.HTTPError, ValueError) as e: + # Handle error in fetching or parsing tool definitions + raise RuntimeError( + f"Error fetching MCP servers from API Registry: {e}" + ) from e + + def get_toolset( + self, + mcp_server_name: str, + tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, + tool_name_prefix: Optional[str] = None, + ) -> McpToolset: + """Return the MCP Toolset based on the params. + + Args: + mcp_server_name: Filter to select the MCP server name to get tools + from. + tool_filter: Optional filter to select specific tools. Can be a list of + tool names or a ToolPredicate function. + tool_name_prefix: Optional prefix to prepend to the names of the tools + returned by the toolset. + + Returns: + McpToolset: A toolset for the MCP server specified. + """ + server = self._mcp_servers.get(mcp_server_name) + if not server: + raise ValueError( + f"MCP server {mcp_server_name} not found in API Registry." + ) + if not server.get("urls"): + raise ValueError(f"MCP server {mcp_server_name} has no URLs.") + + mcp_server_url = server["urls"][0] + request = google.auth.transport.requests.Request() + self._credentials.refresh(request) + headers = { + "Authorization": f"Bearer {self._credentials.token}", + } + return McpToolset( + connection_params=StreamableHTTPConnectionParams( + url="https://" + mcp_server_url, + headers=headers, + ), + tool_filter=tool_filter, + tool_name_prefix=tool_name_prefix, + header_provider=self._header_provider, + ) diff --git a/src/google/adk/tools/bigquery/config.py b/src/google/adk/tools/bigquery/config.py index 7768f214ed..39b6a3d9b6 100644 --- a/src/google/adk/tools/bigquery/config.py +++ b/src/google/adk/tools/bigquery/config.py @@ -101,6 +101,16 @@ class BigQueryToolConfig(BaseModel): locations, see https://cloud.google.com/bigquery/docs/locations. """ + job_labels: Optional[dict[str, str]] = None + """Labels to apply to BigQuery jobs for tracking and monitoring. + + These labels will be added to all BigQuery jobs executed by the tools. + Labels must be key-value pairs where both keys and values are strings. + Labels can be used for billing, monitoring, and resource organization. + For more information about labels, see + https://cloud.google.com/bigquery/docs/labels-intro. + """ + @field_validator('maximum_bytes_billed') @classmethod def validate_maximum_bytes_billed(cls, v): @@ -121,3 +131,13 @@ def validate_application_name(cls, v): if v and ' ' in v: raise ValueError('Application name should not contain spaces.') return v + + @field_validator('job_labels') + @classmethod + def validate_job_labels(cls, v): + """Validate that job_labels keys are not empty.""" + if v is not None: + for key in v.keys(): + if not key: + raise ValueError('Label keys cannot be empty.') + return v diff --git a/src/google/adk/tools/bigquery/query_tool.py b/src/google/adk/tools/bigquery/query_tool.py index 666dc3c5a1..5bcd734e70 100644 --- a/src/google/adk/tools/bigquery/query_tool.py +++ b/src/google/adk/tools/bigquery/query_tool.py @@ -68,7 +68,10 @@ def _execute_sql( bq_connection_properties = [] # BigQuery job labels if applicable - bq_job_labels = {} + bq_job_labels = ( + settings.job_labels.copy() if settings and settings.job_labels else {} + ) + if caller_id: bq_job_labels["adk-bigquery-tool"] = caller_id if settings and settings.application_name: diff --git a/src/google/adk/tools/crewai_tool.py b/src/google/adk/tools/crewai_tool.py index eaef479274..875b82e5b9 100644 --- a/src/google/adk/tools/crewai_tool.py +++ b/src/google/adk/tools/crewai_tool.py @@ -30,16 +30,9 @@ try: from crewai.tools import BaseTool as CrewaiBaseTool except ImportError as e: - import sys - - if sys.version_info < (3, 10): - raise ImportError( - 'Crewai Tools require Python 3.10+. Please upgrade your Python version.' - ) from e - else: - raise ImportError( - "Crewai Tools require pip install 'google-adk[extensions]'." - ) from e + raise ImportError( + "Crewai Tools require pip install 'google-adk[extensions]'." + ) from e class CrewaiTool(FunctionTool): diff --git a/src/google/adk/tools/mcp_tool/__init__.py b/src/google/adk/tools/mcp_tool/__init__.py index f1e56b99c4..1170b2e1af 100644 --- a/src/google/adk/tools/mcp_tool/__init__.py +++ b/src/google/adk/tools/mcp_tool/__init__.py @@ -39,15 +39,7 @@ except ImportError as e: import logging - import sys logger = logging.getLogger('google_adk.' + __name__) - - if sys.version_info < (3, 10): - logger.warning( - 'MCP Tool requires Python 3.10 or above. Please upgrade your Python' - ' version.' - ) - else: - logger.debug('MCP Tool is not installed') - logger.debug(e) + logger.debug('MCP Tool is not installed') + logger.debug(e) diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index d95d48f282..c9c4c2ae66 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -29,24 +29,13 @@ from typing import Union import anyio +from mcp import ClientSession +from mcp import StdioServerParameters +from mcp.client.sse import sse_client +from mcp.client.stdio import stdio_client +from mcp.client.streamable_http import streamablehttp_client from pydantic import BaseModel -try: - from mcp import ClientSession - from mcp import StdioServerParameters - from mcp.client.sse import sse_client - from mcp.client.stdio import stdio_client - from mcp.client.streamable_http import streamablehttp_client -except ImportError as e: - - if sys.version_info < (3, 10): - raise ImportError( - 'MCP Tool requires Python 3.10 or above. Please upgrade your Python' - ' version.' - ) from e - else: - raise e - logger = logging.getLogger('google_adk.' + __name__) @@ -108,10 +97,10 @@ class StreamableHTTPConnectionParams(BaseModel): terminate_on_close: bool = True -def retry_on_closed_resource(func): - """Decorator to automatically retry action when MCP session is closed. +def retry_on_errors(func): + """Decorator to automatically retry action when MCP session errors occur. - When MCP session was closed, the decorator will automatically retry the + When MCP session errors occur, the decorator will automatically retry the action once. The create_session method will handle creating a new session if the old one was disconnected. @@ -126,11 +115,11 @@ def retry_on_closed_resource(func): async def wrapper(self, *args, **kwargs): try: return await func(self, *args, **kwargs) - except (anyio.ClosedResourceError, anyio.BrokenResourceError): - # If the session connection is closed or unusable, we will retry the - # function to reconnect to the server. create_session will handle - # detecting and replacing disconnected sessions. - logger.info('Retrying %s due to closed resource', func.__name__) + except Exception as e: + # If an error is thrown, we will retry the function to reconnect to the + # server. create_session will handle detecting and replacing disconnected + # sessions. + logger.info('Retrying %s due to error: %s', func.__name__, e) return await func(self, *args, **kwargs) return wrapper diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index ad420a3d0f..b15f2c73fe 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -17,7 +17,6 @@ import base64 import inspect import logging -import sys from typing import Any from typing import Callable from typing import Dict @@ -27,35 +26,21 @@ from fastapi.openapi.models import APIKeyIn from google.genai.types import FunctionDeclaration +from mcp.types import Tool as McpBaseTool from typing_extensions import override from ...agents.readonly_context import ReadonlyContext -from ...features import FeatureName -from ...features import is_feature_enabled -from .._gemini_schema_util import _to_gemini_schema -from .mcp_session_manager import MCPSessionManager -from .mcp_session_manager import retry_on_closed_resource - -# Attempt to import MCP Tool from the MCP library, and hints user to upgrade -# their Python version to 3.10 if it fails. -try: - from mcp.types import Tool as McpBaseTool -except ImportError as e: - if sys.version_info < (3, 10): - raise ImportError( - "MCP Tool requires Python 3.10 or above. Please upgrade your Python" - " version." - ) from e - else: - raise e - - from ...auth.auth_credential import AuthCredential from ...auth.auth_schemes import AuthScheme from ...auth.auth_tool import AuthConfig +from ...features import FeatureName +from ...features import is_feature_enabled +from .._gemini_schema_util import _to_gemini_schema from ..base_authenticated_tool import BaseAuthenticatedTool # import from ..tool_context import ToolContext +from .mcp_session_manager import MCPSessionManager +from .mcp_session_manager import retry_on_errors logger = logging.getLogger("google_adk." + __name__) @@ -195,7 +180,7 @@ async def run_async( return {"error": "This tool call is rejected."} return await super().run_async(args=args, tool_context=tool_context) - @retry_on_closed_resource + @retry_on_errors @override async def _run_async_impl( self, *, args, tool_context: ToolContext, credential: AuthCredential diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index daa88f9031..035b75878b 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -25,6 +25,8 @@ from typing import Union import warnings +from mcp import StdioServerParameters +from mcp.types import ListToolsResult from pydantic import model_validator from typing_extensions import override @@ -37,27 +39,10 @@ from ..tool_configs import BaseToolConfig from ..tool_configs import ToolArgsConfig from .mcp_session_manager import MCPSessionManager -from .mcp_session_manager import retry_on_closed_resource +from .mcp_session_manager import retry_on_errors from .mcp_session_manager import SseConnectionParams from .mcp_session_manager import StdioConnectionParams from .mcp_session_manager import StreamableHTTPConnectionParams - -# Attempt to import MCP Tool from the MCP library, and hints user to upgrade -# their Python version to 3.10 if it fails. -try: - from mcp import StdioServerParameters - from mcp.types import ListToolsResult -except ImportError as e: - import sys - - if sys.version_info < (3, 10): - raise ImportError( - "MCP Tool requires Python 3.10 or above. Please upgrade your Python" - " version." - ) from e - else: - raise e - from .mcp_tool import MCPTool logger = logging.getLogger("google_adk." + __name__) @@ -155,7 +140,7 @@ def __init__( self._auth_credential = auth_credential self._require_confirmation = require_confirmation - @retry_on_closed_resource + @retry_on_errors async def get_tools( self, readonly_context: Optional[ReadonlyContext] = None, diff --git a/src/google/adk/tools/openapi_tool/common/common.py b/src/google/adk/tools/openapi_tool/common/common.py index 7187b1bd1b..1df3125e3d 100644 --- a/src/google/adk/tools/openapi_tool/common/common.py +++ b/src/google/adk/tools/openapi_tool/common/common.py @@ -64,11 +64,9 @@ class ApiParameter(BaseModel): required: bool = False def model_post_init(self, _: Any): - self.py_name = ( - self.py_name - if self.py_name - else rename_python_keywords(_to_snake_case(self.original_name)) - ) + if not self.py_name: + inferred_name = rename_python_keywords(_to_snake_case(self.original_name)) + self.py_name = inferred_name or self._default_py_name() if isinstance(self.param_schema, str): self.param_schema = Schema.model_validate_json(self.param_schema) @@ -77,6 +75,16 @@ def model_post_init(self, _: Any): self.type_hint = TypeHintHelper.get_type_hint(self.param_schema) return self + def _default_py_name(self) -> str: + location_defaults = { + 'body': 'body', + 'query': 'query_param', + 'path': 'path_param', + 'header': 'header_param', + 'cookie': 'cookie_param', + } + return location_defaults.get(self.param_location or '', 'value') + @model_serializer def _serialize(self): return { diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py index 06d692a2b0..326ff6787e 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py @@ -139,10 +139,19 @@ def _process_request_body(self): ) ) else: + # Prefer explicit body name to avoid empty keys when schema lacks type + # information (e.g., oneOf/anyOf/allOf) while retaining legacy behavior + # for simple scalar types. + if schema and (schema.oneOf or schema.anyOf or schema.allOf): + param_name = 'body' + elif not schema or not schema.type: + param_name = 'body' + else: + param_name = '' + self._params.append( - # Empty name for unnamed body param ApiParameter( - original_name='', + original_name=param_name, param_location='body', param_schema=schema, description=description, diff --git a/src/google/adk/tools/spanner/search_tool.py b/src/google/adk/tools/spanner/search_tool.py index b3cf797edf..2b6b9777dc 100644 --- a/src/google/adk/tools/spanner/search_tool.py +++ b/src/google/adk/tools/spanner/search_tool.py @@ -20,23 +20,34 @@ from typing import List from typing import Optional -from google.adk.tools.spanner import client -from google.adk.tools.spanner.settings import SpannerToolSettings -from google.adk.tools.tool_context import ToolContext from google.auth.credentials import Credentials from google.cloud.spanner_admin_database_v1.types import DatabaseDialect from google.cloud.spanner_v1.database import Database -# Embedding options -_SPANNER_EMBEDDING_MODEL_NAME = "spanner_embedding_model_name" -_VERTEX_AI_EMBEDDING_MODEL_ENDPOINT = "vertex_ai_embedding_model_endpoint" +from . import client +from . import utils +from .settings import APPROXIMATE_NEAREST_NEIGHBORS +from .settings import EXACT_NEAREST_NEIGHBORS +from .settings import SpannerToolSettings + +# Embedding model settings. +# Only for Spanner GoogleSQL dialect database, and use Spanner ML.PREDICT +# function. +_SPANNER_GSQL_EMBEDDING_MODEL_NAME = "spanner_googlesql_embedding_model_name" +# Only for Spanner PostgreSQL dialect database, and use spanner.ML_PREDICT_ROW +# to inferencing with Vertex AI embedding model endpoint. +_SPANNER_PG_VERTEX_AI_EMBEDDING_MODEL_ENDPOINT = ( + "spanner_postgresql_vertex_ai_embedding_model_endpoint" +) +# For both Spanner GoogleSQL and PostgreSQL dialects, use Vertex AI embedding +# model to generate embeddings for vector similarity search. +_VERTEX_AI_EMBEDDING_MODEL_NAME = "vertex_ai_embedding_model_name" +_OUTPUT_DIMENSIONALITY = "output_dimensionality" # Search options _TOP_K = "top_k" _DISTANCE_TYPE = "distance_type" _NEAREST_NEIGHBORS_ALGORITHM = "nearest_neighbors_algorithm" -_EXACT_NEAREST_NEIGHBORS = "EXACT_NEAREST_NEIGHBORS" -_APPROXIMATE_NEAREST_NEIGHBORS = "APPROXIMATE_NEAREST_NEIGHBORS" _NUM_LEAVES_TO_SEARCH = "num_leaves_to_search" # Constants @@ -48,12 +59,12 @@ def _generate_googlesql_for_embedding_query( - spanner_embedding_model_name: str, + spanner_gsql_embedding_model_name: str, ) -> str: return f""" SELECT embeddings.values FROM ML.PREDICT( - MODEL {spanner_embedding_model_name}, + MODEL {spanner_gsql_embedding_model_name}, (SELECT CAST(@{_GOOGLESQL_PARAMETER_TEXT_QUERY} AS STRING) as content) ) """ @@ -61,37 +72,60 @@ def _generate_googlesql_for_embedding_query( def _generate_postgresql_for_embedding_query( vertex_ai_embedding_model_endpoint: str, + output_dimensionality: Optional[int], ) -> str: + instances_json = f""" + 'instances', + JSONB_BUILD_ARRAY( + JSONB_BUILD_OBJECT( + 'content', + ${_POSTGRESQL_PARAMETER_TEXT_QUERY}::TEXT + ) + ) + """ + + params_list = [] + if output_dimensionality is not None: + params_list.append(f""" + 'parameters', + JSONB_BUILD_OBJECT( + 'outputDimensionality', + {output_dimensionality} + ) + """) + + jsonb_build_args = ",\n".join([instances_json] + params_list) + return f""" - SELECT spanner.FLOAT32_ARRAY( spanner.ML_PREDICT_ROW( - '{vertex_ai_embedding_model_endpoint}', - JSONB_BUILD_OBJECT( - 'instances', - JSONB_BUILD_ARRAY( JSONB_BUILD_OBJECT( - 'content', - ${_POSTGRESQL_PARAMETER_TEXT_QUERY}::TEXT - )) + SELECT spanner.FLOAT32_ARRAY( + spanner.ML_PREDICT_ROW( + '{vertex_ai_embedding_model_endpoint}', + JSONB_BUILD_OBJECT( + {jsonb_build_args} + ) + ) -> 'predictions' -> 0 -> 'embeddings' -> 'values' ) - ) -> 'predictions'->0->'embeddings'->'values' ) """ def _get_embedding_for_query( database: Database, dialect: DatabaseDialect, - spanner_embedding_model_name: Optional[str], - vertex_ai_embedding_model_endpoint: Optional[str], + spanner_gsql_embedding_model_name: Optional[str], + spanner_pg_vertex_ai_embedding_model_endpoint: Optional[str], query: str, + output_dimensionality: Optional[int] = None, ) -> List[float]: """Gets the embedding for the query.""" if dialect == DatabaseDialect.POSTGRESQL: embedding_query = _generate_postgresql_for_embedding_query( - vertex_ai_embedding_model_endpoint + spanner_pg_vertex_ai_embedding_model_endpoint, + output_dimensionality, ) params = {f"p{_POSTGRESQL_PARAMETER_TEXT_QUERY}": query} else: embedding_query = _generate_googlesql_for_embedding_query( - spanner_embedding_model_name + spanner_gsql_embedding_model_name ) params = {_GOOGLESQL_PARAMETER_TEXT_QUERY: query} with database.snapshot() as snapshot: @@ -101,8 +135,8 @@ def _get_embedding_for_query( def _get_postgresql_distance_function(distance_type: str) -> str: return { - "COSINE_DISTANCE": "spanner.cosine_distance", - "EUCLIDEAN_DISTANCE": "spanner.euclidean_distance", + "COSINE": "spanner.cosine_distance", + "EUCLIDEAN": "spanner.euclidean_distance", "DOT_PRODUCT": "spanner.dot_product", }[distance_type] @@ -110,13 +144,13 @@ def _get_postgresql_distance_function(distance_type: str) -> str: def _get_googlesql_distance_function(distance_type: str, ann: bool) -> str: if ann: return { - "COSINE_DISTANCE": "APPROX_COSINE_DISTANCE", - "EUCLIDEAN_DISTANCE": "APPROX_EUCLIDEAN_DISTANCE", + "COSINE": "APPROX_COSINE_DISTANCE", + "EUCLIDEAN": "APPROX_EUCLIDEAN_DISTANCE", "DOT_PRODUCT": "APPROX_DOT_PRODUCT", }[distance_type] return { - "COSINE_DISTANCE": "COSINE_DISTANCE", - "EUCLIDEAN_DISTANCE": "EUCLIDEAN_DISTANCE", + "COSINE": "COSINE_DISTANCE", + "EUCLIDEAN": "EUCLIDEAN_DISTANCE", "DOT_PRODUCT": "DOT_PRODUCT", }[distance_type] @@ -172,7 +206,7 @@ def _generate_sql_for_ann( """Generates a SQL query for ANN search.""" if dialect == DatabaseDialect.POSTGRESQL: raise NotImplementedError( - f"{_APPROXIMATE_NEAREST_NEIGHBORS} is not supported for PostgreSQL" + f"{APPROXIMATE_NEAREST_NEIGHBORS} is not supported for PostgreSQL" " dialect." ) distance_function = _get_googlesql_distance_function(distance_type, ann=True) @@ -206,8 +240,6 @@ def similarity_search( columns: List[str], embedding_options: Dict[str, str], credentials: Credentials, - settings: SpannerToolSettings, - tool_context: ToolContext, additional_filter: Optional[str] = None, search_options: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: @@ -234,21 +266,34 @@ def similarity_search( columns (List[str]): A list of column names, representing the additional columns to return in the search results. embedding_options (Dict[str, str]): A dictionary of options to use for - the embedding service. The following options are supported: - - spanner_embedding_model_name: (For GoogleSQL dialect) The + the embedding service. **Exactly one of the following three keys + MUST be present in this dictionary**: + `vertex_ai_embedding_model_name`, `spanner_googlesql_embedding_model_name`, + or `spanner_postgresql_vertex_ai_embedding_model_endpoint`. + - vertex_ai_embedding_model_name (str): (Supported both **GoogleSQL and + PostgreSQL** dialects Spanner database) The name of a + public Vertex AI embedding model (e.g., `'text-embedding-005'`). + If specified, the tool generates embeddings client-side using the + Vertex AI embedding model. + - spanner_googlesql_embedding_model_name (str): (For GoogleSQL dialect) The name of the embedding model that is registered in Spanner via a `CREATE MODEL` statement. For more details, see https://cloud.google.com/spanner/docs/ml-tutorial-embeddings#generate_and_store_text_embeddings - - vertex_ai_embedding_model_endpoint: (For PostgreSQL dialect) - The fully qualified endpoint of the Vertex AI embedding model, - in the format of + If specified, embedding generation is performed using Spanner's + `ML.PREDICT` function. + - spanner_postgresql_vertex_ai_embedding_model_endpoint (str): + (For PostgreSQL dialect) The fully qualified endpoint of the Vertex AI + embedding model, in the format of `projects/$project/locations/$location/publishers/google/models/$model_name`, where $project is the project hosting the Vertex AI endpoint, $location is the location of the endpoint, and $model_name is the name of the text embedding model. + If specified, embedding generation is performed using Spanner's + `spanner.ML_PREDICT_ROW` function. + - output_dimensionality: Optional. The output dimensionality of the + embedding. If not specified, the embedding model's default output + dimensionality will be used. credentials (Credentials): The credentials to use for the request. - settings (SpannerToolSettings): The configuration for the tool. - tool_context (ToolContext): The context for the tool. additional_filter (Optional[str]): An optional filter to apply to the search query. If provided, this will be added to the WHERE clause of the final query. @@ -257,9 +302,9 @@ def similarity_search( - top_k: The number of most similar documents to return. The default value is 4. - distance_type: The distance type to use to perform the - similarity search. Valid values include "COSINE_DISTANCE", - "EUCLIDEAN_DISTANCE", and "DOT_PRODUCT". Default value is - "COSINE_DISTANCE". + similarity search. Valid values include "COSINE", + "EUCLIDEAN", and "DOT_PRODUCT". Default value is + "COSINE". - nearest_neighbors_algorithm: The nearest neighbors search algorithm to use. Valid values include "EXACT_NEAREST_NEIGHBORS" and "APPROXIMATE_NEAREST_NEIGHBORS". Default value is @@ -287,15 +332,13 @@ def similarity_search( ... embedding_column_to_search="product_description_embedding", ... columns=["product_name", "product_description", "price_in_cents"], ... credentials=credentials, - ... settings=settings, - ... tool_context=tool_context, ... additional_filter="price_in_cents < 100000", ... embedding_options={ - ... "spanner_embedding_model_name": "my_embedding_model" + ... "vertex_ai_embedding_model_name": "text-embedding-005" ... }, ... search_options={ ... "top_k": 2, - ... "distance_type": "COSINE_DISTANCE" + ... "distance_type": "COSINE" ... } ... ) { @@ -336,33 +379,68 @@ def similarity_search( embedding_options = {} if search_options is None: search_options = {} - spanner_embedding_model_name = embedding_options.get( - _SPANNER_EMBEDDING_MODEL_NAME + + exclusive_embedding_model_keys = { + _VERTEX_AI_EMBEDDING_MODEL_NAME, + _SPANNER_GSQL_EMBEDDING_MODEL_NAME, + _SPANNER_PG_VERTEX_AI_EMBEDDING_MODEL_ENDPOINT, + } + if ( + len( + exclusive_embedding_model_keys.intersection( + embedding_options.keys() + ) + ) + != 1 + ): + raise ValueError("Exactly one embedding model option must be specified.") + + vertex_ai_embedding_model_name = embedding_options.get( + _VERTEX_AI_EMBEDDING_MODEL_NAME + ) + spanner_gsql_embedding_model_name = embedding_options.get( + _SPANNER_GSQL_EMBEDDING_MODEL_NAME + ) + spanner_pg_vertex_ai_embedding_model_endpoint = embedding_options.get( + _SPANNER_PG_VERTEX_AI_EMBEDDING_MODEL_ENDPOINT ) if ( database.database_dialect == DatabaseDialect.GOOGLE_STANDARD_SQL - and spanner_embedding_model_name is None + and vertex_ai_embedding_model_name is None + and spanner_gsql_embedding_model_name is None ): raise ValueError( - f"embedding_options['{_SPANNER_EMBEDDING_MODEL_NAME}']" - " must be specified for GoogleSQL dialect." + f"embedding_options['{_VERTEX_AI_EMBEDDING_MODEL_NAME}'] or" + f" embedding_options['{_SPANNER_GSQL_EMBEDDING_MODEL_NAME}'] must be" + " specified for GoogleSQL dialect Spanner database." ) - vertex_ai_embedding_model_endpoint = embedding_options.get( - _VERTEX_AI_EMBEDDING_MODEL_ENDPOINT - ) if ( database.database_dialect == DatabaseDialect.POSTGRESQL - and vertex_ai_embedding_model_endpoint is None + and vertex_ai_embedding_model_name is None + and spanner_pg_vertex_ai_embedding_model_endpoint is None ): raise ValueError( - f"embedding_options['{_VERTEX_AI_EMBEDDING_MODEL_ENDPOINT}']" - " must be specified for PostgreSQL dialect." + f"embedding_options['{_VERTEX_AI_EMBEDDING_MODEL_NAME}'] or" + f" embedding_options['{_SPANNER_PG_VERTEX_AI_EMBEDDING_MODEL_ENDPOINT}']" + " must be specified for PostgreSQL dialect Spanner database." + ) + output_dimensionality = embedding_options.get(_OUTPUT_DIMENSIONALITY) + if ( + output_dimensionality is not None + and spanner_gsql_embedding_model_name is not None + ): + # Currently, Spanner GSQL Model ML.PREDICT does not support + # output_dimensionality parameter for inference embedding models. + raise ValueError( + f"embedding_options[{_OUTPUT_DIMENSIONALITY}] is not supported when" + f" embedding_options['{_SPANNER_GSQL_EMBEDDING_MODEL_NAME}'] is" + " specified." ) # Use cosine distance by default. distance_type = search_options.get(_DISTANCE_TYPE) if distance_type is None: - distance_type = "COSINE_DISTANCE" + distance_type = "COSINE" top_k = search_options.get(_TOP_K) if top_k is None: @@ -370,26 +448,36 @@ def similarity_search( # Use EXACT_NEAREST_NEIGHBORS (i.e. kNN) by default. nearest_neighbors_algorithm = search_options.get( - _NEAREST_NEIGHBORS_ALGORITHM, _EXACT_NEAREST_NEIGHBORS + _NEAREST_NEIGHBORS_ALGORITHM, + EXACT_NEAREST_NEIGHBORS, ) if nearest_neighbors_algorithm not in ( - _EXACT_NEAREST_NEIGHBORS, - _APPROXIMATE_NEAREST_NEIGHBORS, + EXACT_NEAREST_NEIGHBORS, + APPROXIMATE_NEAREST_NEIGHBORS, ): raise NotImplementedError( f"Unsupported search_options['{_NEAREST_NEIGHBORS_ALGORITHM}']:" f" {nearest_neighbors_algorithm}" ) - embedding = _get_embedding_for_query( - database, - database.database_dialect, - spanner_embedding_model_name, - vertex_ai_embedding_model_endpoint, - query, - ) + # Generate embedding for the query according to the embedding options. + if vertex_ai_embedding_model_name: + embedding = utils.embed_contents( + vertex_ai_embedding_model_name, + [query], + output_dimensionality, + )[0] + else: + embedding = _get_embedding_for_query( + database, + database.database_dialect, + spanner_gsql_embedding_model_name, + spanner_pg_vertex_ai_embedding_model_endpoint, + query, + output_dimensionality, + ) - if nearest_neighbors_algorithm == _EXACT_NEAREST_NEIGHBORS: + if nearest_neighbors_algorithm == EXACT_NEAREST_NEIGHBORS: sql = _generate_sql_for_knn( database.database_dialect, table_name, @@ -438,5 +526,100 @@ def similarity_search( except Exception as ex: return { "status": "ERROR", - "error_details": str(ex), + "error_details": repr(ex), + } + + +def vector_store_similarity_search( + query: str, + credentials: Credentials, + settings: SpannerToolSettings, +) -> Dict[str, Any]: + """Performs a semantic similarity search to retrieve relevant context from the Spanner vector store. + + This function performs vector similarity search directly on a vector store + table in Spanner database and returns the relevant data. + + Args: + query (str): The search string based on the user's question. + credentials (Credentials): The credentials to use for the request. + settings (SpannerToolSettings): The configuration for the tool. + + Returns: + Dict[str, Any]: A dictionary representing the result of the search. + On success, it contains {"status": "SUCCESS", "rows": [...]}. The last + column of each row is the distance between the query and the row result. + On error, it contains {"status": "ERROR", "error_details": "..."}. + + Examples: + >>> vector_store_similarity_search( + ... query="Spanner database optimization techniques for high QPS", + ... credentials=credentials, + ... settings=settings + ... ) + { + "status": "SUCCESS", + "rows": [ + ( + "Optimizing Query Performance", + 0.12, + ), + ( + "Schema Design Best Practices", + 0.25, + ), + ( + "Using Secondary Indexes Effectively", + 0.31, + ), + ... + ], + } + """ + + try: + if not settings or not settings.vector_store_settings: + raise ValueError("Spanner vector store settings are not set.") + + # Get the embedding model settings. + embedding_options = { + _VERTEX_AI_EMBEDDING_MODEL_NAME: ( + settings.vector_store_settings.vertex_ai_embedding_model_name + ), + _OUTPUT_DIMENSIONALITY: settings.vector_store_settings.vector_length, + } + + # Get the search settings. + search_options = { + _TOP_K: settings.vector_store_settings.top_k, + _DISTANCE_TYPE: settings.vector_store_settings.distance_type, + _NEAREST_NEIGHBORS_ALGORITHM: ( + settings.vector_store_settings.nearest_neighbors_algorithm + ), + } + if ( + settings.vector_store_settings.nearest_neighbors_algorithm + == APPROXIMATE_NEAREST_NEIGHBORS + ): + search_options[_NUM_LEAVES_TO_SEARCH] = ( + settings.vector_store_settings.num_leaves_to_search + ) + + return similarity_search( + project_id=settings.vector_store_settings.project_id, + instance_id=settings.vector_store_settings.instance_id, + database_id=settings.vector_store_settings.database_id, + table_name=settings.vector_store_settings.table_name, + query=query, + embedding_column_to_search=settings.vector_store_settings.embedding_column, + columns=settings.vector_store_settings.selected_columns, + embedding_options=embedding_options, + credentials=credentials, + additional_filter=settings.vector_store_settings.additional_filter, + search_options=search_options, + ) + except Exception as ex: + return { + "status": "ERROR", + "error_details": repr(ex), } diff --git a/src/google/adk/tools/spanner/settings.py b/src/google/adk/tools/spanner/settings.py index 5d097258f4..a76331ba65 100644 --- a/src/google/adk/tools/spanner/settings.py +++ b/src/google/adk/tools/spanner/settings.py @@ -16,20 +16,115 @@ from enum import Enum from typing import List +from typing import Literal +from typing import Optional from pydantic import BaseModel +from pydantic import model_validator from ...utils.feature_decorator import experimental +# Vector similarity search nearest neighbors search algorithms. +EXACT_NEAREST_NEIGHBORS = "EXACT_NEAREST_NEIGHBORS" +APPROXIMATE_NEAREST_NEIGHBORS = "APPROXIMATE_NEAREST_NEIGHBORS" +NearestNeighborsAlgorithm = Literal[ + EXACT_NEAREST_NEIGHBORS, + APPROXIMATE_NEAREST_NEIGHBORS, +] + class Capabilities(Enum): """Capabilities indicating what type of operation tools are allowed to be performed on Spanner.""" - DATA_READ = 'data_read' + DATA_READ = "data_read" """Read only data operations tools are allowed.""" -@experimental('Tool settings defaults may have breaking change in the future.') +class SpannerVectorStoreSettings(BaseModel): + """Settings for Spanner Vector Store. + + This is used for vector similarity search in a Spanner vector store table. + Provide the vector store table and the embedding model settings to use with + the `vector_store_similarity_search` tool. + """ + + project_id: str + """Required. The GCP project id in which the Spanner database resides.""" + + instance_id: str + """Required. The instance id of the Spanner database.""" + + database_id: str + """Required. The database id of the Spanner database.""" + + table_name: str + """Required. The name of the vector store table to use for vector similarity search.""" + + content_column: str + """Required. The name of the content column in the vector store table. By default, this column value is also returned as part of the vector similarity search result.""" + + embedding_column: str + """Required. The name of the embedding column to search in the vector store table.""" + + vector_length: int + """Required. The the dimension of the vectors in the `embedding_column`.""" + + vertex_ai_embedding_model_name: str + """Required. The Vertex AI embedding model name, which is used to generate embeddings for vector store and vector similarity search. + For example, 'text-embedding-005'. + + Note: the output dimensionality of the embedding model should be the same as the value specified in the `vector_length` field. + Otherwise, a runtime error might be raised during a query. + """ + + selected_columns: List[str] = [] + """Required. The vector store table columns to return in the vector similarity search result. + + By default, only the `content_column` value and the distance value are returned. + If sepecified, the list of selected columns and the distance value are returned. + For example, if `selected_columns` is ['col1', 'col2'], then the result will contain the values of 'col1' and 'col2' columns and the distance value. + """ + + nearest_neighbors_algorithm: NearestNeighborsAlgorithm = ( + "EXACT_NEAREST_NEIGHBORS" + ) + """The algorithm used to perform vector similarity search. This value can be EXACT_NEAREST_NEIGHBORS or APPROXIMATE_NEAREST_NEIGHBORS. + + For more details about EXACT_NEAREST_NEIGHBORS, see https://docs.cloud.google.com/spanner/docs/find-k-nearest-neighbors + For more details about APPROXIMATE_NEAREST_NEIGHBORS, see https://docs.cloud.google.com/spanner/docs/find-approximate-nearest-neighbors + """ + + top_k: int = 4 + """Required. The number of neighbors to return for each vector similarity search query. The default value is 4.""" + + distance_type: str = "COSINE" + """Required. The distance metric used to build the vector index or perform vector similarity search. This value can be COSINE, DOT_PRODUCT, or EUCLIDEAN.""" + + num_leaves_to_search: Optional[int] = None + """Optional. This option specifies how many leaf nodes of the index are searched. + + Note: this option is only used when the nearest neighbors search algorithm (`nearest_neighbors_algorithm`) is APPROXIMATE_NEAREST_NEIGHBORS. + For more details, see https://docs.cloud.google.com/spanner/docs/vector-index-best-practices + """ + + additional_filter: Optional[str] = None + """Optional. An optional filter to apply to the search query. If provided, this will be added to the WHERE clause of the final query.""" + + @model_validator(mode="after") + def __post_init__(self): + """Validate the embedding settings.""" + if not self.vector_length or self.vector_length <= 0: + raise ValueError( + "Invalid vector length in the Spanner vector store settings." + ) + + if not self.selected_columns: + self.selected_columns = [self.content_column] + + return self + + +@experimental("Tool settings defaults may have breaking change in the future.") class SpannerToolSettings(BaseModel): """Settings for Spanner tools.""" @@ -44,3 +139,6 @@ class SpannerToolSettings(BaseModel): max_executed_query_result_rows: int = 50 """Maximum number of rows to return from a query result.""" + + vector_store_settings: Optional[SpannerVectorStoreSettings] = None + """Settings for Spanner vector store and vector similarity search.""" diff --git a/src/google/adk/tools/spanner/spanner_toolset.py b/src/google/adk/tools/spanner/spanner_toolset.py index 861314abb3..6496014f74 100644 --- a/src/google/adk/tools/spanner/spanner_toolset.py +++ b/src/google/adk/tools/spanner/spanner_toolset.py @@ -47,6 +47,8 @@ class SpannerToolset(BaseToolset): - spanner_list_named_schemas - spanner_get_table_schema - spanner_execute_sql + - spanner_similarity_search + - spanner_vector_store_similarity_search """ def __init__( @@ -121,6 +123,16 @@ async def get_tools( tool_settings=self._tool_settings, ) ) + if self._tool_settings.vector_store_settings: + # Only add the vector store similarity search tool if the vector store + # settings are specified. + all_tools.append( + GoogleTool( + func=search_tool.vector_store_similarity_search, + credentials_config=self._credentials_config, + tool_settings=self._tool_settings, + ) + ) return [ tool diff --git a/src/google/adk/tools/spanner/utils.py b/src/google/adk/tools/spanner/utils.py index 64c03859e1..f1b710ec85 100644 --- a/src/google/adk/tools/spanner/utils.py +++ b/src/google/adk/tools/spanner/utils.py @@ -105,3 +105,27 @@ def execute_sql( "status": "ERROR", "error_details": str(ex), } + + +def embed_contents( + vertex_ai_embedding_model_name: str, + contents: list[str], + output_dimensionality: Optional[int] = None, +) -> list[list[float]]: + """Embed the given contents into list of vectors using the Vertex AI embedding model endpoint.""" + try: + from google.genai import Client + from google.genai.types import EmbedContentConfig + + client = Client() + config = EmbedContentConfig() + if output_dimensionality: + config.output_dimensionality = output_dimensionality + response = client.models.embed_content( + model=vertex_ai_embedding_model_name, + contents=contents, + config=config, + ) + return [list(e.values) for e in response.embeddings] + except Exception as ex: + raise RuntimeError(f"Failed to embed content: {ex!r}") from ex diff --git a/src/google/adk/tools/transfer_to_agent_tool.py b/src/google/adk/tools/transfer_to_agent_tool.py index 99ee234b30..2124e6aab9 100644 --- a/src/google/adk/tools/transfer_to_agent_tool.py +++ b/src/google/adk/tools/transfer_to_agent_tool.py @@ -14,6 +14,12 @@ from __future__ import annotations +from typing import Optional + +from google.genai import types +from typing_extensions import override + +from .function_tool import FunctionTool from .tool_context import ToolContext @@ -23,7 +29,61 @@ def transfer_to_agent(agent_name: str, tool_context: ToolContext) -> None: This tool hands off control to another agent when it's more suitable to answer the user's question according to the agent's description. + Note: + For most use cases, you should use TransferToAgentTool instead of this + function directly. TransferToAgentTool provides additional enum constraints + that prevent LLMs from hallucinating invalid agent names. + Args: agent_name: the agent name to transfer to. """ tool_context.actions.transfer_to_agent = agent_name + + +class TransferToAgentTool(FunctionTool): + """A specialized FunctionTool for agent transfer with enum constraints. + + This tool enhances the base transfer_to_agent function by adding JSON Schema + enum constraints to the agent_name parameter. This prevents LLMs from + hallucinating invalid agent names by restricting choices to only valid agents. + + Attributes: + agent_names: List of valid agent names that can be transferred to. + """ + + def __init__( + self, + agent_names: list[str], + ): + """Initialize the TransferToAgentTool. + + Args: + agent_names: List of valid agent names that can be transferred to. + """ + super().__init__(func=transfer_to_agent) + self._agent_names = agent_names + + @override + def _get_declaration(self) -> Optional[types.FunctionDeclaration]: + """Add enum constraint to the agent_name parameter. + + Returns: + FunctionDeclaration with enum constraint on agent_name parameter. + """ + function_decl = super()._get_declaration() + if not function_decl: + return function_decl + + # Handle parameters (types.Schema object) + if function_decl.parameters: + agent_name_schema = function_decl.parameters.properties.get('agent_name') + if agent_name_schema: + agent_name_schema.enum = self._agent_names + + # Handle parameters_json_schema (dict) + if function_decl.parameters_json_schema: + properties = function_decl.parameters_json_schema.get('properties', {}) + if 'agent_name' in properties: + properties['agent_name']['enum'] = self._agent_names + + return function_decl diff --git a/src/google/adk/tools/vertex_ai_search_tool.py b/src/google/adk/tools/vertex_ai_search_tool.py index aff5be1552..e0c228be0e 100644 --- a/src/google/adk/tools/vertex_ai_search_tool.py +++ b/src/google/adk/tools/vertex_ai_search_tool.py @@ -14,6 +14,7 @@ from __future__ import annotations +import logging from typing import Optional from typing import TYPE_CHECKING @@ -25,6 +26,8 @@ from .base_tool import BaseTool from .tool_context import ToolContext +logger = logging.getLogger('google_adk.' + __name__) + if TYPE_CHECKING: from ..models import LlmRequest @@ -102,6 +105,30 @@ async def process_llm_request( ) llm_request.config = llm_request.config or types.GenerateContentConfig() llm_request.config.tools = llm_request.config.tools or [] + + # Format data_store_specs concisely for logging + if self.data_store_specs: + spec_ids = [ + spec.data_store.split('/')[-1] if spec.data_store else 'unnamed' + for spec in self.data_store_specs + ] + specs_info = ( + f'{len(self.data_store_specs)} spec(s): [{", ".join(spec_ids)}]' + ) + else: + specs_info = None + + logger.debug( + 'Adding Vertex AI Search tool config to LLM request: ' + 'datastore=%s, engine=%s, filter=%s, max_results=%s, ' + 'data_store_specs=%s', + self.data_store_id, + self.search_engine_id, + self.filter, + self.max_results, + specs_info, + ) + llm_request.config.tools.append( types.Tool( retrieval=types.Retrieval( diff --git a/src/google/adk/utils/_client_labels_utils.py b/src/google/adk/utils/_client_labels_utils.py new file mode 100644 index 0000000000..72858c3c1d --- /dev/null +++ b/src/google/adk/utils/_client_labels_utils.py @@ -0,0 +1,78 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from contextlib import contextmanager +import contextvars +import os +import sys +from typing import List + +from .. import version + +_ADK_LABEL = "google-adk" +_LANGUAGE_LABEL = "gl-python" +_AGENT_ENGINE_TELEMETRY_TAG = "remote_reasoning_engine" +_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME = "GOOGLE_CLOUD_AGENT_ENGINE_ID" + + +EVAL_CLIENT_LABEL = f"google-adk-eval/{version.__version__}" +"""Label used to denote calls emerging to external system as a part of Evals.""" + +# The ContextVar holds client label collected for the current request. +_LABEL_CONTEXT: contextvars.ContextVar[str] = contextvars.ContextVar( + "_LABEL_CONTEXT", default=None +) + + +def _get_default_labels() -> List[str]: + """Returns a list of labels that are always added.""" + framework_label = f"{_ADK_LABEL}/{version.__version__}" + + if os.environ.get(_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME): + framework_label = f"{framework_label}+{_AGENT_ENGINE_TELEMETRY_TAG}" + + language_label = f"{_LANGUAGE_LABEL}/" + sys.version.split()[0] + return [framework_label, language_label] + + +@contextmanager +def client_label_context(client_label: str): + """Runs the operation within the context of the given client label.""" + current_client_label = _LABEL_CONTEXT.get() + + if current_client_label is not None: + raise ValueError( + "Client label already exists. You can only add one client label." + ) + + token = _LABEL_CONTEXT.set(client_label) + + try: + yield + finally: + # Restore the previous state of the context variable + _LABEL_CONTEXT.reset(token) + + +def get_client_labels() -> List[str]: + """Returns the current list of client labels that can be added to HTTP Headers.""" + labels = _get_default_labels() + current_client_label = _LABEL_CONTEXT.get() + + if current_client_label: + labels.append(current_client_label) + + return labels diff --git a/src/google/adk/utils/context_utils.py b/src/google/adk/utils/context_utils.py index bd8dacb9d8..a75feae3dd 100644 --- a/src/google/adk/utils/context_utils.py +++ b/src/google/adk/utils/context_utils.py @@ -20,30 +20,7 @@ from __future__ import annotations -from contextlib import AbstractAsyncContextManager -from typing import Any -from typing import AsyncGenerator +from contextlib import aclosing - -class Aclosing(AbstractAsyncContextManager): - """Async context manager for safely finalizing an asynchronously cleaned-up - resource such as an async generator, calling its ``aclose()`` method. - Needed to correctly close contexts for OTel spans. - See https://github.com/google/adk-python/issues/1670#issuecomment-3115891100. - - Based on - https://docs.python.org/3/library/contextlib.html#contextlib.aclosing - which is available in Python 3.10+. - - TODO: replace all occurrences with contextlib.aclosing once Python 3.9 is no - longer supported. - """ - - def __init__(self, async_generator: AsyncGenerator[Any, None]): - self.async_generator = async_generator - - async def __aenter__(self): - return self.async_generator - - async def __aexit__(self, *exc_info): - await self.async_generator.aclose() +# Re-export aclosing for backward compatibility +Aclosing = aclosing diff --git a/src/google/adk/utils/streaming_utils.py b/src/google/adk/utils/streaming_utils.py index eb75365467..eae80aa7cc 100644 --- a/src/google/adk/utils/streaming_utils.py +++ b/src/google/adk/utils/streaming_utils.py @@ -14,6 +14,7 @@ from __future__ import annotations +from typing import Any from typing import AsyncGenerator from typing import Optional @@ -43,6 +44,12 @@ def __init__(self): self._current_text_is_thought: Optional[bool] = None self._finish_reason: Optional[types.FinishReason] = None + # For streaming function call arguments + self._current_fc_name: Optional[str] = None + self._current_fc_args: dict[str, Any] = {} + self._current_fc_id: Optional[str] = None + self._current_thought_signature: Optional[str] = None + def _flush_text_buffer_to_sequence(self): """Flush current text buffer to parts sequence. @@ -61,6 +68,171 @@ def _flush_text_buffer_to_sequence(self): self._current_text_buffer = '' self._current_text_is_thought = None + def _get_value_from_partial_arg( + self, partial_arg: types.PartialArg, json_path: str + ): + """Extract value from a partial argument. + + Args: + partial_arg: The partial argument object + json_path: JSONPath for this argument + + Returns: + Tuple of (value, has_value) where has_value indicates if a value exists + """ + value = None + has_value = False + + if partial_arg.string_value is not None: + # For streaming strings, append chunks to existing value + string_chunk = partial_arg.string_value + has_value = True + + # Get current value for this path (if any) + path_without_prefix = ( + json_path[2:] if json_path.startswith('$.') else json_path + ) + path_parts = path_without_prefix.split('.') + + # Try to get existing value + existing_value = self._current_fc_args + for part in path_parts: + if isinstance(existing_value, dict) and part in existing_value: + existing_value = existing_value[part] + else: + existing_value = None + break + + # Append to existing string or set new value + if isinstance(existing_value, str): + value = existing_value + string_chunk + else: + value = string_chunk + + elif partial_arg.number_value is not None: + value = partial_arg.number_value + has_value = True + elif partial_arg.bool_value is not None: + value = partial_arg.bool_value + has_value = True + elif partial_arg.null_value is not None: + value = None + has_value = True + + return value, has_value + + def _set_value_by_json_path(self, json_path: str, value: Any): + """Set a value in _current_fc_args using JSONPath notation. + + Args: + json_path: JSONPath string like "$.location" or "$.location.latitude" + value: The value to set + """ + # Remove leading "$." from jsonPath + if json_path.startswith('$.'): + path = json_path[2:] + else: + path = json_path + + # Split path into components + path_parts = path.split('.') + + # Navigate to the correct location and set the value + current = self._current_fc_args + for part in path_parts[:-1]: + if part not in current: + current[part] = {} + current = current[part] + + # Set the final value + current[path_parts[-1]] = value + + def _flush_function_call_to_sequence(self): + """Flush current function call to parts sequence. + + This creates a complete FunctionCall part from accumulated partial args. + """ + if self._current_fc_name: + # Create function call part with accumulated args + fc_part = types.Part.from_function_call( + name=self._current_fc_name, + args=self._current_fc_args.copy(), + ) + + # Set the ID if provided (directly on the function_call object) + if self._current_fc_id and fc_part.function_call: + fc_part.function_call.id = self._current_fc_id + + # Set thought_signature if provided (on the Part, not FunctionCall) + if self._current_thought_signature: + fc_part.thought_signature = self._current_thought_signature + + self._parts_sequence.append(fc_part) + + # Reset FC state + self._current_fc_name = None + self._current_fc_args = {} + self._current_fc_id = None + self._current_thought_signature = None + + def _process_streaming_function_call(self, fc: types.FunctionCall): + """Process a streaming function call with partialArgs. + + Args: + fc: The function call object with partial_args + """ + # Save function name if present (first chunk) + if fc.name: + self._current_fc_name = fc.name + if fc.id: + self._current_fc_id = fc.id + + # Process each partial argument + for partial_arg in getattr(fc, 'partial_args', []): + json_path = partial_arg.json_path + if not json_path: + continue + + # Extract value from partial arg + value, has_value = self._get_value_from_partial_arg( + partial_arg, json_path + ) + + # Set the value using JSONPath (only if a value was provided) + if has_value: + self._set_value_by_json_path(json_path, value) + + # Check if function call is complete + fc_will_continue = getattr(fc, 'will_continue', False) + if not fc_will_continue: + # Function call complete, flush it + self._flush_text_buffer_to_sequence() + self._flush_function_call_to_sequence() + + def _process_function_call_part(self, part: types.Part): + """Process a function call part (streaming or non-streaming). + + Args: + part: The part containing a function call + """ + fc = part.function_call + + # Check if this is a streaming FC (has partialArgs) + if hasattr(fc, 'partial_args') and fc.partial_args: + # Streaming function call arguments + + # Save thought_signature from the part (first chunk should have it) + if part.thought_signature and not self._current_thought_signature: + self._current_thought_signature = part.thought_signature + self._process_streaming_function_call(fc) + else: + # Non-streaming function call (standard format with args) + # Skip empty function calls (used as streaming end markers) + if fc.name: + # Flush any buffered text first, then add the FC part + self._flush_text_buffer_to_sequence() + self._parts_sequence.append(part) + async def process_response( self, response: types.GenerateContentResponse ) -> AsyncGenerator[LlmResponse, None]: @@ -101,8 +273,12 @@ async def process_response( if not self._current_text_buffer: self._current_text_is_thought = part.thought self._current_text_buffer += part.text + elif part.function_call: + # Process function call (handles both streaming Args and + # non-streaming Args) + self._process_function_call_part(part) else: - # Non-text part (function_call, bytes, etc.) + # Other non-text parts (bytes, etc.) # Flush any buffered text first, then add the non-text part self._flush_text_buffer_to_sequence() self._parts_sequence.append(part) @@ -155,8 +331,9 @@ def close(self) -> Optional[LlmResponse]: if is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING): # Always generate final aggregated response in progressive mode if self._response and self._response.candidates: - # Flush any remaining text buffer to complete the sequence + # Flush any remaining buffers to complete the sequence self._flush_text_buffer_to_sequence() + self._flush_function_call_to_sequence() # Use the parts sequence which preserves original ordering final_parts = self._parts_sequence diff --git a/src/google/adk/version.py b/src/google/adk/version.py index 0a21522cb6..a287db284c 100644 --- a/src/google/adk/version.py +++ b/src/google/adk/version.py @@ -13,4 +13,4 @@ # limitations under the License. # version: major.minor.patch -__version__ = "1.18.0" +__version__ = "1.20.0" diff --git a/tests/unittests/a2a/converters/test_event_converter.py b/tests/unittests/a2a/converters/test_event_converter.py index cb3f7a6858..49b7d3c2b6 100644 --- a/tests/unittests/a2a/converters/test_event_converter.py +++ b/tests/unittests/a2a/converters/test_event_converter.py @@ -12,50 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys from unittest.mock import Mock from unittest.mock import patch +from a2a.types import DataPart +from a2a.types import Message +from a2a.types import Role +from a2a.types import Task +from a2a.types import TaskState +from a2a.types import TaskStatusUpdateEvent +from google.adk.a2a.converters.event_converter import _create_artifact_id +from google.adk.a2a.converters.event_converter import _create_error_status_event +from google.adk.a2a.converters.event_converter import _create_status_update_event +from google.adk.a2a.converters.event_converter import _get_adk_metadata_key +from google.adk.a2a.converters.event_converter import _get_context_metadata +from google.adk.a2a.converters.event_converter import _process_long_running_tool +from google.adk.a2a.converters.event_converter import _serialize_metadata_value +from google.adk.a2a.converters.event_converter import ARTIFACT_ID_SEPARATOR +from google.adk.a2a.converters.event_converter import convert_a2a_task_to_event +from google.adk.a2a.converters.event_converter import convert_event_to_a2a_events +from google.adk.a2a.converters.event_converter import convert_event_to_a2a_message +from google.adk.a2a.converters.event_converter import DEFAULT_ERROR_MESSAGE +from google.adk.a2a.converters.utils import ADK_METADATA_KEY_PREFIX +from google.adk.agents.invocation_context import InvocationContext +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from a2a.types import DataPart - from a2a.types import Message - from a2a.types import Role - from a2a.types import Task - from a2a.types import TaskState - from a2a.types import TaskStatusUpdateEvent - from google.adk.a2a.converters.event_converter import _create_artifact_id - from google.adk.a2a.converters.event_converter import _create_error_status_event - from google.adk.a2a.converters.event_converter import _create_status_update_event - from google.adk.a2a.converters.event_converter import _get_adk_metadata_key - from google.adk.a2a.converters.event_converter import _get_context_metadata - from google.adk.a2a.converters.event_converter import _process_long_running_tool - from google.adk.a2a.converters.event_converter import _serialize_metadata_value - from google.adk.a2a.converters.event_converter import ARTIFACT_ID_SEPARATOR - from google.adk.a2a.converters.event_converter import convert_a2a_task_to_event - from google.adk.a2a.converters.event_converter import convert_event_to_a2a_events - from google.adk.a2a.converters.event_converter import convert_event_to_a2a_message - from google.adk.a2a.converters.event_converter import DEFAULT_ERROR_MESSAGE - from google.adk.a2a.converters.utils import ADK_METADATA_KEY_PREFIX - from google.adk.agents.invocation_context import InvocationContext - from google.adk.events.event import Event - from google.adk.events.event_actions import EventActions -except ImportError as e: - if sys.version_info < (3, 10): - # Imports are not needed since tests will be skipped due to pytestmark. - # The imported names are only used within test methods, not at module level, - # so no NameError occurs during module compilation. - pass - else: - raise e - class TestEventConverter: """Test suite for event_converter module.""" diff --git a/tests/unittests/a2a/converters/test_part_converter.py b/tests/unittests/a2a/converters/test_part_converter.py index 5a8bad1096..541ab7709d 100644 --- a/tests/unittests/a2a/converters/test_part_converter.py +++ b/tests/unittests/a2a/converters/test_part_converter.py @@ -13,38 +13,21 @@ # limitations under the License. import json -import sys from unittest.mock import Mock from unittest.mock import patch +from a2a import types as a2a_types +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_KEY +from google.adk.a2a.converters.part_converter import convert_a2a_part_to_genai_part +from google.adk.a2a.converters.part_converter import convert_genai_part_to_a2a_part +from google.adk.a2a.converters.utils import _get_adk_metadata_key +from google.genai import types as genai_types import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from a2a import types as a2a_types - from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT - from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE - from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL - from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE - from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_KEY - from google.adk.a2a.converters.part_converter import convert_a2a_part_to_genai_part - from google.adk.a2a.converters.part_converter import convert_genai_part_to_a2a_part - from google.adk.a2a.converters.utils import _get_adk_metadata_key - from google.genai import types as genai_types -except ImportError as e: - if sys.version_info < (3, 10): - # Imports are not needed since tests will be skipped due to pytestmark. - # The imported names are only used within test methods, not at module level, - # so no NameError occurs during module compilation. - pass - else: - raise e - class TestConvertA2aPartToGenaiPart: """Test cases for convert_a2a_part_to_genai_part function.""" diff --git a/tests/unittests/a2a/converters/test_request_converter.py b/tests/unittests/a2a/converters/test_request_converter.py index a7c21e4dbc..173b122d7c 100644 --- a/tests/unittests/a2a/converters/test_request_converter.py +++ b/tests/unittests/a2a/converters/test_request_converter.py @@ -12,33 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys from unittest.mock import Mock from unittest.mock import patch +from a2a.server.agent_execution import RequestContext +from google.adk.a2a.converters.request_converter import _get_user_id +from google.adk.a2a.converters.request_converter import convert_a2a_request_to_agent_run_request +from google.adk.runners import RunConfig +from google.genai import types as genai_types import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from a2a.server.agent_execution import RequestContext - from google.adk.a2a.converters.request_converter import _get_user_id - from google.adk.a2a.converters.request_converter import convert_a2a_request_to_agent_run_request - from google.adk.runners import RunConfig - from google.genai import types as genai_types -except ImportError as e: - if sys.version_info < (3, 10): - # Imports are not needed since tests will be skipped due to pytestmark. - # The imported names are only used within test methods, not at module level, - # so no NameError occurs during module compilation. - pass - else: - raise e - class TestGetUserId: """Test cases for _get_user_id function.""" diff --git a/tests/unittests/a2a/converters/test_utils.py b/tests/unittests/a2a/converters/test_utils.py index 6c8511161a..0d896852aa 100644 --- a/tests/unittests/a2a/converters/test_utils.py +++ b/tests/unittests/a2a/converters/test_utils.py @@ -12,31 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys - +from google.adk.a2a.converters.utils import _from_a2a_context_id +from google.adk.a2a.converters.utils import _get_adk_metadata_key +from google.adk.a2a.converters.utils import _to_a2a_context_id +from google.adk.a2a.converters.utils import ADK_CONTEXT_ID_PREFIX +from google.adk.a2a.converters.utils import ADK_METADATA_KEY_PREFIX import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from google.adk.a2a.converters.utils import _from_a2a_context_id - from google.adk.a2a.converters.utils import _get_adk_metadata_key - from google.adk.a2a.converters.utils import _to_a2a_context_id - from google.adk.a2a.converters.utils import ADK_CONTEXT_ID_PREFIX - from google.adk.a2a.converters.utils import ADK_METADATA_KEY_PREFIX -except ImportError as e: - if sys.version_info < (3, 10): - # Imports are not needed since tests will be skipped due to pytestmark. - # The imported names are only used within test methods, not at module level, - # so no NameError occurs during module compilation. - pass - else: - raise e - class TestUtilsFunctions: """Test suite for utils module functions.""" diff --git a/tests/unittests/a2a/executor/test_a2a_agent_executor.py b/tests/unittests/a2a/executor/test_a2a_agent_executor.py index 4bcc7a91d7..58d7521f7d 100644 --- a/tests/unittests/a2a/executor/test_a2a_agent_executor.py +++ b/tests/unittests/a2a/executor/test_a2a_agent_executor.py @@ -12,41 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys from unittest.mock import AsyncMock from unittest.mock import Mock from unittest.mock import patch +from a2a.server.agent_execution.context import RequestContext +from a2a.server.events.event_queue import EventQueue +from a2a.types import Message +from a2a.types import TaskState +from a2a.types import TextPart +from google.adk.a2a.converters.request_converter import AgentRunRequest +from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor +from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutorConfig +from google.adk.events.event import Event +from google.adk.runners import RunConfig +from google.adk.runners import Runner +from google.genai.types import Content import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from a2a.server.agent_execution.context import RequestContext - from a2a.server.events.event_queue import EventQueue - from a2a.types import Message - from a2a.types import TaskState - from a2a.types import TextPart - from google.adk.a2a.converters.request_converter import AgentRunRequest - from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor - from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutorConfig - from google.adk.events.event import Event - from google.adk.runners import RunConfig - from google.adk.runners import Runner - from google.genai.types import Content -except ImportError as e: - if sys.version_info < (3, 10): - # Imports are not needed since tests will be skipped due to pytestmark. - # The imported names are only used within test methods, not at module level, - # so no NameError occurs during module compilation. - pass - else: - raise e - class TestA2aAgentExecutor: """Test suite for A2aAgentExecutor class.""" diff --git a/tests/unittests/a2a/executor/test_task_result_aggregator.py b/tests/unittests/a2a/executor/test_task_result_aggregator.py index 9d03db9dc8..b809b62728 100644 --- a/tests/unittests/a2a/executor/test_task_result_aggregator.py +++ b/tests/unittests/a2a/executor/test_task_result_aggregator.py @@ -12,35 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys from unittest.mock import Mock +from a2a.types import Message +from a2a.types import Part +from a2a.types import Role +from a2a.types import TaskState +from a2a.types import TaskStatus +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart +from google.adk.a2a.executor.task_result_aggregator import TaskResultAggregator import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from a2a.types import Message - from a2a.types import Part - from a2a.types import Role - from a2a.types import TaskState - from a2a.types import TaskStatus - from a2a.types import TaskStatusUpdateEvent - from a2a.types import TextPart - from google.adk.a2a.executor.task_result_aggregator import TaskResultAggregator -except ImportError as e: - if sys.version_info < (3, 10): - # Imports are not needed since tests will be skipped due to pytestmark. - # The imported names are only used within test methods, not at module level, - # so no NameError occurs during module compilation. - pass - else: - raise e - def create_test_message(text: str): """Helper function to create a test Message object.""" diff --git a/tests/unittests/a2a/utils/test_agent_card_builder.py b/tests/unittests/a2a/utils/test_agent_card_builder.py index e0b62468e5..3bf3202897 100644 --- a/tests/unittests/a2a/utils/test_agent_card_builder.py +++ b/tests/unittests/a2a/utils/test_agent_card_builder.py @@ -12,55 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys from unittest.mock import Mock from unittest.mock import patch +from a2a.types import AgentCapabilities +from a2a.types import AgentCard +from a2a.types import AgentProvider +from a2a.types import AgentSkill +from a2a.types import SecurityScheme +from google.adk.a2a.utils.agent_card_builder import _build_agent_description +from google.adk.a2a.utils.agent_card_builder import _build_llm_agent_description_with_instructions +from google.adk.a2a.utils.agent_card_builder import _build_loop_description +from google.adk.a2a.utils.agent_card_builder import _build_orchestration_skill +from google.adk.a2a.utils.agent_card_builder import _build_parallel_description +from google.adk.a2a.utils.agent_card_builder import _build_sequential_description +from google.adk.a2a.utils.agent_card_builder import _convert_example_tool_examples +from google.adk.a2a.utils.agent_card_builder import _extract_examples_from_instruction +from google.adk.a2a.utils.agent_card_builder import _get_agent_skill_name +from google.adk.a2a.utils.agent_card_builder import _get_agent_type +from google.adk.a2a.utils.agent_card_builder import _get_default_description +from google.adk.a2a.utils.agent_card_builder import _get_input_modes +from google.adk.a2a.utils.agent_card_builder import _get_output_modes +from google.adk.a2a.utils.agent_card_builder import _get_workflow_description +from google.adk.a2a.utils.agent_card_builder import _replace_pronouns +from google.adk.a2a.utils.agent_card_builder import AgentCardBuilder +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.loop_agent import LoopAgent +from google.adk.agents.parallel_agent import ParallelAgent +from google.adk.agents.sequential_agent import SequentialAgent +from google.adk.tools.example_tool import ExampleTool import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from a2a.types import AgentCapabilities - from a2a.types import AgentCard - from a2a.types import AgentProvider - from a2a.types import AgentSkill - from a2a.types import SecurityScheme - from google.adk.a2a.utils.agent_card_builder import _build_agent_description - from google.adk.a2a.utils.agent_card_builder import _build_llm_agent_description_with_instructions - from google.adk.a2a.utils.agent_card_builder import _build_loop_description - from google.adk.a2a.utils.agent_card_builder import _build_orchestration_skill - from google.adk.a2a.utils.agent_card_builder import _build_parallel_description - from google.adk.a2a.utils.agent_card_builder import _build_sequential_description - from google.adk.a2a.utils.agent_card_builder import _convert_example_tool_examples - from google.adk.a2a.utils.agent_card_builder import _extract_examples_from_instruction - from google.adk.a2a.utils.agent_card_builder import _get_agent_skill_name - from google.adk.a2a.utils.agent_card_builder import _get_agent_type - from google.adk.a2a.utils.agent_card_builder import _get_default_description - from google.adk.a2a.utils.agent_card_builder import _get_input_modes - from google.adk.a2a.utils.agent_card_builder import _get_output_modes - from google.adk.a2a.utils.agent_card_builder import _get_workflow_description - from google.adk.a2a.utils.agent_card_builder import _replace_pronouns - from google.adk.a2a.utils.agent_card_builder import AgentCardBuilder - from google.adk.agents.base_agent import BaseAgent - from google.adk.agents.llm_agent import LlmAgent - from google.adk.agents.loop_agent import LoopAgent - from google.adk.agents.parallel_agent import ParallelAgent - from google.adk.agents.sequential_agent import SequentialAgent - from google.adk.tools.example_tool import ExampleTool -except ImportError as e: - if sys.version_info < (3, 10): - # Imports are not needed since tests will be skipped due to pytestmark. - # The imported names are only used within test methods, not at module level, - # so no NameError occurs during module compilation. - pass - else: - raise e - class TestAgentCardBuilder: """Test suite for AgentCardBuilder class.""" diff --git a/tests/unittests/a2a/utils/test_agent_to_a2a.py b/tests/unittests/a2a/utils/test_agent_to_a2a.py index ee80b0233b..503e572f2f 100644 --- a/tests/unittests/a2a/utils/test_agent_to_a2a.py +++ b/tests/unittests/a2a/utils/test_agent_to_a2a.py @@ -12,42 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys from unittest.mock import AsyncMock from unittest.mock import Mock from unittest.mock import patch +from a2a.server.apps import A2AStarletteApplication +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import InMemoryTaskStore +from a2a.types import AgentCard +from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor +from google.adk.a2a.utils.agent_card_builder import AgentCardBuilder +from google.adk.a2a.utils.agent_to_a2a import to_a2a +from google.adk.agents.base_agent import BaseAgent +from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService +from google.adk.memory.in_memory_memory_service import InMemoryMemoryService +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService import pytest - -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from a2a.server.apps import A2AStarletteApplication - from a2a.server.request_handlers import DefaultRequestHandler - from a2a.server.tasks import InMemoryTaskStore - from a2a.types import AgentCard - from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor - from google.adk.a2a.utils.agent_card_builder import AgentCardBuilder - from google.adk.a2a.utils.agent_to_a2a import to_a2a - from google.adk.agents.base_agent import BaseAgent - from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService - from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService - from google.adk.memory.in_memory_memory_service import InMemoryMemoryService - from google.adk.runners import Runner - from google.adk.sessions.in_memory_session_service import InMemorySessionService - from starlette.applications import Starlette -except ImportError as e: - if sys.version_info < (3, 10): - # Imports are not needed since tests will be skipped due to pytestmark. - # The imported names are only used within test methods, not at module level, - # so no NameError occurs during module compilation. - pass - else: - raise e +from starlette.applications import Starlette class TestToA2A: diff --git a/tests/unittests/agents/test_agent_config.py b/tests/unittests/agents/test_agent_config.py index c2300f5f5d..86fda7fc9b 100644 --- a/tests/unittests/agents/test_agent_config.py +++ b/tests/unittests/agents/test_agent_config.py @@ -12,18 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +import ntpath +import os from pathlib import Path +from textwrap import dedent from typing import Literal from typing import Type +from unittest import mock from google.adk.agents import config_agent_utils from google.adk.agents.agent_config import AgentConfig from google.adk.agents.base_agent import BaseAgent from google.adk.agents.base_agent_config import BaseAgentConfig +from google.adk.agents.common_configs import AgentRefConfig from google.adk.agents.llm_agent import LlmAgent from google.adk.agents.loop_agent import LoopAgent from google.adk.agents.parallel_agent import ParallelAgent from google.adk.agents.sequential_agent import SequentialAgent +from google.adk.models.lite_llm import LiteLlm import pytest import yaml @@ -254,6 +260,53 @@ def test_agent_config_discriminator_llm_agent_with_sub_agents( assert config.root.agent_class == agent_class_value +def test_agent_config_litellm_model_with_custom_args(tmp_path: Path): + yaml_content = """\ +name: managed_api_agent +description: Agent using LiteLLM managed endpoint +instruction: Respond concisely. +model_code: + name: google.adk.models.lite_llm.LiteLlm + args: + - name: model + value: kimi/k2 + - name: api_base + value: https://proxy.litellm.ai/v1 +""" + config_file = tmp_path / "litellm_agent.yaml" + config_file.write_text(yaml_content) + + agent = config_agent_utils.from_config(str(config_file)) + + assert isinstance(agent, LlmAgent) + assert isinstance(agent.model, LiteLlm) + assert agent.model.model == "kimi/k2" + assert agent.model._additional_args.get("api_base") == ( + "https://proxy.litellm.ai/v1" + ) + + +def test_agent_config_legacy_model_mapping_still_supported(tmp_path: Path): + yaml_content = """\ +name: managed_api_agent +description: Agent using LiteLLM managed endpoint +instruction: Respond concisely. +model: + name: google.adk.models.lite_llm.LiteLlm + args: + - name: model + value: kimi/k2 +""" + config_file = tmp_path / "legacy_litellm_agent.yaml" + config_file.write_text(yaml_content) + + agent = config_agent_utils.from_config(str(config_file)) + + assert isinstance(agent, LlmAgent) + assert isinstance(agent.model, LiteLlm) + assert agent.model.model == "kimi/k2" + + def test_agent_config_discriminator_custom_agent(): class MyCustomAgentConfig(BaseAgentConfig): agent_class: Literal["mylib.agents.MyCustomAgent"] = ( @@ -280,3 +333,91 @@ class MyCustomAgentConfig(BaseAgentConfig): config.root.model_dump() ) assert my_custom_config.other_field == "other value" + + +@pytest.mark.parametrize( + ("config_rel_path", "child_rel_path", "child_name", "instruction"), + [ + ( + Path("main.yaml"), + Path("sub_agents/child.yaml"), + "child_agent", + "I am a child agent", + ), + ( + Path("level1/level2/nested_main.yaml"), + Path("sub/nested_child.yaml"), + "nested_child", + "I am nested", + ), + ], +) +def test_resolve_agent_reference_resolves_relative_paths( + config_rel_path: Path, + child_rel_path: Path, + child_name: str, + instruction: str, + tmp_path: Path, +): + """Verify resolve_agent_reference resolves relative sub-agent paths.""" + config_file = tmp_path / config_rel_path + config_file.parent.mkdir(parents=True, exist_ok=True) + + child_config_path = config_file.parent / child_rel_path + child_config_path.parent.mkdir(parents=True, exist_ok=True) + child_config_path.write_text(dedent(f""" + agent_class: LlmAgent + name: {child_name} + model: gemini-2.0-flash + instruction: {instruction} + """).lstrip()) + + config_file.write_text(dedent(f""" + agent_class: LlmAgent + name: main_agent + model: gemini-2.0-flash + instruction: I am the main agent + sub_agents: + - config_path: {child_rel_path.as_posix()} + """).lstrip()) + + ref_config = AgentRefConfig(config_path=child_rel_path.as_posix()) + agent = config_agent_utils.resolve_agent_reference( + ref_config, str(config_file) + ) + + assert agent.name == child_name + + config_dir = os.path.dirname(str(config_file.resolve())) + assert config_dir == str(config_file.parent.resolve()) + + expected_child_path = os.path.join(config_dir, *child_rel_path.parts) + assert os.path.exists(expected_child_path) + + +def test_resolve_agent_reference_uses_windows_dirname(): + """Ensure Windows-style config references resolve via os.path.dirname.""" + ref_config = AgentRefConfig(config_path="sub\\child.yaml") + recorded: dict[str, str] = {} + + def fake_from_config(path: str): + recorded["path"] = path + return "sentinel" + + with ( + mock.patch.object( + config_agent_utils, + "from_config", + autospec=True, + side_effect=fake_from_config, + ), + mock.patch.object(config_agent_utils.os, "path", ntpath), + ): + referencing = r"C:\workspace\agents\main.yaml" + result = config_agent_utils.resolve_agent_reference(ref_config, referencing) + + expected_path = ntpath.join( + ntpath.dirname(referencing), ref_config.config_path + ) + assert result == "sentinel" + assert recorded["path"] == expected_path diff --git a/tests/unittests/agents/test_base_agent.py b/tests/unittests/agents/test_base_agent.py index 663179f670..259bdd51c2 100644 --- a/tests/unittests/agents/test_base_agent.py +++ b/tests/unittests/agents/test_base_agent.py @@ -16,6 +16,7 @@ from enum import Enum from functools import partial +import logging from typing import AsyncGenerator from typing import List from typing import Optional @@ -854,6 +855,112 @@ def test_set_parent_agent_for_sub_agent_twice( ) +def test_validate_sub_agents_unique_names_single_duplicate( + request: pytest.FixtureRequest, + caplog: pytest.LogCaptureFixture, +): + """Test that duplicate sub-agent names logs a warning.""" + duplicate_name = f'{request.function.__name__}_duplicate_agent' + sub_agent_1 = _TestingAgent(name=duplicate_name) + sub_agent_2 = _TestingAgent(name=duplicate_name) + + with caplog.at_level(logging.WARNING): + _ = _TestingAgent( + name=f'{request.function.__name__}_parent', + sub_agents=[sub_agent_1, sub_agent_2], + ) + assert f'Found duplicate sub-agent names: `{duplicate_name}`' in caplog.text + + +def test_validate_sub_agents_unique_names_multiple_duplicates( + request: pytest.FixtureRequest, + caplog: pytest.LogCaptureFixture, +): + """Test that multiple duplicate sub-agent names are all reported.""" + duplicate_name_1 = f'{request.function.__name__}_duplicate_1' + duplicate_name_2 = f'{request.function.__name__}_duplicate_2' + + sub_agents = [ + _TestingAgent(name=duplicate_name_1), + _TestingAgent(name=f'{request.function.__name__}_unique'), + _TestingAgent(name=duplicate_name_1), # First duplicate + _TestingAgent(name=duplicate_name_2), + _TestingAgent(name=duplicate_name_2), # Second duplicate + ] + + with caplog.at_level(logging.WARNING): + _ = _TestingAgent( + name=f'{request.function.__name__}_parent', + sub_agents=sub_agents, + ) + + # Verify each duplicate name appears exactly once in the error message + assert caplog.text.count(duplicate_name_1) == 1 + assert caplog.text.count(duplicate_name_2) == 1 + # Verify both duplicate names are present + assert duplicate_name_1 in caplog.text + assert duplicate_name_2 in caplog.text + + +def test_validate_sub_agents_unique_names_triple_duplicate( + request: pytest.FixtureRequest, + caplog: pytest.LogCaptureFixture, +): + """Test that a name appearing three times is reported only once.""" + duplicate_name = f'{request.function.__name__}_triple_duplicate' + + sub_agents = [ + _TestingAgent(name=duplicate_name), + _TestingAgent(name=f'{request.function.__name__}_unique'), + _TestingAgent(name=duplicate_name), # Second occurrence + _TestingAgent(name=duplicate_name), # Third occurrence + ] + + with caplog.at_level(logging.WARNING): + _ = _TestingAgent( + name=f'{request.function.__name__}_parent', + sub_agents=sub_agents, + ) + + # Verify the duplicate name appears exactly once in the error message + # (not three times even though it appears three times in the list) + assert caplog.text.count(duplicate_name) == 1 + assert duplicate_name in caplog.text + + +def test_validate_sub_agents_unique_names_no_duplicates( + request: pytest.FixtureRequest, +): + """Test that unique sub-agent names pass validation.""" + sub_agents = [ + _TestingAgent(name=f'{request.function.__name__}_sub_agent_1'), + _TestingAgent(name=f'{request.function.__name__}_sub_agent_2'), + _TestingAgent(name=f'{request.function.__name__}_sub_agent_3'), + ] + + parent = _TestingAgent( + name=f'{request.function.__name__}_parent', + sub_agents=sub_agents, + ) + + assert len(parent.sub_agents) == 3 + assert parent.sub_agents[0].name == f'{request.function.__name__}_sub_agent_1' + assert parent.sub_agents[1].name == f'{request.function.__name__}_sub_agent_2' + assert parent.sub_agents[2].name == f'{request.function.__name__}_sub_agent_3' + + +def test_validate_sub_agents_unique_names_empty_list( + request: pytest.FixtureRequest, +): + """Test that empty sub-agents list passes validation.""" + parent = _TestingAgent( + name=f'{request.function.__name__}_parent', + sub_agents=[], + ) + + assert len(parent.sub_agents) == 0 + + if __name__ == '__main__': pytest.main([__file__]) diff --git a/tests/unittests/agents/test_llm_agent_fields.py b/tests/unittests/agents/test_llm_agent_fields.py index c57254dbc8..577923f7bf 100644 --- a/tests/unittests/agents/test_llm_agent_fields.py +++ b/tests/unittests/agents/test_llm_agent_fields.py @@ -22,6 +22,9 @@ from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.llm_agent import LlmAgent from google.adk.agents.readonly_context import ReadonlyContext +from google.adk.models.anthropic_llm import Claude +from google.adk.models.google_llm import Gemini +from google.adk.models.lite_llm import LiteLlm from google.adk.models.llm_request import LlmRequest from google.adk.models.registry import LLMRegistry from google.adk.sessions.in_memory_session_service import InMemorySessionService @@ -411,3 +414,47 @@ async def test_handle_vais_only(self): assert len(tools) == 1 assert tools[0].name == 'vertex_ai_search' assert tools[0].__class__.__name__ == 'VertexAiSearchTool' + + +# Tests for multi-provider model support via string model names +@pytest.mark.parametrize( + 'model_name', + [ + 'gemini-1.5-flash', + 'gemini-2.0-flash-exp', + ], +) +def test_agent_with_gemini_string_model(model_name): + """Test that Agent accepts Gemini model strings and resolves to Gemini.""" + agent = LlmAgent(name='test_agent', model=model_name) + assert isinstance(agent.canonical_model, Gemini) + assert agent.canonical_model.model == model_name + + +@pytest.mark.parametrize( + 'model_name', + [ + 'claude-3-5-sonnet-v2@20241022', + 'claude-sonnet-4@20250514', + ], +) +def test_agent_with_claude_string_model(model_name): + """Test that Agent accepts Claude model strings and resolves to Claude.""" + agent = LlmAgent(name='test_agent', model=model_name) + assert isinstance(agent.canonical_model, Claude) + assert agent.canonical_model.model == model_name + + +@pytest.mark.parametrize( + 'model_name', + [ + 'openai/gpt-4o', + 'groq/llama3-70b-8192', + 'anthropic/claude-3-opus-20240229', + ], +) +def test_agent_with_litellm_string_model(model_name): + """Test that Agent accepts LiteLLM provider strings.""" + agent = LlmAgent(name='test_agent', model=model_name) + assert isinstance(agent.canonical_model, LiteLlm) + assert agent.canonical_model.model == model_name diff --git a/tests/unittests/agents/test_mcp_instruction_provider.py b/tests/unittests/agents/test_mcp_instruction_provider.py index 1f2d098c2a..256d812630 100644 --- a/tests/unittests/agents/test_mcp_instruction_provider.py +++ b/tests/unittests/agents/test_mcp_instruction_provider.py @@ -13,34 +13,14 @@ # limitations under the License. """Unit tests for McpInstructionProvider.""" -import sys from unittest.mock import AsyncMock from unittest.mock import MagicMock from unittest.mock import patch +from google.adk.agents.mcp_instruction_provider import McpInstructionProvider from google.adk.agents.readonly_context import ReadonlyContext import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), - reason="MCP instruction provider requires Python 3.10+", -) - -# Import dependencies with version checking -try: - from google.adk.agents.mcp_instruction_provider import McpInstructionProvider -except ImportError as e: - if sys.version_info < (3, 10): - # Create dummy classes to prevent NameError during test collection - # Tests will be skipped anyway due to pytestmark - class DummyClass: - pass - - McpInstructionProvider = DummyClass - else: - raise e - class TestMcpInstructionProvider: """Unit tests for McpInstructionProvider.""" diff --git a/tests/unittests/agents/test_remote_a2a_agent.py b/tests/unittests/agents/test_remote_a2a_agent.py index 561a381870..e7865f39ba 100644 --- a/tests/unittests/agents/test_remote_a2a_agent.py +++ b/tests/unittests/agents/test_remote_a2a_agent.py @@ -14,70 +14,38 @@ import json from pathlib import Path -import sys import tempfile from unittest.mock import AsyncMock from unittest.mock import create_autospec from unittest.mock import Mock from unittest.mock import patch +from a2a.client.client import ClientConfig +from a2a.client.client import Consumer +from a2a.client.client_factory import ClientFactory +from a2a.types import AgentCapabilities +from a2a.types import AgentCard +from a2a.types import AgentSkill +from a2a.types import Artifact +from a2a.types import Message as A2AMessage +from a2a.types import Part as A2ATaskStatus +from a2a.types import SendMessageSuccessResponse +from a2a.types import Task as A2ATask +from a2a.types import TaskArtifactUpdateEvent +from a2a.types import TaskState +from a2a.types import TaskStatus +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.remote_a2a_agent import A2A_METADATA_PREFIX +from google.adk.agents.remote_a2a_agent import AgentCardResolutionError +from google.adk.agents.remote_a2a_agent import RemoteA2aAgent from google.adk.events.event import Event from google.adk.sessions.session import Session from google.genai import types as genai_types import httpx import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from a2a.client.client import ClientConfig - from a2a.client.client import Consumer - from a2a.client.client_factory import ClientFactory - from a2a.types import AgentCapabilities - from a2a.types import AgentCard - from a2a.types import AgentSkill - from a2a.types import Artifact - from a2a.types import Message as A2AMessage - from a2a.types import Part as A2ATaskStatus - from a2a.types import SendMessageSuccessResponse - from a2a.types import Task as A2ATask - from a2a.types import TaskArtifactUpdateEvent - from a2a.types import TaskState - from a2a.types import TaskStatus - from a2a.types import TaskStatusUpdateEvent - from a2a.types import TextPart - from google.adk.agents.invocation_context import InvocationContext - from google.adk.agents.remote_a2a_agent import A2A_METADATA_PREFIX - from google.adk.agents.remote_a2a_agent import AgentCardResolutionError - from google.adk.agents.remote_a2a_agent import RemoteA2aAgent -except ImportError as e: - if sys.version_info < (3, 10): - # Create dummy classes to prevent NameError during module compilation. - # These are needed because the module has type annotations and module-level - # helper functions that reference imported types. - class DummyTypes: - pass - - AgentCapabilities = DummyTypes() - AgentCard = DummyTypes() - AgentSkill = DummyTypes() - A2AMessage = DummyTypes() - SendMessageSuccessResponse = DummyTypes() - A2ATask = DummyTypes() - TaskStatusUpdateEvent = DummyTypes() - Artifact = DummyTypes() - TaskArtifactUpdateEvent = DummyTypes() - InvocationContext = DummyTypes() - RemoteA2aAgent = DummyTypes() - AgentCardResolutionError = Exception - A2A_METADATA_PREFIX = "" - else: - raise e - # Helper function to create a proper AgentCard for testing def create_test_agent_card( @@ -723,6 +691,7 @@ async def test_handle_a2a_response_success_with_message(self): mock_a2a_message, self.agent.name, self.mock_context, + self.mock_a2a_part_converter, ) # Check that metadata was added assert result.custom_metadata is not None @@ -760,6 +729,7 @@ async def test_handle_a2a_response_with_task_completed_and_no_update(self): mock_a2a_task, self.agent.name, self.mock_context, + self.mock_a2a_part_converter, ) # Check the parts are not updated as Thought assert result.content.parts[0].thought is None @@ -864,6 +834,7 @@ async def test_handle_a2a_response_with_task_submitted_and_no_update(self): mock_a2a_task, self.agent.name, self.mock_context, + self.mock_a2a_part_converter, ) # Check the parts are updated as Thought assert result.content.parts[0].thought is True @@ -909,6 +880,7 @@ async def test_handle_a2a_response_with_task_status_update_with_message(self): mock_a2a_message, self.agent.name, self.mock_context, + self.mock_a2a_part_converter, ) # Check that metadata was added assert result.custom_metadata is not None @@ -954,6 +926,7 @@ async def test_handle_a2a_response_with_task_status_working_update_with_message( mock_a2a_message, self.agent.name, self.mock_context, + self.mock_a2a_part_converter, ) # Check that metadata was added assert result.custom_metadata is not None @@ -1009,7 +982,10 @@ async def test_handle_a2a_response_with_artifact_update(self): assert result == mock_event mock_convert.assert_called_once_with( - mock_a2a_task, self.agent.name, self.mock_context + mock_a2a_task, + self.agent.name, + self.mock_context, + self.agent._a2a_part_converter, ) # Check that metadata was added assert result.custom_metadata is not None @@ -1039,6 +1015,8 @@ class TestRemoteA2aAgentMessageHandlingFromFactory: def setup_method(self): """Setup test fixtures.""" + self.mock_a2a_part_converter = Mock() + self.agent_card = create_test_agent_card() self.agent = RemoteA2aAgent( name="test_agent", @@ -1046,6 +1024,7 @@ def setup_method(self): a2a_client_factory=ClientFactory( config=ClientConfig(httpx_client=httpx.AsyncClient()), ), + a2a_part_converter=self.mock_a2a_part_converter, ) # Mock session and context @@ -1173,7 +1152,10 @@ async def test_handle_a2a_response_success_with_message(self): assert result == mock_event mock_convert.assert_called_once_with( - mock_a2a_message, self.agent.name, self.mock_context + mock_a2a_message, + self.agent.name, + self.mock_context, + self.mock_a2a_part_converter, ) # Check that metadata was added assert result.custom_metadata is not None @@ -1211,6 +1193,7 @@ async def test_handle_a2a_response_with_task_completed_and_no_update(self): mock_a2a_task, self.agent.name, self.mock_context, + self.mock_a2a_part_converter, ) # Check the parts are not updated as Thought assert result.content.parts[0].thought is None @@ -1251,6 +1234,7 @@ async def test_handle_a2a_response_with_task_submitted_and_no_update(self): mock_a2a_task, self.agent.name, self.mock_context, + self.agent._a2a_part_converter, ) # Check the parts are updated as Thought assert result.content.parts[0].thought is True @@ -1296,6 +1280,7 @@ async def test_handle_a2a_response_with_task_status_update_with_message(self): mock_a2a_message, self.agent.name, self.mock_context, + self.agent._a2a_part_converter, ) # Check that metadata was added assert result.custom_metadata is not None @@ -1341,6 +1326,7 @@ async def test_handle_a2a_response_with_task_status_working_update_with_message( mock_a2a_message, self.agent.name, self.mock_context, + self.agent._a2a_part_converter, ) # Check that metadata was added assert result.custom_metadata is not None @@ -1396,7 +1382,10 @@ async def test_handle_a2a_response_with_artifact_update(self): assert result == mock_event mock_convert.assert_called_once_with( - mock_a2a_task, self.agent.name, self.mock_context + mock_a2a_task, + self.agent.name, + self.mock_context, + self.agent._a2a_part_converter, ) # Check that metadata was added assert result.custom_metadata is not None diff --git a/tests/unittests/artifacts/test_artifact_service.py b/tests/unittests/artifacts/test_artifact_service.py index f7b457f73b..c68ad512c0 100644 --- a/tests/unittests/artifacts/test_artifact_service.py +++ b/tests/unittests/artifacts/test_artifact_service.py @@ -32,6 +32,7 @@ from google.adk.artifacts.file_artifact_service import FileArtifactService from google.adk.artifacts.gcs_artifact_service import GcsArtifactService from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.errors.input_validation_error import InputValidationError from google.genai import types import pytest @@ -732,7 +733,7 @@ async def test_file_save_artifact_rejects_out_of_scope_paths( """FileArtifactService prevents path traversal outside of its storage roots.""" artifact_service = FileArtifactService(root_dir=tmp_path / "artifacts") part = types.Part(text="content") - with pytest.raises(ValueError): + with pytest.raises(InputValidationError): await artifact_service.save_artifact( app_name="myapp", user_id="user123", @@ -757,7 +758,7 @@ async def test_file_save_artifact_rejects_absolute_path_within_scope(tmp_path): / "diagram.png" ) part = types.Part(text="content") - with pytest.raises(ValueError): + with pytest.raises(InputValidationError): await artifact_service.save_artifact( app_name="myapp", user_id="user123", diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index 2d7b9472ba..75d5679084 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -30,7 +30,9 @@ from google.adk.agents.base_agent import BaseAgent from google.adk.agents.run_config import RunConfig from google.adk.apps.app import App +from google.adk.artifacts.base_artifact_service import ArtifactVersion from google.adk.cli.fast_api import get_fast_api_app +from google.adk.errors.input_validation_error import InputValidationError from google.adk.evaluation.eval_case import EvalCase from google.adk.evaluation.eval_case import Invocation from google.adk.evaluation.eval_result import EvalSetResult @@ -190,6 +192,14 @@ def load_agent(self, app_name): def list_agents(self): return ["test_app"] + def list_agents_detailed(self): + return [{ + "name": "test_app", + "root_agent_name": "test_agent", + "description": "A test agent for unit testing", + "language": "python", + }] + return MockAgentLoader(".") @@ -203,48 +213,135 @@ def mock_session_service(): def mock_artifact_service(): """Create a mock artifact service.""" - # Storage for artifacts - artifacts = {} + artifacts: dict[str, list[dict[str, Any]]] = {} + + def _artifact_key( + app_name: str, user_id: str, session_id: Optional[str], filename: str + ) -> str: + if session_id is None: + return f"{app_name}:{user_id}:user:{filename}" + return f"{app_name}:{user_id}:{session_id}:{filename}" + + def _canonical_uri( + app_name: str, + user_id: str, + session_id: Optional[str], + filename: str, + version: int, + ) -> str: + if session_id is None: + return ( + f"artifact://apps/{app_name}/users/{user_id}/artifacts/" + f"{filename}/versions/{version}" + ) + return ( + f"artifact://apps/{app_name}/users/{user_id}/sessions/{session_id}/" + f"artifacts/{filename}/versions/{version}" + ) class MockArtifactService: + def __init__(self): + self._artifacts = artifacts + self.save_artifact_side_effect: Optional[BaseException] = None + + async def save_artifact( + self, + *, + app_name: str, + user_id: str, + filename: str, + artifact: types.Part, + session_id: Optional[str] = None, + custom_metadata: Optional[dict[str, Any]] = None, + ) -> int: + if self.save_artifact_side_effect is not None: + effect = self.save_artifact_side_effect + if isinstance(effect, BaseException): + raise effect + raise TypeError( + "save_artifact_side_effect must be an exception instance." + ) + key = _artifact_key(app_name, user_id, session_id, filename) + entries = artifacts.setdefault(key, []) + version = len(entries) + artifact_version = ArtifactVersion( + version=version, + canonical_uri=_canonical_uri( + app_name, user_id, session_id, filename, version + ), + custom_metadata=custom_metadata or {}, + ) + if artifact.inline_data is not None: + artifact_version.mime_type = artifact.inline_data.mime_type + elif artifact.text is not None: + artifact_version.mime_type = "text/plain" + elif artifact.file_data is not None: + artifact_version.mime_type = artifact.file_data.mime_type + + entries.append({ + "version": version, + "artifact": artifact, + "metadata": artifact_version, + }) + return version + async def load_artifact( self, app_name, user_id, session_id, filename, version=None ): """Load an artifact by filename.""" - key = f"{app_name}:{user_id}:{session_id}:{filename}" + key = _artifact_key(app_name, user_id, session_id, filename) if key not in artifacts: return None if version is not None: - # Get a specific version - for v in artifacts[key]: - if v["version"] == version: - return v["artifact"] + for entry in artifacts[key]: + if entry["version"] == version: + return entry["artifact"] return None - # Get the latest version - return sorted(artifacts[key], key=lambda x: x["version"])[-1]["artifact"] + return artifacts[key][-1]["artifact"] async def list_artifact_keys(self, app_name, user_id, session_id): """List artifact names for a session.""" prefix = f"{app_name}:{user_id}:{session_id}:" return [ - k.split(":")[-1] for k in artifacts.keys() if k.startswith(prefix) + key.split(":")[-1] + for key in artifacts.keys() + if key.startswith(prefix) ] async def list_versions(self, app_name, user_id, session_id, filename): """List versions of an artifact.""" - key = f"{app_name}:{user_id}:{session_id}:{filename}" + key = _artifact_key(app_name, user_id, session_id, filename) if key not in artifacts: return [] - return [a["version"] for a in artifacts[key]] + return [entry["version"] for entry in artifacts[key]] async def delete_artifact(self, app_name, user_id, session_id, filename): """Delete an artifact.""" - key = f"{app_name}:{user_id}:{session_id}:{filename}" - if key in artifacts: - del artifacts[key] + key = _artifact_key(app_name, user_id, session_id, filename) + artifacts.pop(key, None) + + async def get_artifact_version( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: Optional[str] = None, + version: Optional[int] = None, + ) -> Optional[ArtifactVersion]: + key = _artifact_key(app_name, user_id, session_id, filename) + entries = artifacts.get(key) + if not entries: + return None + if version is None: + return entries[-1]["metadata"] + for entry in entries: + if entry["version"] == version: + return entry["metadata"] + return None return MockArtifactService() @@ -412,8 +509,6 @@ async def create_test_eval_set( @pytest.fixture def temp_agents_dir_with_a2a(): """Create a temporary agents directory with A2A agent configurations for testing.""" - if sys.version_info < (3, 10): - pytest.skip("A2A requires Python 3.10+") with tempfile.TemporaryDirectory() as temp_dir: # Create test agent directory agent_dir = Path(temp_dir) / "test_a2a_agent" @@ -457,9 +552,6 @@ def test_app_with_a2a( temp_agents_dir_with_a2a, ): """Create a TestClient for the FastAPI app with A2A enabled.""" - if sys.version_info < (3, 10): - pytest.skip("A2A requires Python 3.10+") - # Mock A2A related classes with ( patch("signal.signal", return_value=None), @@ -548,6 +640,26 @@ def test_list_apps(test_app): logger.info(f"Listed apps: {data}") +def test_list_apps_detailed(test_app): + """Test listing available applications with detailed metadata.""" + response = test_app.get("/list-apps?detailed=true") + + assert response.status_code == 200 + data = response.json() + assert isinstance(data, dict) + assert "apps" in data + assert isinstance(data["apps"], list) + + for app in data["apps"]: + assert "name" in app + assert "rootAgentName" in app + assert "description" in app + assert "language" in app + assert app["language"] in ["yaml", "python"] + + logger.info(f"Listed apps: {data}") + + def test_create_session_with_id(test_app, test_session_info): """Test creating a session with a specific ID.""" new_session_id = "new_session_id" @@ -782,6 +894,87 @@ def test_list_artifact_names(test_app, create_test_session): logger.info(f"Listed {len(data)} artifacts") +def test_save_artifact(test_app, create_test_session, mock_artifact_service): + """Test saving an artifact through the FastAPI endpoint.""" + info = create_test_session + url = ( + f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/" + f"{info['session_id']}/artifacts" + ) + artifact_part = types.Part(text="hello world") + payload = { + "filename": "greeting.txt", + "artifact": artifact_part.model_dump(by_alias=True, exclude_none=True), + } + + response = test_app.post(url, json=payload) + assert response.status_code == 200 + data = response.json() + assert data["version"] == 0 + assert data["customMetadata"] == {} + assert data["mimeType"] in (None, "text/plain") + assert data["canonicalUri"].endswith( + f"/sessions/{info['session_id']}/artifacts/" + f"{payload['filename']}/versions/0" + ) + assert isinstance(data["createTime"], float) + + key = ( + f"{info['app_name']}:{info['user_id']}:{info['session_id']}:" + f"{payload['filename']}" + ) + stored = mock_artifact_service._artifacts[key][0] + assert stored["artifact"].text == "hello world" + + +def test_save_artifact_returns_400_on_validation_error( + test_app, create_test_session, mock_artifact_service +): + """Test save artifact endpoint surfaces validation errors as HTTP 400.""" + info = create_test_session + url = ( + f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/" + f"{info['session_id']}/artifacts" + ) + artifact_part = types.Part(text="bad data") + payload = { + "filename": "invalid.txt", + "artifact": artifact_part.model_dump(by_alias=True, exclude_none=True), + } + + mock_artifact_service.save_artifact_side_effect = InputValidationError( + "invalid artifact" + ) + + response = test_app.post(url, json=payload) + assert response.status_code == 400 + assert response.json()["detail"] == "invalid artifact" + + +def test_save_artifact_returns_500_on_unexpected_error( + test_app, create_test_session, mock_artifact_service +): + """Test save artifact endpoint surfaces unexpected errors as HTTP 500.""" + info = create_test_session + url = ( + f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/" + f"{info['session_id']}/artifacts" + ) + artifact_part = types.Part(text="bad data") + payload = { + "filename": "invalid.txt", + "artifact": artifact_part.model_dump(by_alias=True, exclude_none=True), + } + + mock_artifact_service.save_artifact_side_effect = RuntimeError( + "unexpected failure" + ) + + response = test_app.post(url, json=payload) + assert response.status_code == 500 + assert response.json()["detail"] == "unexpected failure" + + def test_create_eval_set(test_app, test_session_info): """Test creating an eval set.""" url = f"/apps/{test_session_info['app_name']}/eval_sets/test_eval_set_id" @@ -952,9 +1145,6 @@ def list_agents(self): assert "dotSrc" in response.json() -@pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) def test_a2a_agent_discovery(test_app_with_a2a): """Test that A2A agents are properly discovered and configured.""" # This test mainly verifies that the A2A setup doesn't break the app @@ -963,9 +1153,6 @@ def test_a2a_agent_discovery(test_app_with_a2a): logger.info("A2A agent discovery test passed") -@pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) def test_a2a_disabled_by_default(test_app): """Test that A2A functionality is disabled by default.""" # The regular test_app fixture has a2a=False diff --git a/tests/unittests/cli/utils/test_agent_loader.py b/tests/unittests/cli/utils/test_agent_loader.py index 5c66160aed..4950fecbd3 100644 --- a/tests/unittests/cli/utils/test_agent_loader.py +++ b/tests/unittests/cli/utils/test_agent_loader.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import ntpath import os from pathlib import Path +from pathlib import PureWindowsPath +import re import sys import tempfile from textwrap import dedent +from google.adk.cli.utils import agent_loader as agent_loader_module from google.adk.cli.utils.agent_loader import AgentLoader from pydantic import ValidationError import pytest @@ -280,6 +284,43 @@ def test_load_multiple_different_agents(self): assert agent2 is not agent3 assert agent1.agent_id != agent2.agent_id != agent3.agent_id + def test_error_messages_use_os_sep_consistently(self): + """Verify error messages use os.sep instead of hardcoded '/'.""" + del self + with tempfile.TemporaryDirectory() as temp_dir: + loader = AgentLoader(temp_dir) + agent_name = "missing_agent" + + expected_path = os.path.join(temp_dir, agent_name) + + with pytest.raises(ValueError) as exc_info: + loader.load_agent(agent_name) + + exc_info.match(re.escape(expected_path)) + exc_info.match(re.escape(f"{agent_name}{os.sep}root_agent.yaml")) + exc_info.match(re.escape(f"{os.sep}")) + + def test_agent_loader_with_mocked_windows_path(self, monkeypatch): + """Mock Path() to simulate Windows behavior and catch regressions. + + REGRESSION TEST: Fails with rstrip('/'), passes with str(Path()). + """ + del self + windows_path = "C:\\Users\\dev\\agents\\" + + with monkeypatch.context() as m: + m.setattr( + agent_loader_module, + "Path", + lambda path_str: PureWindowsPath(path_str), + ) + loader = AgentLoader(windows_path) + + expected = str(PureWindowsPath(windows_path)) + assert loader.agents_dir == expected + assert not loader.agents_dir.endswith("\\") + assert not loader.agents_dir.endswith("/") + def test_agent_not_found_error(self): """Test that appropriate error is raised when agent is not found.""" with tempfile.TemporaryDirectory() as temp_dir: diff --git a/tests/unittests/cli/utils/test_cli.py b/tests/unittests/cli/utils/test_cli.py index 0de59598b3..73ae89a986 100644 --- a/tests/unittests/cli/utils/test_cli.py +++ b/tests/unittests/cli/utils/test_cli.py @@ -28,7 +28,12 @@ import click from google.adk.agents.base_agent import BaseAgent from google.adk.apps.app import App +from google.adk.artifacts.file_artifact_service import FileArtifactService +from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService import google.adk.cli.cli as cli +from google.adk.cli.utils.service_factory import create_artifact_service_from_options +from google.adk.sessions.in_memory_session_service import InMemorySessionService import pytest @@ -151,9 +156,9 @@ def _echo(msg: str) -> None: input_path = tmp_path / "input.json" input_path.write_text(json.dumps(input_json)) - artifact_service = cli.InMemoryArtifactService() - session_service = cli.InMemorySessionService() - credential_service = cli.InMemoryCredentialService() + artifact_service = InMemoryArtifactService() + session_service = InMemorySessionService() + credential_service = InMemoryCredentialService() dummy_root = BaseAgent(name="root") session = await cli.run_input_file( @@ -189,6 +194,34 @@ async def test_run_cli_with_input_file(fake_agent, tmp_path: Path) -> None: ) +@pytest.mark.asyncio +async def test_run_cli_loads_services_module( + fake_agent, tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """run_cli should load custom services from the agents directory.""" + parent_dir, folder_name = fake_agent + input_json = {"state": {}, "queries": ["ping"]} + input_path = tmp_path / "input.json" + input_path.write_text(json.dumps(input_json)) + + loaded_dirs: list[str] = [] + monkeypatch.setattr( + cli, "load_services_module", lambda path: loaded_dirs.append(path) + ) + + agent_root = parent_dir / folder_name + + await cli.run_cli( + agent_parent_dir=str(parent_dir), + agent_folder_name=folder_name, + input_file=str(input_path), + saved_session_file=None, + save_session=False, + ) + + assert loaded_dirs == [str(agent_root.resolve())] + + @pytest.mark.asyncio async def test_run_cli_app_uses_app_name_for_sessions( fake_app_agent, tmp_path: Path, monkeypatch: pytest.MonkeyPatch @@ -197,15 +230,20 @@ async def test_run_cli_app_uses_app_name_for_sessions( parent_dir, folder_name, app_name = fake_app_agent created_app_names: List[str] = [] - original_session_cls = cli.InMemorySessionService - - class _SpySessionService(original_session_cls): + class _SpySessionService(InMemorySessionService): async def create_session(self, *, app_name: str, **kwargs: Any) -> Any: created_app_names.append(app_name) return await super().create_session(app_name=app_name, **kwargs) - monkeypatch.setattr(cli, "InMemorySessionService", _SpySessionService) + spy_session_service = _SpySessionService() + + def _session_factory(**_: Any) -> InMemorySessionService: + return spy_session_service + + monkeypatch.setattr( + cli, "create_session_service_from_options", _session_factory + ) input_json = {"state": {}, "queries": ["ping"]} input_path = tmp_path / "input_app.json" @@ -253,16 +291,76 @@ async def test_run_cli_save_session( assert "id" in data and "events" in data +def test_create_artifact_service_defaults_to_file(tmp_path: Path) -> None: + """Service factory should default to FileArtifactService when URI is unset.""" + service = create_artifact_service_from_options(base_dir=tmp_path) + assert isinstance(service, FileArtifactService) + expected_root = Path(tmp_path) / ".adk" / "artifacts" + assert service.root_dir == expected_root + assert expected_root.exists() + + +def test_create_artifact_service_uses_shared_root( + tmp_path: Path, +) -> None: + """Artifact service should use a single file artifact service.""" + service = create_artifact_service_from_options(base_dir=tmp_path) + assert isinstance(service, FileArtifactService) + expected_root = Path(tmp_path) / ".adk" / "artifacts" + assert service.root_dir == expected_root + assert expected_root.exists() + + +def test_create_artifact_service_respects_memory_uri(tmp_path: Path) -> None: + """Service factory should honor memory:// URIs.""" + service = create_artifact_service_from_options( + base_dir=tmp_path, artifact_service_uri="memory://" + ) + assert isinstance(service, InMemoryArtifactService) + + +def test_create_artifact_service_accepts_file_uri(tmp_path: Path) -> None: + """Service factory should allow custom local roots via file:// URIs.""" + custom_root = tmp_path / "custom_artifacts" + service = create_artifact_service_from_options( + base_dir=tmp_path, artifact_service_uri=custom_root.as_uri() + ) + assert isinstance(service, FileArtifactService) + assert service.root_dir == custom_root + assert custom_root.exists() + + +@pytest.mark.asyncio +async def test_run_cli_accepts_memory_scheme( + fake_agent, tmp_path: Path +) -> None: + """run_cli should allow configuring in-memory services via memory:// URIs.""" + parent_dir, folder_name = fake_agent + input_json = {"state": {}, "queries": []} + input_path = tmp_path / "noop.json" + input_path.write_text(json.dumps(input_json)) + + await cli.run_cli( + agent_parent_dir=str(parent_dir), + agent_folder_name=folder_name, + input_file=str(input_path), + saved_session_file=None, + save_session=False, + session_service_uri="memory://", + artifact_service_uri="memory://", + ) + + @pytest.mark.asyncio async def test_run_interactively_whitespace_and_exit( tmp_path: Path, monkeypatch: pytest.MonkeyPatch ) -> None: """run_interactively should skip blank input, echo once, then exit.""" # make a session that belongs to dummy agent - session_service = cli.InMemorySessionService() + session_service = InMemorySessionService() sess = await session_service.create_session(app_name="dummy", user_id="u") - artifact_service = cli.InMemoryArtifactService() - credential_service = cli.InMemoryCredentialService() + artifact_service = InMemoryArtifactService() + credential_service = InMemoryCredentialService() root_agent = BaseAgent(name="root") # fake user input: blank -> 'hello' -> 'exit' diff --git a/tests/unittests/cli/utils/test_cli_tools_click.py b/tests/unittests/cli/utils/test_cli_tools_click.py index be9015ca87..95b561e57b 100644 --- a/tests/unittests/cli/utils/test_cli_tools_click.py +++ b/tests/unittests/cli/utils/test_cli_tools_click.py @@ -76,8 +76,11 @@ def __call__(self, *args: Any, **kwargs: Any) -> None: # noqa: D401 # Fixtures @pytest.fixture(autouse=True) -def _mute_click(monkeypatch: pytest.MonkeyPatch) -> None: +def _mute_click(request, monkeypatch: pytest.MonkeyPatch) -> None: """Suppress click output during tests.""" + # Allow tests to opt-out of muting by using the 'unmute_click' marker + if "unmute_click" in request.keywords: + return monkeypatch.setattr(click, "echo", lambda *a, **k: None) # Keep secho for error messages # monkeypatch.setattr(click, "secho", lambda *a, **k: None) @@ -121,32 +124,70 @@ def test_cli_create_cmd_invokes_run_cmd( cli_tools_click.main, ["create", "--model", "gemini", "--api_key", "key123", str(app_dir)], ) - assert result.exit_code == 0 + assert result.exit_code == 0, (result.output, repr(result.exception)) assert rec.calls, "cli_create.run_cmd must be called" # cli run -@pytest.mark.asyncio -async def test_cli_run_invokes_run_cli( - tmp_path: Path, monkeypatch: pytest.MonkeyPatch +@pytest.mark.parametrize( + "cli_args,expected_session_uri,expected_artifact_uri", + [ + pytest.param( + [ + "--session_service_uri", + "memory://", + "--artifact_service_uri", + "memory://", + ], + "memory://", + "memory://", + id="memory_scheme_uris", + ), + pytest.param( + [], + None, + None, + id="default_uris_none", + ), + ], +) +def test_cli_run_service_uris( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + cli_args: list, + expected_session_uri: str, + expected_artifact_uri: str, ) -> None: - """`adk run` should call run_cli via asyncio.run with correct parameters.""" - rec = _Recorder() - monkeypatch.setattr(cli_tools_click, "run_cli", lambda **kwargs: rec(kwargs)) - monkeypatch.setattr( - cli_tools_click.asyncio, "run", lambda coro: coro - ) # pass-through - - # create dummy agent directory + """`adk run` should forward service URIs correctly to run_cli.""" agent_dir = tmp_path / "agent" agent_dir.mkdir() (agent_dir / "__init__.py").touch() (agent_dir / "agent.py").touch() + # Capture the coroutine's locals before closing it + captured_locals = [] + + def capture_asyncio_run(coro): + # Extract the locals before closing the coroutine + if coro.cr_frame is not None: + captured_locals.append(dict(coro.cr_frame.f_locals)) + coro.close() # Properly close the coroutine to avoid warnings + + monkeypatch.setattr(cli_tools_click.asyncio, "run", capture_asyncio_run) + runner = CliRunner() - result = runner.invoke(cli_tools_click.main, ["run", str(agent_dir)]) - assert result.exit_code == 0 - assert rec.calls and rec.calls[0][0][0]["agent_folder_name"] == "agent" + result = runner.invoke( + cli_tools_click.main, + ["run", *cli_args, str(agent_dir)], + ) + assert result.exit_code == 0, (result.output, repr(result.exception)) + assert len(captured_locals) == 1, "Expected asyncio.run to be called once" + + # Verify the kwargs passed to run_cli + coro_locals = captured_locals[0] + assert coro_locals.get("session_service_uri") == expected_session_uri + assert coro_locals.get("artifact_service_uri") == expected_artifact_uri + assert coro_locals["agent_folder_name"] == "agent" # cli deploy cloud_run @@ -520,10 +561,13 @@ def test_cli_web_passes_service_uris( assert called_kwargs.get("memory_service_uri") == "rag://mycorpus" -def test_cli_web_passes_deprecated_uris( - tmp_path: Path, monkeypatch: pytest.MonkeyPatch, _patch_uvicorn: _Recorder +@pytest.mark.unmute_click +def test_cli_web_warns_and_maps_deprecated_uris( + tmp_path: Path, + _patch_uvicorn: _Recorder, + monkeypatch: pytest.MonkeyPatch, ) -> None: - """`adk web` should use deprecated URIs if new ones are not provided.""" + """`adk web` should accept deprecated URI flags with warnings.""" agents_dir = tmp_path / "agents" agents_dir.mkdir() @@ -542,11 +586,14 @@ def test_cli_web_passes_deprecated_uris( "gs://deprecated", ], ) + assert result.exit_code == 0 - assert mock_get_app.calls called_kwargs = mock_get_app.calls[0][1] assert called_kwargs.get("session_service_uri") == "sqlite:///deprecated.db" assert called_kwargs.get("artifact_service_uri") == "gs://deprecated" + # Check output for deprecation warnings (CliRunner captures both stdout and stderr) + assert "--session_db_url" in result.output + assert "--artifact_storage_uri" in result.output def test_cli_eval_with_eval_set_file_path( diff --git a/tests/unittests/cli/utils/test_service_factory.py b/tests/unittests/cli/utils/test_service_factory.py new file mode 100644 index 0000000000..9d9afdd23b --- /dev/null +++ b/tests/unittests/cli/utils/test_service_factory.py @@ -0,0 +1,141 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for service factory helpers.""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import Mock + +from google.adk.cli.utils.local_storage import PerAgentDatabaseSessionService +import google.adk.cli.utils.service_factory as service_factory +from google.adk.memory.in_memory_memory_service import InMemoryMemoryService +from google.adk.sessions.database_session_service import DatabaseSessionService +import pytest + + +def test_create_session_service_uses_registry(tmp_path: Path, monkeypatch): + registry = Mock() + expected = object() + registry.create_session_service.return_value = expected + monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) + + result = service_factory.create_session_service_from_options( + base_dir=tmp_path, + session_service_uri="sqlite:///test.db", + ) + + assert result is expected + registry.create_session_service.assert_called_once_with( + "sqlite:///test.db", + agents_dir=str(tmp_path), + ) + + +@pytest.mark.asyncio +async def test_create_session_service_defaults_to_per_agent_sqlite( + tmp_path: Path, +) -> None: + agent_dir = tmp_path / "agent_a" + agent_dir.mkdir() + service = service_factory.create_session_service_from_options( + base_dir=tmp_path, + ) + + assert isinstance(service, PerAgentDatabaseSessionService) + session = await service.create_session(app_name="agent_a", user_id="user") + assert session.app_name == "agent_a" + assert (agent_dir / ".adk" / "session.db").exists() + + +def test_create_session_service_fallbacks_to_database( + tmp_path: Path, monkeypatch +): + registry = Mock() + registry.create_session_service.return_value = None + monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) + + service = service_factory.create_session_service_from_options( + base_dir=tmp_path, + session_service_uri="sqlite+aiosqlite:///:memory:", + session_db_kwargs={"echo": True}, + ) + + assert isinstance(service, DatabaseSessionService) + assert service.db_engine.url.drivername == "sqlite+aiosqlite" + assert service.db_engine.echo is True + registry.create_session_service.assert_called_once_with( + "sqlite+aiosqlite:///:memory:", + agents_dir=str(tmp_path), + echo=True, + ) + + +def test_create_artifact_service_uses_registry(tmp_path: Path, monkeypatch): + registry = Mock() + expected = object() + registry.create_artifact_service.return_value = expected + monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) + + result = service_factory.create_artifact_service_from_options( + base_dir=tmp_path, + artifact_service_uri="gs://bucket/path", + ) + + assert result is expected + registry.create_artifact_service.assert_called_once_with( + "gs://bucket/path", + agents_dir=str(tmp_path), + ) + + +def test_create_memory_service_uses_registry(tmp_path: Path, monkeypatch): + registry = Mock() + expected = object() + registry.create_memory_service.return_value = expected + monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) + + result = service_factory.create_memory_service_from_options( + base_dir=tmp_path, + memory_service_uri="rag://my-corpus", + ) + + assert result is expected + registry.create_memory_service.assert_called_once_with( + "rag://my-corpus", + agents_dir=str(tmp_path), + ) + + +def test_create_memory_service_defaults_to_in_memory(tmp_path: Path): + service = service_factory.create_memory_service_from_options( + base_dir=tmp_path + ) + + assert isinstance(service, InMemoryMemoryService) + + +def test_create_memory_service_raises_on_unknown_scheme( + tmp_path: Path, monkeypatch +): + registry = Mock() + registry.create_memory_service.return_value = None + monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) + + with pytest.raises(ValueError): + service_factory.create_memory_service_from_options( + base_dir=tmp_path, + memory_service_uri="unknown://foo", + ) diff --git a/tests/unittests/code_executors/test_unsafe_local_code_executor.py b/tests/unittests/code_executors/test_unsafe_local_code_executor.py index eeb10b34fa..e5d5c4f792 100644 --- a/tests/unittests/code_executors/test_unsafe_local_code_executor.py +++ b/tests/unittests/code_executors/test_unsafe_local_code_executor.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import textwrap from unittest.mock import MagicMock from google.adk.agents.base_agent import BaseAgent @@ -101,3 +102,22 @@ def test_execute_code_empty(self, mock_invocation_context: InvocationContext): result = executor.execute_code(mock_invocation_context, code_input) assert result.stdout == "" assert result.stderr == "" + + def test_execute_code_nested_function_call( + self, mock_invocation_context: InvocationContext + ): + executor = UnsafeLocalCodeExecutor() + code_input = CodeExecutionInput(code=(textwrap.dedent(""" + def helper(name): + return f'hi {name}' + + def run(): + print(helper('ada')) + + run() + """))) + + result = executor.execute_code(mock_invocation_context, code_input) + + assert result.stderr == "" + assert result.stdout == "hi ada\n" diff --git a/tests/unittests/evaluation/simulation/__init__.py b/tests/unittests/evaluation/simulation/__init__.py new file mode 100644 index 0000000000..0a2669d7a2 --- /dev/null +++ b/tests/unittests/evaluation/simulation/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/evaluation/test_llm_backed_user_simulator.py b/tests/unittests/evaluation/simulation/test_llm_backed_user_simulator.py similarity index 95% rename from tests/unittests/evaluation/test_llm_backed_user_simulator.py rename to tests/unittests/evaluation/simulation/test_llm_backed_user_simulator.py index 6ef3969e70..75db778bc7 100644 --- a/tests/unittests/evaluation/test_llm_backed_user_simulator.py +++ b/tests/unittests/evaluation/simulation/test_llm_backed_user_simulator.py @@ -15,9 +15,9 @@ from __future__ import annotations from google.adk.evaluation import conversation_scenarios -from google.adk.evaluation.llm_backed_user_simulator import LlmBackedUserSimulator -from google.adk.evaluation.llm_backed_user_simulator import LlmBackedUserSimulatorConfig -from google.adk.evaluation.user_simulator import Status +from google.adk.evaluation.simulation.llm_backed_user_simulator import LlmBackedUserSimulator +from google.adk.evaluation.simulation.llm_backed_user_simulator import LlmBackedUserSimulatorConfig +from google.adk.evaluation.simulation.user_simulator import Status from google.adk.events.event import Event from google.genai import types import pytest @@ -112,7 +112,7 @@ async def to_async_iter(items): def mock_llm_agent(mocker): """Provides a mock LLM agent.""" mock_llm_registry_cls = mocker.patch( - "google.adk.evaluation.llm_backed_user_simulator.LLMRegistry" + "google.adk.evaluation.simulation.llm_backed_user_simulator.LLMRegistry" ) mock_llm_registry = mocker.MagicMock() mock_llm_registry_cls.return_value = mock_llm_registry diff --git a/tests/unittests/evaluation/test_static_user_simulator.py b/tests/unittests/evaluation/simulation/test_static_user_simulator.py similarity index 93% rename from tests/unittests/evaluation/test_static_user_simulator.py rename to tests/unittests/evaluation/simulation/test_static_user_simulator.py index 5cc70c80e6..f18c23f5f2 100644 --- a/tests/unittests/evaluation/test_static_user_simulator.py +++ b/tests/unittests/evaluation/simulation/test_static_user_simulator.py @@ -14,9 +14,9 @@ from __future__ import annotations -from google.adk.evaluation import static_user_simulator -from google.adk.evaluation import user_simulator from google.adk.evaluation.eval_case import Invocation +from google.adk.evaluation.simulation import static_user_simulator +from google.adk.evaluation.simulation import user_simulator from google.genai import types import pytest diff --git a/tests/unittests/evaluation/test_user_simulator.py b/tests/unittests/evaluation/simulation/test_user_simulator.py similarity index 90% rename from tests/unittests/evaluation/test_user_simulator.py rename to tests/unittests/evaluation/simulation/test_user_simulator.py index c3e1e606ee..dbe7aff1db 100644 --- a/tests/unittests/evaluation/test_user_simulator.py +++ b/tests/unittests/evaluation/simulation/test_user_simulator.py @@ -14,8 +14,8 @@ from __future__ import annotations -from google.adk.evaluation.user_simulator import NextUserMessage -from google.adk.evaluation.user_simulator import Status +from google.adk.evaluation.simulation.user_simulator import NextUserMessage +from google.adk.evaluation.simulation.user_simulator import Status from google.genai.types import Content import pytest diff --git a/tests/unittests/evaluation/test_user_simulator_provider.py b/tests/unittests/evaluation/simulation/test_user_simulator_provider.py similarity index 86% rename from tests/unittests/evaluation/test_user_simulator_provider.py rename to tests/unittests/evaluation/simulation/test_user_simulator_provider.py index 7cff4241b6..c4fb826fb7 100644 --- a/tests/unittests/evaluation/test_user_simulator_provider.py +++ b/tests/unittests/evaluation/simulation/test_user_simulator_provider.py @@ -16,10 +16,10 @@ from google.adk.evaluation import conversation_scenarios from google.adk.evaluation import eval_case -from google.adk.evaluation import user_simulator_provider -from google.adk.evaluation.llm_backed_user_simulator import LlmBackedUserSimulator -from google.adk.evaluation.llm_backed_user_simulator import LlmBackedUserSimulatorConfig -from google.adk.evaluation.static_user_simulator import StaticUserSimulator +from google.adk.evaluation.simulation import user_simulator_provider +from google.adk.evaluation.simulation.llm_backed_user_simulator import LlmBackedUserSimulator +from google.adk.evaluation.simulation.llm_backed_user_simulator import LlmBackedUserSimulatorConfig +from google.adk.evaluation.simulation.static_user_simulator import StaticUserSimulator from google.genai import types import pytest @@ -52,7 +52,7 @@ def test_provide_static_user_simulator(self): def test_provide_llm_backed_user_simulator(self, mocker): """Tests the case when a LlmBackedUserSimulator should be provided.""" mock_llm_registry = mocker.patch( - 'google.adk.evaluation.llm_backed_user_simulator.LLMRegistry', + 'google.adk.evaluation.simulation.llm_backed_user_simulator.LLMRegistry', autospec=True, ) mock_llm_registry.return_value.resolve.return_value = mocker.Mock() diff --git a/tests/unittests/evaluation/test_evaluation_generator.py b/tests/unittests/evaluation/test_evaluation_generator.py index 27372f12c2..873239e7f4 100644 --- a/tests/unittests/evaluation/test_evaluation_generator.py +++ b/tests/unittests/evaluation/test_evaluation_generator.py @@ -18,9 +18,9 @@ from google.adk.evaluation.app_details import AppDetails from google.adk.evaluation.evaluation_generator import EvaluationGenerator from google.adk.evaluation.request_intercepter_plugin import _RequestIntercepterPlugin -from google.adk.evaluation.user_simulator import NextUserMessage -from google.adk.evaluation.user_simulator import Status as UserSimulatorStatus -from google.adk.evaluation.user_simulator import UserSimulator +from google.adk.evaluation.simulation.user_simulator import NextUserMessage +from google.adk.evaluation.simulation.user_simulator import Status as UserSimulatorStatus +from google.adk.evaluation.simulation.user_simulator import UserSimulator from google.adk.events.event import Event from google.adk.models.llm_request import LlmRequest from google.genai import types diff --git a/tests/unittests/evaluation/test_local_eval_service.py b/tests/unittests/evaluation/test_local_eval_service.py index cf2ca342f3..66080828d8 100644 --- a/tests/unittests/evaluation/test_local_eval_service.py +++ b/tests/unittests/evaluation/test_local_eval_service.py @@ -536,9 +536,6 @@ def test_generate_final_eval_status_doesn_t_throw_on(eval_service): @pytest.mark.asyncio -@pytest.mark.skipif( - sys.version_info < (3, 10), reason="MCP tool requires Python 3.10+" -) async def test_mcp_stdio_agent_no_runtime_error(mocker): """Test that LocalEvalService can handle MCP stdio agents without RuntimeError. diff --git a/tests/unittests/flows/llm_flows/test_agent_transfer_system_instructions.py b/tests/unittests/flows/llm_flows/test_agent_transfer_system_instructions.py index be97a627a1..b180a589cb 100644 --- a/tests/unittests/flows/llm_flows/test_agent_transfer_system_instructions.py +++ b/tests/unittests/flows/llm_flows/test_agent_transfer_system_instructions.py @@ -126,15 +126,16 @@ async def test_agent_transfer_includes_sorted_agent_names_in_system_instructions Agent description: Peer agent -If you are the best to answer the question according to your description, you -can answer it. +If you are the best to answer the question according to your description, +you can answer it. If another agent is better for answering the question according to its -description, call `transfer_to_agent` function to transfer the -question to that agent. When transferring, do not generate any text other than -the function call. +description, call `transfer_to_agent` function to transfer the question to that +agent. When transferring, do not generate any text other than the function +call. -**NOTE**: the only available agents for `transfer_to_agent` function are `a_agent`, `m_agent`, `parent_agent`, `peer_agent`, `z_agent`. +**NOTE**: the only available agents for `transfer_to_agent` function are +`a_agent`, `m_agent`, `parent_agent`, `peer_agent`, `z_agent`. If neither you nor the other agents are best for the question, transfer to your parent agent parent_agent.""" @@ -189,15 +190,16 @@ async def test_agent_transfer_system_instructions_without_parent(): Agent description: Second sub-agent -If you are the best to answer the question according to your description, you -can answer it. +If you are the best to answer the question according to your description, +you can answer it. If another agent is better for answering the question according to its -description, call `transfer_to_agent` function to transfer the -question to that agent. When transferring, do not generate any text other than -the function call. +description, call `transfer_to_agent` function to transfer the question to that +agent. When transferring, do not generate any text other than the function +call. -**NOTE**: the only available agents for `transfer_to_agent` function are `agent1`, `agent2`.""" +**NOTE**: the only available agents for `transfer_to_agent` function are +`agent1`, `agent2`.""" assert expected_content in instructions @@ -248,15 +250,16 @@ async def test_agent_transfer_simplified_parent_instructions(): Agent description: Parent agent -If you are the best to answer the question according to your description, you -can answer it. +If you are the best to answer the question according to your description, +you can answer it. If another agent is better for answering the question according to its -description, call `transfer_to_agent` function to transfer the -question to that agent. When transferring, do not generate any text other than -the function call. +description, call `transfer_to_agent` function to transfer the question to that +agent. When transferring, do not generate any text other than the function +call. -**NOTE**: the only available agents for `transfer_to_agent` function are `parent_agent`, `sub_agent`. +**NOTE**: the only available agents for `transfer_to_agent` function are +`parent_agent`, `sub_agent`. If neither you nor the other agents are best for the question, transfer to your parent agent parent_agent.""" diff --git a/tests/unittests/flows/llm_flows/test_contents.py b/tests/unittests/flows/llm_flows/test_contents.py index cf55630b67..b2aa91dbee 100644 --- a/tests/unittests/flows/llm_flows/test_contents.py +++ b/tests/unittests/flows/llm_flows/test_contents.py @@ -465,6 +465,43 @@ async def test_events_with_empty_content_are_skipped(): author="user", content=types.UserContent("How are you?"), ), + # Event with content that has only empty text part + Event( + invocation_id="inv6", + author="user", + content=types.Content(parts=[types.Part(text="")], role="model"), + ), + # Event with content that has only inline data part + Event( + invocation_id="inv7", + author="user", + content=types.Content( + parts=[ + types.Part( + inline_data=types.Blob( + data=b"test", mime_type="image/png" + ) + ) + ], + role="user", + ), + ), + # Event with content that has only file data part + Event( + invocation_id="inv8", + author="user", + content=types.Content( + parts=[ + types.Part( + file_data=types.FileData( + file_uri="gs://test_bucket/test_file.png", + mime_type="image/png", + ) + ) + ], + role="user", + ), + ), ] invocation_context.session.events = events @@ -478,4 +515,23 @@ async def test_events_with_empty_content_are_skipped(): assert llm_request.contents == [ types.UserContent("Hello"), types.UserContent("How are you?"), + types.Content( + parts=[ + types.Part( + inline_data=types.Blob(data=b"test", mime_type="image/png") + ) + ], + role="user", + ), + types.Content( + parts=[ + types.Part( + file_data=types.FileData( + file_uri="gs://test_bucket/test_file.png", + mime_type="image/png", + ) + ) + ], + role="user", + ), ] diff --git a/tests/unittests/flows/llm_flows/test_progressive_sse_streaming.py b/tests/unittests/flows/llm_flows/test_progressive_sse_streaming.py index e64613ff8d..e589d51c7d 100644 --- a/tests/unittests/flows/llm_flows/test_progressive_sse_streaming.py +++ b/tests/unittests/flows/llm_flows/test_progressive_sse_streaming.py @@ -397,3 +397,245 @@ def test_progressive_sse_preserves_part_ordering(): # Part 4: Second function call (from chunk8) assert parts[4].function_call.name == "get_weather" assert parts[4].function_call.args["location"] == "New York" + + +def test_progressive_sse_streaming_function_call_arguments(): + """Test streaming function call arguments feature. + + This test simulates the streamFunctionCallArguments feature where a function + call's arguments are streamed incrementally across multiple chunks: + + Chunk 1: FC name + partial location argument ("New ") + Chunk 2: Continue location argument ("York") -> concatenated to "New York" + Chunk 3: Add unit argument ("celsius"), willContinue=False -> FC complete + + Expected result: FunctionCall(name="get_weather", + args={"location": "New York", "unit": + "celsius"}, + id="fc_001") + """ + + aggregator = StreamingResponseAggregator() + + # Chunk 1: FC name + partial location argument + chunk1_fc = types.FunctionCall( + name="get_weather", + id="fc_001", + partial_args=[ + types.PartialArg(json_path="$.location", string_value="New ") + ], + will_continue=True, + ) + chunk1 = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content( + role="model", parts=[types.Part(function_call=chunk1_fc)] + ) + ) + ] + ) + + # Chunk 2: Continue streaming location argument + chunk2_fc = types.FunctionCall( + partial_args=[ + types.PartialArg(json_path="$.location", string_value="York") + ], + will_continue=True, + ) + chunk2 = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content( + role="model", parts=[types.Part(function_call=chunk2_fc)] + ) + ) + ] + ) + + # Chunk 3: Add unit argument, FC complete + chunk3_fc = types.FunctionCall( + partial_args=[ + types.PartialArg(json_path="$.unit", string_value="celsius") + ], + will_continue=False, # FC complete + ) + chunk3 = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content( + role="model", parts=[types.Part(function_call=chunk3_fc)] + ), + finish_reason=types.FinishReason.STOP, + ) + ] + ) + + # Process all chunks through aggregator + processed_chunks = [] + for chunk in [chunk1, chunk2, chunk3]: + + async def process(): + results = [] + async for response in aggregator.process_response(chunk): + results.append(response) + return results + + import asyncio + + chunk_results = asyncio.run(process()) + processed_chunks.extend(chunk_results) + + # Get final aggregated response + final_response = aggregator.close() + + # Verify final aggregated response has complete FC + assert final_response is not None + assert len(final_response.content.parts) == 1 + + fc_part = final_response.content.parts[0] + assert fc_part.function_call is not None + assert fc_part.function_call.name == "get_weather" + assert fc_part.function_call.id == "fc_001" + + # Verify arguments were correctly assembled from streaming chunks + args = fc_part.function_call.args + assert args["location"] == "New York" # "New " + "York" concatenated + assert args["unit"] == "celsius" + + +def test_progressive_sse_preserves_thought_signature(): + """Test that thought_signature is preserved when streaming FC arguments. + + This test verifies that when a streaming function call has a thought_signature + in the Part, it is correctly preserved in the final aggregated FunctionCall. + """ + + aggregator = StreamingResponseAggregator() + + # Create a thought signature (simulating what Gemini returns) + # thought_signature is bytes (base64 encoded) + test_thought_signature = b"test_signature_abc123" + + # Chunk with streaming FC args and thought_signature + chunk_fc = types.FunctionCall( + name="add_5_numbers", + id="fc_003", + partial_args=[ + types.PartialArg(json_path="$.num1", number_value=10), + types.PartialArg(json_path="$.num2", number_value=20), + ], + will_continue=False, + ) + + # Create Part with both function_call AND thought_signature + chunk_part = types.Part( + function_call=chunk_fc, thought_signature=test_thought_signature + ) + + chunk = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content(role="model", parts=[chunk_part]), + finish_reason=types.FinishReason.STOP, + ) + ] + ) + + # Process chunk through aggregator + async def process(): + results = [] + async for response in aggregator.process_response(chunk): + results.append(response) + return results + + import asyncio + + asyncio.run(process()) + + # Get final aggregated response + final_response = aggregator.close() + + # Verify thought_signature was preserved in the Part + assert final_response is not None + assert len(final_response.content.parts) == 1 + + fc_part = final_response.content.parts[0] + assert fc_part.function_call is not None + assert fc_part.function_call.name == "add_5_numbers" + + assert fc_part.thought_signature == test_thought_signature + + +def test_progressive_sse_handles_empty_function_call(): + """Test that empty function calls are skipped. + + When using streamFunctionCallArguments, Gemini may send an empty + functionCall: {} as the final chunk to signal streaming completion. + This test verifies that such empty function calls are properly skipped + and don't cause errors. + """ + + aggregator = StreamingResponseAggregator() + + # Chunk 1: Streaming FC with partial args + chunk1_fc = types.FunctionCall( + name="concat_number_and_string", + id="fc_001", + partial_args=[ + types.PartialArg(json_path="$.num", number_value=100), + types.PartialArg(json_path="$.s", string_value="ADK"), + ], + will_continue=False, + ) + chunk1 = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content( + role="model", parts=[types.Part(function_call=chunk1_fc)] + ) + ) + ] + ) + + # Chunk 2: Empty function call (streaming end marker) + chunk2_fc = types.FunctionCall() # Empty function call + chunk2 = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content( + role="model", parts=[types.Part(function_call=chunk2_fc)] + ), + finish_reason=types.FinishReason.STOP, + ) + ] + ) + + # Process all chunks through aggregator + async def process(): + results = [] + for chunk in [chunk1, chunk2]: + async for response in aggregator.process_response(chunk): + results.append(response) + return results + + import asyncio + + asyncio.run(process()) + + # Get final aggregated response + final_response = aggregator.close() + + # Verify final response only has the real FC, not the empty one + assert final_response is not None + assert len(final_response.content.parts) == 1 + + fc_part = final_response.content.parts[0] + assert fc_part.function_call is not None + assert fc_part.function_call.name == "concat_number_and_string" + assert fc_part.function_call.id == "fc_001" + + # Verify arguments + args = fc_part.function_call.args + assert args["num"] == 100 + assert args["s"] == "ADK" diff --git a/tests/unittests/models/test_anthropic_llm.py b/tests/unittests/models/test_anthropic_llm.py index e5ac8cc051..e1880abf0d 100644 --- a/tests/unittests/models/test_anthropic_llm.py +++ b/tests/unittests/models/test_anthropic_llm.py @@ -19,7 +19,9 @@ from anthropic import types as anthropic_types from google.adk import version as adk_version from google.adk.models import anthropic_llm +from google.adk.models.anthropic_llm import AnthropicLlm from google.adk.models.anthropic_llm import Claude +from google.adk.models.anthropic_llm import content_to_message_param from google.adk.models.anthropic_llm import function_declaration_to_tool_param from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse @@ -358,6 +360,37 @@ async def mock_coro(): assert responses[0].content.parts[0].text == "Hello, how can I help you?" +@pytest.mark.asyncio +async def test_anthropic_llm_generate_content_async( + llm_request, generate_content_response, generate_llm_response +): + anthropic_llm_instance = AnthropicLlm(model="claude-sonnet-4-20250514") + with mock.patch.object( + anthropic_llm_instance, "_anthropic_client" + ) as mock_client: + with mock.patch.object( + anthropic_llm, + "message_to_generate_content_response", + return_value=generate_llm_response, + ): + # Create a mock coroutine that returns the generate_content_response. + async def mock_coro(): + return generate_content_response + + # Assign the coroutine to the mocked method + mock_client.messages.create.return_value = mock_coro() + + responses = [ + resp + async for resp in anthropic_llm_instance.generate_content_async( + llm_request, stream=False + ) + ] + assert len(responses) == 1 + assert isinstance(responses[0], LlmResponse) + assert responses[0].content.parts[0].text == "Hello, how can I help you?" + + @pytest.mark.asyncio async def test_generate_content_async_with_max_tokens( llm_request, generate_content_response, generate_llm_response @@ -462,3 +495,81 @@ def test_part_to_message_block_with_multiple_content_items(): assert isinstance(result, dict) # Multiple text items should be joined with newlines assert result["content"] == "First part\nSecond part" + + +content_to_message_param_test_cases = [ + ( + "user_role_with_text_and_image", + Content( + role="user", + parts=[ + Part.from_text(text="What's in this image?"), + Part( + inline_data=types.Blob( + mime_type="image/jpeg", data=b"fake_image_data" + ) + ), + ], + ), + "user", + 2, # Expected content length + False, # Should not log warning + ), + ( + "model_role_with_text_and_image", + Content( + role="model", + parts=[ + Part.from_text(text="I see a cat."), + Part( + inline_data=types.Blob( + mime_type="image/png", data=b"fake_image_data" + ) + ), + ], + ), + "assistant", + 1, # Image filtered out, only text remains + True, # Should log warning + ), + ( + "assistant_role_with_text_and_image", + Content( + role="assistant", + parts=[ + Part.from_text(text="Here's what I found."), + Part( + inline_data=types.Blob( + mime_type="image/webp", data=b"fake_image_data" + ) + ), + ], + ), + "assistant", + 1, # Image filtered out, only text remains + True, # Should log warning + ), +] + + +@pytest.mark.parametrize( + "_, content, expected_role, expected_content_length, should_log_warning", + content_to_message_param_test_cases, + ids=[case[0] for case in content_to_message_param_test_cases], +) +def test_content_to_message_param_with_images( + _, content, expected_role, expected_content_length, should_log_warning +): + """Test content_to_message_param handles images correctly based on role.""" + with mock.patch("google.adk.models.anthropic_llm.logger") as mock_logger: + result = content_to_message_param(content) + + assert result["role"] == expected_role + assert len(result["content"]) == expected_content_length + + if should_log_warning: + mock_logger.warning.assert_called_once_with( + "Image data is not supported in Claude for assistant turns." + ) + else: + mock_logger.warning.assert_not_called() diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index 6d3e685748..190007603c 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -15,6 +15,7 @@ from unittest import mock from google.adk.models.gemini_llm_connection import GeminiLlmConnection +from google.adk.utils.variant_utils import GoogleLLMVariant from google.genai import types import pytest @@ -28,7 +29,17 @@ def mock_gemini_session(): @pytest.fixture def gemini_connection(mock_gemini_session): """GeminiLlmConnection instance with mocked session.""" - return GeminiLlmConnection(mock_gemini_session) + return GeminiLlmConnection( + mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI + ) + + +@pytest.fixture +def gemini_api_connection(mock_gemini_session): + """GeminiLlmConnection instance with mocked session for Gemini API.""" + return GeminiLlmConnection( + mock_gemini_session, api_backend=GoogleLLMVariant.GEMINI_API + ) @pytest.fixture @@ -225,6 +236,227 @@ async def mock_receive_generator(): assert content_response.content == mock_content +@pytest.mark.asyncio +async def test_receive_transcript_finished_on_interrupt( + gemini_api_connection, + mock_gemini_session, +): + """Test receive finishes transcription on interrupt signal.""" + + message1 = mock.Mock() + message1.usage_metadata = None + message1.server_content = mock.Mock() + message1.server_content.model_turn = None + message1.server_content.interrupted = False + message1.server_content.input_transcription = types.Transcription( + text='Hello', finished=False + ) + message1.server_content.output_transcription = None + message1.server_content.turn_complete = False + message1.server_content.generation_complete = False + message1.tool_call = None + message1.session_resumption_update = None + + message2 = mock.Mock() + message2.usage_metadata = None + message2.server_content = mock.Mock() + message2.server_content.model_turn = None + message2.server_content.interrupted = False + message2.server_content.input_transcription = None + message2.server_content.output_transcription = types.Transcription( + text='How can', finished=False + ) + message2.server_content.turn_complete = False + message2.server_content.generation_complete = False + message2.tool_call = None + message2.session_resumption_update = None + + message3 = mock.Mock() + message3.usage_metadata = None + message3.server_content = mock.Mock() + message3.server_content.model_turn = None + message3.server_content.interrupted = True + message3.server_content.input_transcription = None + message3.server_content.output_transcription = None + message3.server_content.turn_complete = False + message3.server_content.generation_complete = False + message3.tool_call = None + message3.session_resumption_update = None + + async def mock_receive_generator(): + yield message1 + yield message2 + yield message3 + + receive_mock = mock.Mock(return_value=mock_receive_generator()) + mock_gemini_session.receive = receive_mock + + responses = [resp async for resp in gemini_api_connection.receive()] + + assert len(responses) == 5 + assert responses[4].interrupted is True + + assert responses[0].input_transcription.text == 'Hello' + assert responses[0].input_transcription.finished is False + assert responses[0].partial is True + assert responses[1].output_transcription.text == 'How can' + assert responses[1].output_transcription.finished is False + assert responses[1].partial is True + assert responses[2].input_transcription.text == 'Hello' + assert responses[2].input_transcription.finished is True + assert responses[2].partial is False + assert responses[3].output_transcription.text == 'How can' + assert responses[3].output_transcription.finished is True + assert responses[3].partial is False + + +@pytest.mark.asyncio +async def test_receive_transcript_finished_on_generation_complete( + gemini_api_connection, + mock_gemini_session, +): + """Test receive finishes transcription on generation_complete signal.""" + + message1 = mock.Mock() + message1.usage_metadata = None + message1.server_content = mock.Mock() + message1.server_content.model_turn = None + message1.server_content.interrupted = False + message1.server_content.input_transcription = types.Transcription( + text='Hello', finished=False + ) + message1.server_content.output_transcription = None + message1.server_content.turn_complete = False + message1.server_content.generation_complete = False + message1.tool_call = None + message1.session_resumption_update = None + + message2 = mock.Mock() + message2.usage_metadata = None + message2.server_content = mock.Mock() + message2.server_content.model_turn = None + message2.server_content.interrupted = False + message2.server_content.input_transcription = None + message2.server_content.output_transcription = types.Transcription( + text='How can', finished=False + ) + message2.server_content.turn_complete = False + message2.server_content.generation_complete = False + message2.tool_call = None + message2.session_resumption_update = None + + message3 = mock.Mock() + message3.usage_metadata = None + message3.server_content = mock.Mock() + message3.server_content.model_turn = None + message3.server_content.interrupted = False + message3.server_content.input_transcription = None + message3.server_content.output_transcription = None + message3.server_content.turn_complete = False + message3.server_content.generation_complete = True + message3.tool_call = None + message3.session_resumption_update = None + + async def mock_receive_generator(): + yield message1 + yield message2 + yield message3 + + receive_mock = mock.Mock(return_value=mock_receive_generator()) + mock_gemini_session.receive = receive_mock + + responses = [resp async for resp in gemini_api_connection.receive()] + + assert len(responses) == 4 + + assert responses[0].input_transcription.text == 'Hello' + assert responses[0].input_transcription.finished is False + assert responses[0].partial is True + assert responses[1].output_transcription.text == 'How can' + assert responses[1].output_transcription.finished is False + assert responses[1].partial is True + assert responses[2].input_transcription.text == 'Hello' + assert responses[2].input_transcription.finished is True + assert responses[2].partial is False + assert responses[3].output_transcription.text == 'How can' + assert responses[3].output_transcription.finished is True + assert responses[3].partial is False + + +@pytest.mark.asyncio +async def test_receive_transcript_finished_on_turn_complete( + gemini_api_connection, + mock_gemini_session, +): + """Test receive finishes transcription on interrupt or complete signals.""" + + message1 = mock.Mock() + message1.usage_metadata = None + message1.server_content = mock.Mock() + message1.server_content.model_turn = None + message1.server_content.interrupted = False + message1.server_content.input_transcription = types.Transcription( + text='Hello', finished=False + ) + message1.server_content.output_transcription = None + message1.server_content.turn_complete = False + message1.server_content.generation_complete = False + message1.tool_call = None + message1.session_resumption_update = None + + message2 = mock.Mock() + message2.usage_metadata = None + message2.server_content = mock.Mock() + message2.server_content.model_turn = None + message2.server_content.interrupted = False + message2.server_content.input_transcription = None + message2.server_content.output_transcription = types.Transcription( + text='How can', finished=False + ) + message2.server_content.turn_complete = False + message2.server_content.generation_complete = False + message2.tool_call = None + message2.session_resumption_update = None + + message3 = mock.Mock() + message3.usage_metadata = None + message3.server_content = mock.Mock() + message3.server_content.model_turn = None + message3.server_content.interrupted = False + message3.server_content.input_transcription = None + message3.server_content.output_transcription = None + message3.server_content.turn_complete = True + message3.server_content.generation_complete = False + message3.tool_call = None + message3.session_resumption_update = None + + async def mock_receive_generator(): + yield message1 + yield message2 + yield message3 + + receive_mock = mock.Mock(return_value=mock_receive_generator()) + mock_gemini_session.receive = receive_mock + + responses = [resp async for resp in gemini_api_connection.receive()] + + assert len(responses) == 5 + assert responses[4].turn_complete is True + + assert responses[0].input_transcription.text == 'Hello' + assert responses[0].input_transcription.finished is False + assert responses[0].partial is True + assert responses[1].output_transcription.text == 'How can' + assert responses[1].output_transcription.finished is False + assert responses[1].partial is True + assert responses[2].input_transcription.text == 'Hello' + assert responses[2].input_transcription.finished is True + assert responses[2].partial is False + assert responses[3].output_transcription.text == 'How can' + assert responses[3].output_transcription.finished is True + assert responses[3].partial is False + + @pytest.mark.asyncio async def test_receive_handles_input_transcription_fragments( gemini_connection, mock_gemini_session @@ -240,6 +472,7 @@ async def test_receive_handles_input_transcription_fragments( ) message1.server_content.output_transcription = None message1.server_content.turn_complete = False + message1.server_content.generation_complete = False message1.tool_call = None message1.session_resumption_update = None @@ -253,6 +486,7 @@ async def test_receive_handles_input_transcription_fragments( ) message2.server_content.output_transcription = None message2.server_content.turn_complete = False + message2.server_content.generation_complete = False message2.tool_call = None message2.session_resumption_update = None @@ -266,6 +500,7 @@ async def test_receive_handles_input_transcription_fragments( ) message3.server_content.output_transcription = None message3.server_content.turn_complete = False + message3.server_content.generation_complete = False message3.tool_call = None message3.session_resumption_update = None @@ -306,6 +541,7 @@ async def test_receive_handles_output_transcription_fragments( text='How can', finished=False ) message1.server_content.turn_complete = False + message1.server_content.generation_complete = False message1.tool_call = None message1.session_resumption_update = None @@ -319,6 +555,7 @@ async def test_receive_handles_output_transcription_fragments( text=' I help?', finished=False ) message2.server_content.turn_complete = False + message2.server_content.generation_complete = False message2.tool_call = None message2.session_resumption_update = None @@ -332,6 +569,7 @@ async def test_receive_handles_output_transcription_fragments( text=None, finished=True ) message3.server_content.turn_complete = False + message3.server_content.generation_complete = False message3.tool_call = None message3.session_resumption_update = None diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index f2419daf3f..ddf1b07667 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -22,8 +22,6 @@ from google.adk.agents.context_cache_config import ContextCacheConfig from google.adk.models.cache_metadata import CacheMetadata from google.adk.models.gemini_llm_connection import GeminiLlmConnection -from google.adk.models.google_llm import _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME -from google.adk.models.google_llm import _AGENT_ENGINE_TELEMETRY_TAG from google.adk.models.google_llm import _build_function_declaration_log from google.adk.models.google_llm import _build_request_log from google.adk.models.google_llm import _RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE @@ -31,6 +29,8 @@ from google.adk.models.google_llm import Gemini from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse +from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME +from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_TAG from google.adk.utils.variant_utils import GoogleLLMVariant from google.genai import types from google.genai.errors import ClientError @@ -142,13 +142,6 @@ def llm_request_with_computer_use(): ) -@pytest.fixture -def mock_os_environ(): - initial_env = os.environ.copy() - with mock.patch.dict(os.environ, initial_env, clear=False) as m: - yield m - - def test_supported_models(): models = Gemini.supported_models() assert len(models) == 4 @@ -193,12 +186,15 @@ def test_client_version_header(): ) -def test_client_version_header_with_agent_engine(mock_os_environ): - os.environ[_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME] = "my_test_project" +def test_client_version_header_with_agent_engine(monkeypatch): + monkeypatch.setenv( + _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME, "my_test_project" + ) model = Gemini(model="gemini-1.5-flash") client = model.api_client - # Check that ADK version with telemetry tag and Python version are present in headers + # Check that ADK version with telemetry tag and Python version are present in + # headers adk_version_with_telemetry = ( f"google-adk/{adk_version.__version__}+{_AGENT_ENGINE_TELEMETRY_TAG}" ) @@ -473,8 +469,9 @@ async def test_generate_content_async_with_custom_headers( """Test that tracking headers are updated when custom headers are provided.""" # Add custom headers to the request config custom_headers = {"custom-header": "custom-value"} - for key in gemini_llm._tracking_headers: - custom_headers[key] = "custom " + gemini_llm._tracking_headers[key] + tracking_headers = gemini_llm._tracking_headers() + for key in tracking_headers: + custom_headers[key] = "custom " + tracking_headers[key] llm_request.config.http_options = types.HttpOptions(headers=custom_headers) with mock.patch.object(gemini_llm, "api_client") as mock_client: @@ -497,8 +494,9 @@ async def mock_coro(): config_arg = call_args.kwargs["config"] for key, value in config_arg.http_options.headers.items(): - if key in gemini_llm._tracking_headers: - assert value == gemini_llm._tracking_headers[key] + " custom" + tracking_headers = gemini_llm._tracking_headers() + if key in tracking_headers: + assert value == tracking_headers[key] + " custom" else: assert value == custom_headers[key] @@ -547,7 +545,7 @@ async def mock_coro(): config_arg = call_args.kwargs["config"] expected_headers = custom_headers.copy() - expected_headers.update(gemini_llm._tracking_headers) + expected_headers.update(gemini_llm._tracking_headers()) assert config_arg.http_options.headers == expected_headers assert len(responses) == 2 @@ -601,7 +599,7 @@ async def mock_coro(): assert final_config.http_options is not None assert ( final_config.http_options.headers["x-goog-api-client"] - == gemini_llm._tracking_headers["x-goog-api-client"] + == gemini_llm._tracking_headers()["x-goog-api-client"] ) assert len(responses) == 2 if stream else 1 @@ -635,7 +633,7 @@ def test_live_api_client_properties(gemini_llm): assert http_options.api_version == "v1beta1" # Check that tracking headers are included - tracking_headers = gemini_llm._tracking_headers + tracking_headers = gemini_llm._tracking_headers() for key, value in tracking_headers.items(): assert key in http_options.headers assert value in http_options.headers[key] @@ -673,7 +671,7 @@ async def __aexit__(self, *args): # Verify that tracking headers were merged with custom headers expected_headers = custom_headers.copy() - expected_headers.update(gemini_llm._tracking_headers) + expected_headers.update(gemini_llm._tracking_headers()) assert config_arg.http_options.headers == expected_headers # Verify that API version was set diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 8f2ae50b42..f65fc77a61 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -17,7 +17,6 @@ from unittest.mock import Mock import warnings -from google.adk.models.lite_llm import _build_function_declaration_log from google.adk.models.lite_llm import _content_to_message_param from google.adk.models.lite_llm import _FINISH_REASON_MAPPING from google.adk.models.lite_llm import _function_declaration_to_tool_param @@ -25,7 +24,9 @@ from google.adk.models.lite_llm import _get_content from google.adk.models.lite_llm import _message_to_generate_content_response from google.adk.models.lite_llm import _model_response_to_chunk +from google.adk.models.lite_llm import _model_response_to_generate_content_response from google.adk.models.lite_llm import _parse_tool_calls_from_text +from google.adk.models.lite_llm import _schema_to_dict from google.adk.models.lite_llm import _split_message_content_and_tool_calls from google.adk.models.lite_llm import _to_litellm_response_format from google.adk.models.lite_llm import _to_litellm_role @@ -286,6 +287,28 @@ def test_to_litellm_response_format_handles_genai_schema_instance(): ) +def test_schema_to_dict_filters_none_enum_values(): + # Use model_construct to bypass strict enum validation. + top_level_schema = types.Schema.model_construct( + type=types.Type.STRING, + enum=["ACTIVE", None, "INACTIVE"], + ) + nested_schema = types.Schema.model_construct( + type=types.Type.OBJECT, + properties={ + "status": types.Schema.model_construct( + type=types.Type.STRING, enum=["READY", None, "DONE"] + ), + }, + ) + + assert _schema_to_dict(top_level_schema)["enum"] == ["ACTIVE", "INACTIVE"] + assert _schema_to_dict(nested_schema)["properties"]["status"]["enum"] == [ + "READY", + "DONE", + ] + + MULTIPLE_FUNCTION_CALLS_STREAM = [ ModelResponse( choices=[ @@ -630,54 +653,6 @@ def completion(self, model, messages, tools, stream, **kwargs): ) -def test_build_function_declaration_log(): - """Test that _build_function_declaration_log formats function declarations correctly.""" - # Test case 1: Function with parameters and response - func_decl1 = types.FunctionDeclaration( - name="test_func1", - description="Test function 1", - parameters=types.Schema( - type=types.Type.OBJECT, - properties={ - "param1": types.Schema( - type=types.Type.STRING, description="param1 desc" - ) - }, - ), - response=types.Schema(type=types.Type.BOOLEAN, description="return bool"), - ) - log1 = _build_function_declaration_log(func_decl1) - assert log1 == ( - "test_func1: {'param1': {'description': 'param1 desc', 'type':" - " }} -> {'description': 'return bool', 'type':" - " }" - ) - - # Test case 2: Function with JSON schema parameters and response - func_decl2 = types.FunctionDeclaration( - name="test_func2", - description="Test function 2", - parameters_json_schema={ - "type": "object", - "properties": {"param2": {"type": "integer"}}, - }, - response_json_schema={"type": "string"}, - ) - log2 = _build_function_declaration_log(func_decl2) - assert log2 == ( - "test_func2: {'type': 'object', 'properties': {'param2': {'type':" - " 'integer'}}} -> {'type': 'string'}" - ) - - # Test case 3: Function with no parameters and no response - func_decl3 = types.FunctionDeclaration( - name="test_func3", - description="Test function 3", - ) - log3 = _build_function_declaration_log(func_decl3) - assert log3 == "test_func3: {} -> None" - - @pytest.mark.asyncio async def test_generate_content_async(mock_acompletion, lite_llm_instance): @@ -1220,7 +1195,7 @@ async def test_generate_content_async_with_system_instruction( _, kwargs = mock_acompletion.call_args assert kwargs["model"] == "test_model" - assert kwargs["messages"][0]["role"] == "developer" + assert kwargs["messages"][0]["role"] == "system" assert kwargs["messages"][0]["content"] == "Test system instruction" assert kwargs["messages"][1]["role"] == "user" assert kwargs["messages"][1]["content"] == "Test prompt" @@ -1535,6 +1510,42 @@ def test_message_to_generate_content_response_with_model(): assert response.model_version == "gemini-2.5-pro" +def test_message_to_generate_content_response_reasoning_content(): + message = { + "role": "assistant", + "content": "Visible text", + "reasoning_content": "Hidden chain", + } + response = _message_to_generate_content_response(message) + + assert len(response.content.parts) == 2 + thought_part = response.content.parts[0] + text_part = response.content.parts[1] + assert thought_part.text == "Hidden chain" + assert thought_part.thought is True + assert text_part.text == "Visible text" + + +def test_model_response_to_generate_content_response_reasoning_content(): + model_response = ModelResponse( + model="thinking-model", + choices=[{ + "message": { + "role": "assistant", + "content": "Answer", + "reasoning_content": "Step-by-step", + }, + "finish_reason": "stop", + }], + ) + + response = _model_response_to_generate_content_response(model_response) + + assert response.content.parts[0].text == "Step-by-step" + assert response.content.parts[0].thought is True + assert response.content.parts[1].text == "Answer" + + def test_parse_tool_calls_from_text_multiple_calls(): text = ( '{"name":"alpha","arguments":{"value":1}}\n' diff --git a/tests/unittests/models/test_models.py b/tests/unittests/models/test_models.py index 70246c7bc1..8575064baa 100644 --- a/tests/unittests/models/test_models.py +++ b/tests/unittests/models/test_models.py @@ -15,7 +15,7 @@ from google.adk import models from google.adk.models.anthropic_llm import Claude from google.adk.models.google_llm import Gemini -from google.adk.models.registry import LLMRegistry +from google.adk.models.lite_llm import LiteLlm import pytest @@ -34,6 +34,7 @@ ], ) def test_match_gemini_family(model_name): + """Test that Gemini models are resolved correctly.""" assert models.LLMRegistry.resolve(model_name) is Gemini @@ -51,12 +52,63 @@ def test_match_gemini_family(model_name): ], ) def test_match_claude_family(model_name): - LLMRegistry.register(Claude) - + """Test that Claude models are resolved correctly.""" assert models.LLMRegistry.resolve(model_name) is Claude +@pytest.mark.parametrize( + 'model_name', + [ + 'openai/gpt-4o', + 'openai/gpt-4o-mini', + 'groq/llama3-70b-8192', + 'groq/mixtral-8x7b-32768', + 'anthropic/claude-3-opus-20240229', + 'anthropic/claude-3-5-sonnet-20241022', + ], +) +def test_match_litellm_family(model_name): + """Test that LiteLLM models are resolved correctly.""" + assert models.LLMRegistry.resolve(model_name) is LiteLlm + + def test_non_exist_model(): with pytest.raises(ValueError) as e_info: models.LLMRegistry.resolve('non-exist-model') assert 'Model non-exist-model not found.' in str(e_info.value) + + +def test_helpful_error_for_claude_without_extensions(): + """Test that missing Claude models show helpful install instructions. + + Note: This test may pass even when anthropic IS installed, because it + only checks the error message format when a model is not found. + """ + # Use a non-existent Claude model variant to trigger error + with pytest.raises(ValueError) as e_info: + models.LLMRegistry.resolve('claude-nonexistent-model-xyz') + + error_msg = str(e_info.value) + # The error should mention anthropic package and installation instructions + # These checks work whether or not anthropic is actually installed + assert 'Model claude-nonexistent-model-xyz not found' in error_msg + assert 'anthropic package' in error_msg + assert 'pip install' in error_msg + + +def test_helpful_error_for_litellm_without_extensions(): + """Test that missing LiteLLM models show helpful install instructions. + + Note: This test may pass even when litellm IS installed, because it + only checks the error message format when a model is not found. + """ + # Use a non-existent provider to trigger error + with pytest.raises(ValueError) as e_info: + models.LLMRegistry.resolve('unknown-provider/gpt-4o') + + error_msg = str(e_info.value) + # The error should mention litellm package for provider-style models + assert 'Model unknown-provider/gpt-4o not found' in error_msg + assert 'litellm package' in error_msg + assert 'pip install' in error_msg + assert 'Provider-style models' in error_msg diff --git a/tests/unittests/plugins/test_reflect_retry_tool_plugin.py b/tests/unittests/plugins/test_reflect_retry_tool_plugin.py index 1e15f33899..2cf52e99cb 100644 --- a/tests/unittests/plugins/test_reflect_retry_tool_plugin.py +++ b/tests/unittests/plugins/test_reflect_retry_tool_plugin.py @@ -57,10 +57,8 @@ async def extract_error_from_result( return None -# Inheriting from IsolatedAsyncioTestCase ensures these tests works in Python -# 3.9. See https://github.com/pytest-dev/pytest-asyncio/issues/1039 -# Without this, the tests will fail with a "RuntimeError: There is no current -# event loop in thread 'MainThread'." +# Inheriting from IsolatedAsyncioTestCase ensures consistent behavior. +# See https://github.com/pytest-dev/pytest-asyncio/issues/1039 class TestReflectAndRetryToolPlugin(IsolatedAsyncioTestCase): """Comprehensive tests for ReflectAndRetryToolPlugin focusing on behavior.""" diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 661e6ead59..45aa3feede 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -531,6 +531,79 @@ async def test_append_event_complete(service_type, tmp_path): ) +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'service_type', + [ + SessionServiceType.IN_MEMORY, + SessionServiceType.DATABASE, + SessionServiceType.SQLITE, + ], +) +async def test_session_last_update_time_updates_on_event( + service_type, tmp_path +): + session_service = get_session_service(service_type, tmp_path) + app_name = 'my_app' + user_id = 'user' + + session = await session_service.create_session( + app_name=app_name, user_id=user_id + ) + original_update_time = session.last_update_time + + event_timestamp = original_update_time + 10 + event = Event( + invocation_id='invocation', + author='user', + timestamp=event_timestamp, + ) + await session_service.append_event(session=session, event=event) + + assert session.last_update_time == pytest.approx(event_timestamp, abs=1e-6) + + refreshed_session = await session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + assert refreshed_session is not None + assert refreshed_session.last_update_time == pytest.approx( + event_timestamp, abs=1e-6 + ) + assert refreshed_session.last_update_time > original_update_time + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'service_type', + [ + SessionServiceType.IN_MEMORY, + SessionServiceType.DATABASE, + SessionServiceType.SQLITE, + ], +) +async def test_get_session_with_config(service_type): + session_service = get_session_service(service_type) + app_name = 'my_app' + user_id = 'user' + + session = await session_service.create_session( + app_name=app_name, user_id=user_id + ) + original_update_time = session.last_update_time + + event = Event(invocation_id='invocation', author='user') + await session_service.append_event(session=session, event=event) + + assert session.last_update_time >= event.timestamp + + refreshed_session = await session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + assert refreshed_session is not None + assert refreshed_session.last_update_time >= event.timestamp + assert refreshed_session.last_update_time > original_update_time + + @pytest.mark.asyncio @pytest.mark.parametrize( 'service_type', diff --git a/tests/unittests/telemetry/test_functional.py b/tests/unittests/telemetry/test_functional.py index 409571ad1f..43fe672333 100644 --- a/tests/unittests/telemetry/test_functional.py +++ b/tests/unittests/telemetry/test_functional.py @@ -103,7 +103,7 @@ def wrapped_firstiter(coro): isinstance(referrer, Aclosing) or isinstance(indirect_referrer, Aclosing) for referrer in gc.get_referrers(coro) - # Some coroutines have a layer of indirection in python 3.9 and 3.10 + # Some coroutines have a layer of indirection in Python 3.10 for indirect_referrer in gc.get_referrers(referrer) ), f'Coro `{coro.__name__}` is not wrapped with Aclosing' firstiter(coro) diff --git a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py index eef83a1f5e..1791100e1f 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py +++ b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py @@ -1709,6 +1709,65 @@ def test_execute_sql_job_labels( } +@pytest.mark.parametrize( + ("write_mode", "dry_run", "query_call_count", "query_and_wait_call_count"), + [ + pytest.param(WriteMode.ALLOWED, False, 0, 1, id="write-allowed"), + pytest.param(WriteMode.ALLOWED, True, 1, 0, id="write-allowed-dry-run"), + pytest.param(WriteMode.BLOCKED, False, 1, 1, id="write-blocked"), + pytest.param(WriteMode.BLOCKED, True, 2, 0, id="write-blocked-dry-run"), + pytest.param(WriteMode.PROTECTED, False, 2, 1, id="write-protected"), + pytest.param( + WriteMode.PROTECTED, True, 3, 0, id="write-protected-dry-run" + ), + ], +) +def test_execute_sql_user_job_labels_augment_internal_labels( + write_mode, dry_run, query_call_count, query_and_wait_call_count +): + """Test execute_sql tool augments user job_labels with internal labels.""" + project = "my_project" + query = "SELECT 123 AS num" + statement_type = "SELECT" + credentials = mock.create_autospec(Credentials, instance=True) + user_labels = {"environment": "test", "team": "data"} + tool_settings = BigQueryToolConfig( + write_mode=write_mode, + job_labels=user_labels, + ) + tool_context = mock.create_autospec(ToolContext, instance=True) + tool_context.state.get.return_value = None + + with mock.patch.object(bigquery, "Client", autospec=True) as Client: + bq_client = Client.return_value + + query_job = mock.create_autospec(bigquery.QueryJob) + query_job.statement_type = statement_type + bq_client.query.return_value = query_job + + query_tool.execute_sql( + project, + query, + credentials, + tool_settings, + tool_context, + dry_run=dry_run, + ) + + assert bq_client.query.call_count == query_call_count + assert bq_client.query_and_wait.call_count == query_and_wait_call_count + # Build expected labels from user_labels + internal label + expected_labels = {**user_labels, "adk-bigquery-tool": "execute_sql"} + for call_args_list in [ + bq_client.query.call_args_list, + bq_client.query_and_wait.call_args_list, + ]: + for call_args in call_args_list: + _, mock_kwargs = call_args + # Verify user labels are preserved and internal label is added + assert mock_kwargs["job_config"].labels == expected_labels + + @pytest.mark.parametrize( ("tool_call", "expected_tool_label"), [ @@ -1850,6 +1909,94 @@ def test_ml_tool_job_labels_w_application_name(tool_call, expected_tool_label): assert mock_kwargs["job_config"].labels == expected_labels +@pytest.mark.parametrize( + ("tool_call", "expected_labels"), + [ + pytest.param( + lambda tool_context: query_tool.forecast( + project_id="test-project", + history_data="SELECT * FROM `test-dataset.test-table`", + timestamp_col="ts_col", + data_col="data_col", + credentials=mock.create_autospec(Credentials, instance=True), + settings=BigQueryToolConfig( + write_mode=WriteMode.ALLOWED, + job_labels={"environment": "prod", "app": "forecaster"}, + ), + tool_context=tool_context, + ), + { + "environment": "prod", + "app": "forecaster", + "adk-bigquery-tool": "forecast", + }, + id="forecast", + ), + pytest.param( + lambda tool_context: query_tool.analyze_contribution( + project_id="test-project", + input_data="test-dataset.test-table", + dimension_id_cols=["dim1", "dim2"], + contribution_metric="SUM(metric)", + is_test_col="is_test", + credentials=mock.create_autospec(Credentials, instance=True), + settings=BigQueryToolConfig( + write_mode=WriteMode.ALLOWED, + job_labels={"environment": "prod", "app": "analyzer"}, + ), + tool_context=tool_context, + ), + { + "environment": "prod", + "app": "analyzer", + "adk-bigquery-tool": "analyze_contribution", + }, + id="analyze-contribution", + ), + pytest.param( + lambda tool_context: query_tool.detect_anomalies( + project_id="test-project", + history_data="SELECT * FROM `test-dataset.test-table`", + times_series_timestamp_col="ts_timestamp", + times_series_data_col="ts_data", + credentials=mock.create_autospec(Credentials, instance=True), + settings=BigQueryToolConfig( + write_mode=WriteMode.ALLOWED, + job_labels={"environment": "prod", "app": "detector"}, + ), + tool_context=tool_context, + ), + { + "environment": "prod", + "app": "detector", + "adk-bigquery-tool": "detect_anomalies", + }, + id="detect-anomalies", + ), + ], +) +def test_ml_tool_user_job_labels_augment_internal_labels( + tool_call, expected_labels +): + """Test ML tools augment user job_labels with internal labels.""" + + with mock.patch.object(bigquery, "Client", autospec=True) as Client: + bq_client = Client.return_value + + tool_context = mock.create_autospec(ToolContext, instance=True) + tool_context.state.get.return_value = None + tool_call(tool_context) + + for call_args_list in [ + bq_client.query.call_args_list, + bq_client.query_and_wait.call_args_list, + ]: + for call_args in call_args_list: + _, mock_kwargs = call_args + # Verify user labels are preserved and internal label is added + assert mock_kwargs["job_config"].labels == expected_labels + + def test_execute_sql_max_rows_config(): """Test execute_sql tool respects max_query_result_rows from config.""" project = "my_project" @@ -2014,3 +2161,93 @@ def test_tool_call_doesnt_change_global_settings(tool_call): # Test settings write mode after assert settings.write_mode == WriteMode.ALLOWED + + +@pytest.mark.parametrize( + ("tool_call",), + [ + pytest.param( + lambda settings, tool_context: query_tool.execute_sql( + project_id="test-project", + query="SELECT * FROM `test-dataset.test-table`", + credentials=mock.create_autospec(Credentials, instance=True), + settings=settings, + tool_context=tool_context, + ), + id="execute-sql", + ), + pytest.param( + lambda settings, tool_context: query_tool.forecast( + project_id="test-project", + history_data="SELECT * FROM `test-dataset.test-table`", + timestamp_col="ts_col", + data_col="data_col", + credentials=mock.create_autospec(Credentials, instance=True), + settings=settings, + tool_context=tool_context, + ), + id="forecast", + ), + pytest.param( + lambda settings, tool_context: query_tool.analyze_contribution( + project_id="test-project", + input_data="test-dataset.test-table", + dimension_id_cols=["dim1", "dim2"], + contribution_metric="SUM(metric)", + is_test_col="is_test", + credentials=mock.create_autospec(Credentials, instance=True), + settings=settings, + tool_context=tool_context, + ), + id="analyze-contribution", + ), + pytest.param( + lambda settings, tool_context: query_tool.detect_anomalies( + project_id="test-project", + history_data="SELECT * FROM `test-dataset.test-table`", + times_series_timestamp_col="ts_timestamp", + times_series_data_col="ts_data", + credentials=mock.create_autospec(Credentials, instance=True), + settings=settings, + tool_context=tool_context, + ), + id="detect-anomalies", + ), + ], +) +def test_tool_call_doesnt_mutate_job_labels(tool_call): + """Test query tools don't mutate job_labels in global settings.""" + original_labels = {"environment": "test", "team": "data"} + settings = BigQueryToolConfig( + write_mode=WriteMode.ALLOWED, + job_labels=original_labels.copy(), + ) + tool_context = mock.create_autospec(ToolContext, instance=True) + tool_context.state.get.return_value = ( + "test-bq-session-id", + "_anonymous_dataset", + ) + + with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: + # The mock instance + bq_client = Client.return_value + + # Simulate the result of query API + query_job = mock.create_autospec(bigquery.QueryJob) + query_job.destination.dataset_id = "_anonymous_dataset" + bq_client.query.return_value = query_job + bq_client.query_and_wait.return_value = [] + + # Test job_labels before + assert settings.job_labels == original_labels + assert "adk-bigquery-tool" not in settings.job_labels + + # Call the tool + result = tool_call(settings, tool_context) + + # Test successful execution of the tool + assert result == {"status": "SUCCESS", "rows": []} + + # Test job_labels remain unchanged after tool call + assert settings.job_labels == original_labels + assert "adk-bigquery-tool" not in settings.job_labels diff --git a/tests/unittests/tools/bigquery/test_bigquery_tool_config.py b/tests/unittests/tools/bigquery/test_bigquery_tool_config.py index 5854c97797..072ccea7d0 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_tool_config.py +++ b/tests/unittests/tools/bigquery/test_bigquery_tool_config.py @@ -77,3 +77,61 @@ def test_bigquery_tool_config_invalid_maximum_bytes_billed(): ), ): BigQueryToolConfig(maximum_bytes_billed=10_485_759) + + +@pytest.mark.parametrize( + "labels", + [ + pytest.param( + {"environment": "test", "team": "data"}, + id="valid-labels", + ), + pytest.param( + {}, + id="empty-labels", + ), + pytest.param( + None, + id="none-labels", + ), + ], +) +def test_bigquery_tool_config_valid_labels(labels): + """Test BigQueryToolConfig accepts valid labels.""" + with pytest.warns(UserWarning): + config = BigQueryToolConfig(job_labels=labels) + assert config.job_labels == labels + + +@pytest.mark.parametrize( + ("labels", "message"), + [ + pytest.param( + "invalid", + "Input should be a valid dictionary", + id="invalid-type", + ), + pytest.param( + {123: "value"}, + "Input should be a valid string", + id="non-str-key", + ), + pytest.param( + {"key": 123}, + "Input should be a valid string", + id="non-str-value", + ), + pytest.param( + {"": "value"}, + "Label keys cannot be empty", + id="empty-label-key", + ), + ], +) +def test_bigquery_tool_config_invalid_labels(labels, message): + """Test BigQueryToolConfig raises an exception with invalid labels.""" + with pytest.raises( + ValueError, + match=message, + ): + BigQueryToolConfig(job_labels=labels) diff --git a/tests/unittests/tools/computer_use/test_computer_use_tool.py b/tests/unittests/tools/computer_use/test_computer_use_tool.py index 4dbdfbb5c0..f3843b87a6 100644 --- a/tests/unittests/tools/computer_use/test_computer_use_tool.py +++ b/tests/unittests/tools/computer_use/test_computer_use_tool.py @@ -47,7 +47,7 @@ async def tool_context(self): @pytest.fixture def mock_computer_function(self): """Fixture providing a mock computer function.""" - # Create a real async function instead of AsyncMock for Python 3.9 compatibility + # Create a real async function instead of AsyncMock for better test control calls = [] async def mock_func(*args, **kwargs): diff --git a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py index 6c001ccf65..74eabe9d4d 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py @@ -22,46 +22,14 @@ from unittest.mock import Mock from unittest.mock import patch +from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager +from google.adk.tools.mcp_tool.mcp_session_manager import retry_on_errors +from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams +from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams +from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams +from mcp import StdioServerParameters import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="MCP tool requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager - from google.adk.tools.mcp_tool.mcp_session_manager import retry_on_closed_resource - from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams - from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams - from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams -except ImportError as e: - if sys.version_info < (3, 10): - # Create dummy classes to prevent NameError during test collection - # Tests will be skipped anyway due to pytestmark - class DummyClass: - pass - - MCPSessionManager = DummyClass - retry_on_closed_resource = lambda x: x - SseConnectionParams = DummyClass - StdioConnectionParams = DummyClass - StreamableHTTPConnectionParams = DummyClass - else: - raise e - -# Import real MCP classes -try: - from mcp import StdioServerParameters -except ImportError: - # Create a mock if MCP is not available - class StdioServerParameters: - - def __init__(self, command="test_command", args=None): - self.command = command - self.args = args or [] - class MockClientSession: """Mock ClientSession for testing.""" @@ -375,12 +343,12 @@ async def test_close_with_errors(self): assert "Close error 1" in error_output -def test_retry_on_closed_resource_decorator(): - """Test the retry_on_closed_resource decorator.""" +def test_retry_on_errors_decorator(): + """Test the retry_on_errors decorator.""" call_count = 0 - @retry_on_closed_resource + @retry_on_errors async def mock_function(self): nonlocal call_count call_count += 1 diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py index 17b1d8e54e..1284e73bce 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_tool.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys from unittest.mock import AsyncMock from unittest.mock import Mock from unittest.mock import patch @@ -23,39 +22,15 @@ from google.adk.auth.auth_credential import HttpCredentials from google.adk.auth.auth_credential import OAuth2Auth from google.adk.auth.auth_credential import ServiceAccount +from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager +from google.adk.tools.mcp_tool.mcp_tool import MCPTool +from google.adk.tools.tool_context import ToolContext +from google.genai.types import FunctionDeclaration +from google.genai.types import Type +from mcp.types import CallToolResult +from mcp.types import TextContent import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="MCP tool requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager - from google.adk.tools.mcp_tool.mcp_tool import MCPTool - from google.adk.tools.tool_context import ToolContext - from google.genai.types import FunctionDeclaration - from google.genai.types import Type - from mcp.types import CallToolResult - from mcp.types import TextContent -except ImportError as e: - if sys.version_info < (3, 10): - # Create dummy classes to prevent NameError during test collection - # Tests will be skipped anyway due to pytestmark - class DummyClass: - pass - - MCPSessionManager = DummyClass - MCPTool = DummyClass - ToolContext = DummyClass - FunctionDeclaration = DummyClass - Type = DummyClass - CallToolResult = DummyClass - TextContent = DummyClass - else: - raise e - # Mock MCP Tool from mcp.types class MockMCPTool: diff --git a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py index 82a5c9a3e7..5809efe56f 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py @@ -20,47 +20,17 @@ from unittest.mock import Mock from unittest.mock import patch +from google.adk.agents.readonly_context import ReadonlyContext from google.adk.auth.auth_credential import AuthCredential +from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager +from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams +from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams +from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams +from google.adk.tools.mcp_tool.mcp_tool import MCPTool +from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset +from mcp import StdioServerParameters import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="MCP tool requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from google.adk.agents.readonly_context import ReadonlyContext - from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager - from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams - from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams - from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams - from google.adk.tools.mcp_tool.mcp_tool import MCPTool - from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset - from mcp import StdioServerParameters -except ImportError as e: - if sys.version_info < (3, 10): - # Create dummy classes to prevent NameError during test collection - # Tests will be skipped anyway due to pytestmark - class DummyClass: - pass - - class StdioServerParameters: - - def __init__(self, command="test_command", args=None): - self.command = command - self.args = args or [] - - MCPSessionManager = DummyClass - SseConnectionParams = DummyClass - StdioConnectionParams = DummyClass - StreamableHTTPConnectionParams = DummyClass - MCPTool = DummyClass - MCPToolset = DummyClass - ReadonlyContext = DummyClass - else: - raise e - class MockMCPTool: """Mock MCP Tool for testing.""" diff --git a/tests/unittests/tools/openapi_tool/common/test_common.py b/tests/unittests/tools/openapi_tool/common/test_common.py index 47aeb79fdb..faece5be89 100644 --- a/tests/unittests/tools/openapi_tool/common/test_common.py +++ b/tests/unittests/tools/openapi_tool/common/test_common.py @@ -74,6 +74,24 @@ def test_api_parameter_keyword_rename(self): ) assert param.py_name == 'param_in' + def test_api_parameter_uses_location_default_when_name_missing(self): + schema = Schema(type='string') + param = ApiParameter( + original_name='', + param_location='body', + param_schema=schema, + ) + assert param.py_name == 'body' + + def test_api_parameter_uses_value_default_when_location_unknown(self): + schema = Schema(type='integer') + param = ApiParameter( + original_name='', + param_location='', + param_schema=schema, + ) + assert param.py_name == 'value' + def test_api_parameter_custom_py_name(self): schema = Schema(type='integer') param = ApiParameter( diff --git a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_operation_parser.py b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_operation_parser.py index 26cb944a22..83741c97a2 100644 --- a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_operation_parser.py +++ b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_operation_parser.py @@ -164,6 +164,40 @@ def test_process_request_body_no_name(): assert parser._params[0].param_location == 'body' +def test_process_request_body_one_of_schema_assigns_name(): + """Ensures oneOf bodies result in a named parameter.""" + operation = Operation( + operationId='one_of_request', + requestBody=RequestBody( + content={ + 'application/json': MediaType( + schema=Schema( + oneOf=[ + Schema( + type='object', + properties={ + 'type': Schema(type='string'), + 'stage': Schema(type='string'), + }, + ) + ], + discriminator={'propertyName': 'type'}, + ) + ) + } + ), + responses={'200': Response(description='ok')}, + ) + parser = OperationParser(operation) + params = parser.get_parameters() + assert len(params) == 1 + assert params[0].original_name == 'body' + assert params[0].py_name == 'body' + schema = parser.get_json_schema() + assert 'body' in schema['properties'] + assert '' not in schema['properties'] + + def test_process_request_body_empty_object(): """Test _process_request_body with a schema that is of type object but with no properties.""" operation = Operation( diff --git a/tests/unittests/tools/retrieval/test_files_retrieval.py b/tests/unittests/tools/retrieval/test_files_retrieval.py index ea4b99cd98..dfb7215dce 100644 --- a/tests/unittests/tools/retrieval/test_files_retrieval.py +++ b/tests/unittests/tools/retrieval/test_files_retrieval.py @@ -14,7 +14,6 @@ """Tests for FilesRetrieval tool.""" -import sys import unittest.mock as mock from google.adk.tools.retrieval.files_retrieval import _get_default_embedding_model @@ -111,9 +110,6 @@ def mock_import(name, *args, **kwargs): def test_get_default_embedding_model_success(self): """Test _get_default_embedding_model returns Google embedding when available.""" - # Skip this test in Python 3.9 where llama_index.embeddings.google_genai may not be available - if sys.version_info < (3, 10): - pytest.skip("llama_index.embeddings.google_genai requires Python 3.10+") # Mock the module creation to avoid import issues mock_module = mock.MagicMock() diff --git a/tests/unittests/tools/spanner/test_search_tool.py b/tests/unittests/tools/spanner/test_search_tool.py index 1b330b01d5..c69aa444ec 100644 --- a/tests/unittests/tools/spanner/test_search_tool.py +++ b/tests/unittests/tools/spanner/test_search_tool.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock from unittest.mock import MagicMock -from unittest.mock import patch +from google.adk.tools.spanner import client from google.adk.tools.spanner import search_tool +from google.adk.tools.spanner import utils from google.cloud.spanner_admin_database_v1.types import DatabaseDialect import pytest @@ -35,29 +37,59 @@ def mock_spanner_ids(): } -@patch("google.adk.tools.spanner.client.get_spanner_client") +@pytest.mark.parametrize( + ("embedding_option_key", "embedding_option_value", "expected_embedding"), + [ + pytest.param( + "spanner_googlesql_embedding_model_name", + "EmbeddingsModel", + [0.1, 0.2, 0.3], + id="spanner_googlesql_embedding_model", + ), + pytest.param( + "vertex_ai_embedding_model_name", + "text-embedding-005", + [0.4, 0.5, 0.6], + id="vertex_ai_embedding_model", + ), + ], +) +@mock.patch.object(utils, "embed_contents") +@mock.patch.object(client, "get_spanner_client") def test_similarity_search_knn_success( - mock_get_spanner_client, mock_spanner_ids, mock_credentials + mock_get_spanner_client, + mock_embed_contents, + mock_spanner_ids, + mock_credentials, + embedding_option_key, + embedding_option_value, + expected_embedding, ): """Test similarity_search function with kNN success.""" mock_spanner_client = MagicMock() mock_instance = MagicMock() mock_database = MagicMock() mock_snapshot = MagicMock() - mock_embedding_result = MagicMock() - mock_embedding_result.one.return_value = ([0.1, 0.2, 0.3],) - # First call to execute_sql is for getting the embedding - # Second call is for the kNN search - mock_snapshot.execute_sql.side_effect = [ - mock_embedding_result, - iter([("result1",), ("result2",)]), - ] mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL mock_instance.database.return_value = mock_database mock_spanner_client.instance.return_value = mock_instance mock_get_spanner_client.return_value = mock_spanner_client + if embedding_option_key == "vertex_ai_embedding_model_name": + mock_embed_contents.return_value = [expected_embedding] + # execute_sql is called once for the kNN search + mock_snapshot.execute_sql.return_value = iter([("result1",), ("result2",)]) + else: + mock_embedding_result = MagicMock() + mock_embedding_result.one.return_value = (expected_embedding,) + # First call to execute_sql is for getting the embedding, + # second call is for the kNN search + mock_snapshot.execute_sql.side_effect = [ + mock_embedding_result, + iter([("result1",), ("result2",)]), + ] + result = search_tool.similarity_search( project_id=mock_spanner_ids["project_id"], instance_id=mock_spanner_ids["instance_id"], @@ -66,10 +98,8 @@ def test_similarity_search_knn_success( query="test query", embedding_column_to_search="embedding_col", columns=["col1"], - embedding_options={"spanner_embedding_model_name": "test_model"}, + embedding_options={embedding_option_key: embedding_option_value}, credentials=mock_credentials, - settings=MagicMock(), - tool_context=MagicMock(), ) assert result["status"] == "SUCCESS", result assert result["rows"] == [("result1",), ("result2",)] @@ -79,10 +109,14 @@ def test_similarity_search_knn_success( sql = call_args.args[0] assert "COSINE_DISTANCE" in sql assert "@embedding" in sql - assert call_args.kwargs == {"params": {"embedding": [0.1, 0.2, 0.3]}} + assert call_args.kwargs == {"params": {"embedding": expected_embedding}} + if embedding_option_key == "vertex_ai_embedding_model_name": + mock_embed_contents.assert_called_once_with( + embedding_option_value, ["test query"], None + ) -@patch("google.adk.tools.spanner.client.get_spanner_client") +@mock.patch.object(client, "get_spanner_client") def test_similarity_search_ann_success( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): @@ -113,10 +147,10 @@ def test_similarity_search_ann_success( query="test query", embedding_column_to_search="embedding_col", columns=["col1"], - embedding_options={"spanner_embedding_model_name": "test_model"}, + embedding_options={ + "spanner_googlesql_embedding_model_name": "test_model" + }, credentials=mock_credentials, - settings=MagicMock(), - tool_context=MagicMock(), search_options={ "nearest_neighbors_algorithm": "APPROXIMATE_NEAREST_NEIGHBORS" }, @@ -130,7 +164,7 @@ def test_similarity_search_ann_success( assert call_args.kwargs == {"params": {"embedding": [0.1, 0.2, 0.3]}} -@patch("google.adk.tools.spanner.client.get_spanner_client") +@mock.patch.object(client, "get_spanner_client") def test_similarity_search_error( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): @@ -143,17 +177,17 @@ def test_similarity_search_error( table_name=mock_spanner_ids["table_name"], query="test query", embedding_column_to_search="embedding_col", - embedding_options={"spanner_embedding_model_name": "test_model"}, + embedding_options={ + "spanner_googlesql_embedding_model_name": "test_model" + }, columns=["col1"], credentials=mock_credentials, - settings=MagicMock(), - tool_context=MagicMock(), ) assert result["status"] == "ERROR" - assert result["error_details"] == "Test Exception" + assert "Test Exception" in result["error_details"] -@patch("google.adk.tools.spanner.client.get_spanner_client") +@mock.patch.object(client, "get_spanner_client") def test_similarity_search_postgresql_knn_success( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): @@ -182,10 +216,12 @@ def test_similarity_search_postgresql_knn_success( query="test query", embedding_column_to_search="embedding_col", columns=["col1"], - embedding_options={"vertex_ai_embedding_model_endpoint": "test_endpoint"}, + embedding_options={ + "spanner_postgresql_vertex_ai_embedding_model_endpoint": ( + "test_endpoint" + ) + }, credentials=mock_credentials, - settings=MagicMock(), - tool_context=MagicMock(), ) assert result["status"] == "SUCCESS", result assert result["rows"] == [("pg_result",)] @@ -196,7 +232,7 @@ def test_similarity_search_postgresql_knn_success( assert call_args.kwargs == {"params": {"p1": [0.1, 0.2, 0.3]}} -@patch("google.adk.tools.spanner.client.get_spanner_client") +@mock.patch.object(client, "get_spanner_client") def test_similarity_search_postgresql_ann_unsupported( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): @@ -217,27 +253,28 @@ def test_similarity_search_postgresql_ann_unsupported( query="test query", embedding_column_to_search="embedding_col", columns=["col1"], - embedding_options={"vertex_ai_embedding_model_endpoint": "test_endpoint"}, + embedding_options={ + "spanner_postgresql_vertex_ai_embedding_model_endpoint": ( + "test_endpoint" + ) + }, credentials=mock_credentials, - settings=MagicMock(), - tool_context=MagicMock(), search_options={ "nearest_neighbors_algorithm": "APPROXIMATE_NEAREST_NEIGHBORS" }, ) assert result["status"] == "ERROR" assert ( - result["error_details"] - == "APPROXIMATE_NEAREST_NEIGHBORS is not supported for PostgreSQL" - " dialect." + "APPROXIMATE_NEAREST_NEIGHBORS is not supported for PostgreSQL dialect." + in result["error_details"] ) -@patch("google.adk.tools.spanner.client.get_spanner_client") -def test_similarity_search_missing_spanner_embedding_model_name_error( +@mock.patch.object(client, "get_spanner_client") +def test_similarity_search_gsql_missing_embedding_model_error( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): - """Test similarity_search with missing spanner_embedding_model_name.""" + """Test similarity_search with missing embedding_options for GoogleSQL dialect.""" mock_spanner_client = MagicMock() mock_instance = MagicMock() mock_database = MagicMock() @@ -254,24 +291,27 @@ def test_similarity_search_missing_spanner_embedding_model_name_error( query="test query", embedding_column_to_search="embedding_col", columns=["col1"], - embedding_options={}, + embedding_options={ + "spanner_postgresql_vertex_ai_embedding_model_endpoint": ( + "test_endpoint" + ) + }, credentials=mock_credentials, - settings=MagicMock(), - tool_context=MagicMock(), ) assert result["status"] == "ERROR" assert ( - "embedding_options['spanner_embedding_model_name'] must be" - " specified for GoogleSQL dialect." + "embedding_options['vertex_ai_embedding_model_name'] or" + " embedding_options['spanner_googlesql_embedding_model_name'] must be" + " specified for GoogleSQL dialect Spanner database." in result["error_details"] ) -@patch("google.adk.tools.spanner.client.get_spanner_client") -def test_similarity_search_missing_vertex_ai_embedding_model_endpoint_error( +@mock.patch.object(client, "get_spanner_client") +def test_similarity_search_pg_missing_embedding_model_error( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): - """Test similarity_search with missing vertex_ai_embedding_model_endpoint.""" + """Test similarity_search with missing embedding_options for PostgreSQL dialect.""" mock_spanner_client = MagicMock() mock_instance = MagicMock() mock_database = MagicMock() @@ -288,14 +328,153 @@ def test_similarity_search_missing_vertex_ai_embedding_model_endpoint_error( query="test query", embedding_column_to_search="embedding_col", columns=["col1"], - embedding_options={}, + embedding_options={ + "spanner_googlesql_embedding_model_name": "EmbeddingsModel" + }, + credentials=mock_credentials, + ) + assert result["status"] == "ERROR" + assert ( + "embedding_options['vertex_ai_embedding_model_name'] or" + " embedding_options['spanner_postgresql_vertex_ai_embedding_model_endpoint']" + " must be specified for PostgreSQL dialect Spanner database." + in result["error_details"] + ) + + +@pytest.mark.parametrize( + "embedding_options", + [ + pytest.param( + { + "vertex_ai_embedding_model_name": "test-model", + "spanner_googlesql_embedding_model_name": "test-model-2", + }, + id="vertex_ai_and_googlesql", + ), + pytest.param( + { + "vertex_ai_embedding_model_name": "test-model", + "spanner_postgresql_vertex_ai_embedding_model_endpoint": ( + "test-endpoint" + ), + }, + id="vertex_ai_and_postgresql", + ), + pytest.param( + { + "spanner_googlesql_embedding_model_name": "test-model", + "spanner_postgresql_vertex_ai_embedding_model_endpoint": ( + "test-endpoint" + ), + }, + id="googlesql_and_postgresql", + ), + pytest.param( + { + "vertex_ai_embedding_model_name": "test-model", + "spanner_googlesql_embedding_model_name": "test-model-2", + "spanner_postgresql_vertex_ai_embedding_model_endpoint": ( + "test-endpoint" + ), + }, + id="all_three_models", + ), + pytest.param( + {}, + id="no_models", + ), + ], +) +@mock.patch.object(client, "get_spanner_client") +def test_similarity_search_multiple_embedding_options_error( + mock_get_spanner_client, + mock_spanner_ids, + mock_credentials, + embedding_options, +): + """Test similarity_search with multiple embedding models.""" + mock_spanner_client = MagicMock() + mock_instance = MagicMock() + mock_database = MagicMock() + mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL + mock_instance.database.return_value = mock_database + mock_spanner_client.instance.return_value = mock_instance + mock_get_spanner_client.return_value = mock_spanner_client + + result = search_tool.similarity_search( + project_id=mock_spanner_ids["project_id"], + instance_id=mock_spanner_ids["instance_id"], + database_id=mock_spanner_ids["database_id"], + table_name=mock_spanner_ids["table_name"], + query="test query", + embedding_column_to_search="embedding_col", + columns=["col1"], + embedding_options=embedding_options, credentials=mock_credentials, - settings=MagicMock(), - tool_context=MagicMock(), ) assert result["status"] == "ERROR" assert ( - "embedding_options['vertex_ai_embedding_model_endpoint'] must " - "be specified for PostgreSQL dialect." + "Exactly one embedding model option must be specified." in result["error_details"] ) + + +@mock.patch.object(client, "get_spanner_client") +def test_similarity_search_output_dimensionality_gsql_error( + mock_get_spanner_client, mock_spanner_ids, mock_credentials +): + """Test similarity_search with output_dimensionality and spanner_googlesql_embedding_model_name.""" + mock_spanner_client = MagicMock() + mock_instance = MagicMock() + mock_database = MagicMock() + mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL + mock_instance.database.return_value = mock_database + mock_spanner_client.instance.return_value = mock_instance + mock_get_spanner_client.return_value = mock_spanner_client + + result = search_tool.similarity_search( + project_id=mock_spanner_ids["project_id"], + instance_id=mock_spanner_ids["instance_id"], + database_id=mock_spanner_ids["database_id"], + table_name=mock_spanner_ids["table_name"], + query="test query", + embedding_column_to_search="embedding_col", + columns=["col1"], + embedding_options={ + "spanner_googlesql_embedding_model_name": "EmbeddingsModel", + "output_dimensionality": 128, + }, + credentials=mock_credentials, + ) + assert result["status"] == "ERROR" + assert "is not supported when" in result["error_details"] + + +@mock.patch.object(client, "get_spanner_client") +def test_similarity_search_unsupported_algorithm_error( + mock_get_spanner_client, mock_spanner_ids, mock_credentials +): + """Test similarity_search with an unsupported nearest neighbors algorithm.""" + mock_spanner_client = MagicMock() + mock_instance = MagicMock() + mock_database = MagicMock() + mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL + mock_instance.database.return_value = mock_database + mock_spanner_client.instance.return_value = mock_instance + mock_get_spanner_client.return_value = mock_spanner_client + + result = search_tool.similarity_search( + project_id=mock_spanner_ids["project_id"], + instance_id=mock_spanner_ids["instance_id"], + database_id=mock_spanner_ids["database_id"], + table_name=mock_spanner_ids["table_name"], + query="test query", + embedding_column_to_search="embedding_col", + columns=["col1"], + embedding_options={"vertex_ai_embedding_model_name": "test-model"}, + credentials=mock_credentials, + search_options={"nearest_neighbors_algorithm": "INVALID_ALGORITHM"}, + ) + assert result["status"] == "ERROR" + assert "Unsupported search_options" in result["error_details"] diff --git a/tests/unittests/tools/spanner/test_spanner_tool_settings.py b/tests/unittests/tools/spanner/test_spanner_tool_settings.py index f74922b248..730c9e0efe 100644 --- a/tests/unittests/tools/spanner/test_spanner_tool_settings.py +++ b/tests/unittests/tools/spanner/test_spanner_tool_settings.py @@ -15,9 +15,23 @@ from __future__ import annotations from google.adk.tools.spanner.settings import SpannerToolSettings +from google.adk.tools.spanner.settings import SpannerVectorStoreSettings +from pydantic import ValidationError import pytest +def common_spanner_vector_store_settings(vector_length=None): + return { + "project_id": "test-project", + "instance_id": "test-instance", + "database_id": "test-database", + "table_name": "test-table", + "content_column": "test-content-column", + "embedding_column": "test-embedding-column", + "vector_length": 128 if vector_length is None else vector_length, + } + + def test_spanner_tool_settings_experimental_warning(): """Test SpannerToolSettings experimental warning.""" with pytest.warns( @@ -25,3 +39,34 @@ def test_spanner_tool_settings_experimental_warning(): match="Tool settings defaults may have breaking change in the future.", ): SpannerToolSettings() + + +def test_spanner_vector_store_settings_all_fields_present(): + """Test SpannerVectorStoreSettings with all required fields present.""" + settings = SpannerVectorStoreSettings( + **common_spanner_vector_store_settings(), + vertex_ai_embedding_model_name="test-embedding-model", + ) + assert settings is not None + assert settings.selected_columns == ["test-content-column"] + assert settings.vertex_ai_embedding_model_name == "test-embedding-model" + + +def test_spanner_vector_store_settings_missing_embedding_model_name(): + """Test SpannerVectorStoreSettings with missing vertex_ai_embedding_model_name.""" + with pytest.raises(ValidationError) as excinfo: + SpannerVectorStoreSettings(**common_spanner_vector_store_settings()) + assert "Field required" in str(excinfo.value) + assert "vertex_ai_embedding_model_name" in str(excinfo.value) + + +def test_spanner_vector_store_settings_invalid_vector_length(): + """Test SpannerVectorStoreSettings with invalid vector_length.""" + with pytest.raises(ValidationError) as excinfo: + SpannerVectorStoreSettings( + **common_spanner_vector_store_settings(vector_length=0), + vertex_ai_embedding_model_name="test-embedding-model", + ) + assert "Invalid vector length in the Spanner vector store settings." in str( + excinfo.value + ) diff --git a/tests/unittests/tools/spanner/test_spanner_toolset.py b/tests/unittests/tools/spanner/test_spanner_toolset.py index 163832559d..a583a2f884 100644 --- a/tests/unittests/tools/spanner/test_spanner_toolset.py +++ b/tests/unittests/tools/spanner/test_spanner_toolset.py @@ -18,6 +18,7 @@ from google.adk.tools.spanner import SpannerCredentialsConfig from google.adk.tools.spanner import SpannerToolset from google.adk.tools.spanner.settings import SpannerToolSettings +from google.adk.tools.spanner.settings import SpannerVectorStoreSettings import pytest @@ -184,3 +185,50 @@ async def test_spanner_toolset_without_read_capability( expected_tool_names = set(returned_tools) actual_tool_names = set([tool.name for tool in tools]) assert actual_tool_names == expected_tool_names + + +@pytest.mark.asyncio +async def test_spanner_toolset_with_vector_store_search(): + """Test Spanner toolset with vector store search. + + This test verifies the behavior of the Spanner toolset when vector store + settings is provided. + """ + credentials_config = SpannerCredentialsConfig( + client_id="abc", client_secret="def" + ) + + spanner_tool_settings = SpannerToolSettings( + vector_store_settings=SpannerVectorStoreSettings( + project_id="test-project", + instance_id="test-instance", + database_id="test-database", + table_name="test-table", + content_column="test-content-column", + embedding_column="test-embedding-column", + vector_length=128, + vertex_ai_embedding_model_name="test-embedding-model", + ) + ) + toolset = SpannerToolset( + credentials_config=credentials_config, + spanner_tool_settings=spanner_tool_settings, + ) + tools = await toolset.get_tools() + assert tools is not None + + assert len(tools) == 8 + assert all([isinstance(tool, GoogleTool) for tool in tools]) + + expected_tool_names = set([ + "list_table_names", + "list_table_indexes", + "list_table_index_columns", + "list_named_schemas", + "get_table_schema", + "execute_sql", + "similarity_search", + "vector_store_similarity_search", + ]) + actual_tool_names = set([tool.name for tool in tools]) + assert actual_tool_names == expected_tool_names diff --git a/tests/unittests/tools/test_agent_tool.py b/tests/unittests/tools/test_agent_tool.py index 85e8b9caa1..a9723b4347 100644 --- a/tests/unittests/tools/test_agent_tool.py +++ b/tests/unittests/tools/test_agent_tool.py @@ -570,3 +570,135 @@ class CustomInput(BaseModel): # Should have string response schema for VERTEX_AI when no output_schema assert declaration.response is not None assert declaration.response.type == types.Type.STRING + + +def test_include_plugins_default_true(): + """Test that plugins are propagated by default (include_plugins=True).""" + + # Create a test plugin that tracks callbacks + class TrackingPlugin(BasePlugin): + + def __init__(self, name: str): + super().__init__(name) + self.before_agent_calls = 0 + + async def before_agent_callback(self, **kwargs): + self.before_agent_calls += 1 + + tracking_plugin = TrackingPlugin(name='tracking') + + mock_model = testing_utils.MockModel.create( + responses=[function_call_no_schema, 'response1', 'response2'] + ) + + tool_agent = Agent( + name='tool_agent', + model=mock_model, + ) + + root_agent = Agent( + name='root_agent', + model=mock_model, + tools=[AgentTool(agent=tool_agent)], # Default include_plugins=True + ) + + runner = testing_utils.InMemoryRunner(root_agent, plugins=[tracking_plugin]) + runner.run('test1') + + # Plugin should be called for both root_agent and tool_agent + assert tracking_plugin.before_agent_calls == 2 + + +def test_include_plugins_explicit_true(): + """Test that plugins are propagated when include_plugins=True.""" + + class TrackingPlugin(BasePlugin): + + def __init__(self, name: str): + super().__init__(name) + self.before_agent_calls = 0 + + async def before_agent_callback(self, **kwargs): + self.before_agent_calls += 1 + + tracking_plugin = TrackingPlugin(name='tracking') + + mock_model = testing_utils.MockModel.create( + responses=[function_call_no_schema, 'response1', 'response2'] + ) + + tool_agent = Agent( + name='tool_agent', + model=mock_model, + ) + + root_agent = Agent( + name='root_agent', + model=mock_model, + tools=[AgentTool(agent=tool_agent, include_plugins=True)], + ) + + runner = testing_utils.InMemoryRunner(root_agent, plugins=[tracking_plugin]) + runner.run('test1') + + # Plugin should be called for both root_agent and tool_agent + assert tracking_plugin.before_agent_calls == 2 + + +def test_include_plugins_false(): + """Test that plugins are NOT propagated when include_plugins=False.""" + + class TrackingPlugin(BasePlugin): + + def __init__(self, name: str): + super().__init__(name) + self.before_agent_calls = 0 + + async def before_agent_callback(self, **kwargs): + self.before_agent_calls += 1 + + tracking_plugin = TrackingPlugin(name='tracking') + + mock_model = testing_utils.MockModel.create( + responses=[function_call_no_schema, 'response1', 'response2'] + ) + + tool_agent = Agent( + name='tool_agent', + model=mock_model, + ) + + root_agent = Agent( + name='root_agent', + model=mock_model, + tools=[AgentTool(agent=tool_agent, include_plugins=False)], + ) + + runner = testing_utils.InMemoryRunner(root_agent, plugins=[tracking_plugin]) + runner.run('test1') + + # Plugin should only be called for root_agent, not tool_agent + assert tracking_plugin.before_agent_calls == 1 + + +def test_agent_tool_description_with_input_schema(): + """Test that agent description is propagated when using input_schema.""" + + class CustomInput(BaseModel): + """This is the Pydantic model docstring.""" + + custom_input: str + + agent_description = 'This is the agent description that should be used' + tool_agent = Agent( + name='tool_agent', + model=testing_utils.MockModel.create(responses=['test response']), + description=agent_description, + input_schema=CustomInput, + ) + + agent_tool = AgentTool(agent=tool_agent) + declaration = agent_tool._get_declaration() + + # The description should come from the agent, not the Pydantic model + assert declaration.description == agent_description diff --git a/tests/unittests/tools/test_api_registry.py b/tests/unittests/tools/test_api_registry.py new file mode 100644 index 0000000000..df54786049 --- /dev/null +++ b/tests/unittests/tools/test_api_registry.py @@ -0,0 +1,205 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest +from unittest.mock import MagicMock +from unittest.mock import patch + +from google.adk.tools.api_registry import ApiRegistry +from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams +import httpx + +MOCK_MCP_SERVERS_LIST = { + "mcpServers": [ + { + "name": "test-mcp-server-1", + "urls": ["mcp.server1.com"], + }, + { + "name": "test-mcp-server-2", + "urls": ["mcp.server2.com"], + }, + { + "name": "test-mcp-server-no-url", + }, + ] +} + + +class TestApiRegistry(unittest.IsolatedAsyncioTestCase): + """Unit tests for ApiRegistry.""" + + def setUp(self): + self.project_id = "test-project" + self.location = "global" + self.mock_credentials = MagicMock() + self.mock_credentials.token = "mock_token" + self.mock_credentials.refresh = MagicMock() + mock_auth_patcher = patch( + "google.auth.default", + return_value=(self.mock_credentials, None), + autospec=True, + ) + mock_auth_patcher.start() + self.addCleanup(mock_auth_patcher.stop) + + @patch("httpx.Client", autospec=True) + def test_init_success(self, MockHttpClient): + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST) + mock_client_instance = MockHttpClient.return_value + mock_client_instance.__enter__.return_value = mock_client_instance + mock_client_instance.get.return_value = mock_response + + api_registry = ApiRegistry( + api_registry_project_id=self.project_id, location=self.location + ) + + self.assertEqual(len(api_registry._mcp_servers), 3) + self.assertIn("test-mcp-server-1", api_registry._mcp_servers) + self.assertIn("test-mcp-server-2", api_registry._mcp_servers) + self.assertIn("test-mcp-server-no-url", api_registry._mcp_servers) + mock_client_instance.get.assert_called_once_with( + f"https://cloudapiregistry.googleapis.com/v1beta/projects/{self.project_id}/locations/{self.location}/mcpServers", + headers={ + "Authorization": "Bearer mock_token", + "Content-Type": "application/json", + }, + ) + + @patch("httpx.Client", autospec=True) + def test_init_http_error(self, MockHttpClient): + mock_client_instance = MockHttpClient.return_value + mock_client_instance.__enter__.return_value = mock_client_instance + mock_client_instance.get.side_effect = httpx.RequestError( + "Connection failed" + ) + + with self.assertRaisesRegex(RuntimeError, "Error fetching MCP servers"): + ApiRegistry( + api_registry_project_id=self.project_id, location=self.location + ) + + @patch("httpx.Client", autospec=True) + def test_init_bad_response(self, MockHttpClient): + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock( + side_effect=httpx.HTTPStatusError( + "Not Found", request=MagicMock(), response=MagicMock() + ) + ) + mock_client_instance = MockHttpClient.return_value + mock_client_instance.__enter__.return_value = mock_client_instance + mock_client_instance.get.return_value = mock_response + + with self.assertRaisesRegex(RuntimeError, "Error fetching MCP servers"): + ApiRegistry( + api_registry_project_id=self.project_id, location=self.location + ) + mock_response.raise_for_status.assert_called_once() + + @patch("google.adk.tools.api_registry.McpToolset", autospec=True) + @patch("httpx.Client", autospec=True) + async def test_get_toolset_success(self, MockHttpClient, MockMcpToolset): + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST) + mock_client_instance = MockHttpClient.return_value + mock_client_instance.__enter__.return_value = mock_client_instance + mock_client_instance.get.return_value = mock_response + + api_registry = ApiRegistry( + api_registry_project_id=self.project_id, location=self.location + ) + + toolset = api_registry.get_toolset("test-mcp-server-1") + + MockMcpToolset.assert_called_once_with( + connection_params=StreamableHTTPConnectionParams( + url="https://mcp.server1.com", + headers={"Authorization": "Bearer mock_token"}, + ), + tool_filter=None, + tool_name_prefix=None, + header_provider=None, + ) + self.assertEqual(toolset, MockMcpToolset.return_value) + + @patch("google.adk.tools.api_registry.McpToolset", autospec=True) + @patch("httpx.Client", autospec=True) + async def test_get_toolset_with_filter_and_prefix( + self, MockHttpClient, MockMcpToolset + ): + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST) + mock_client_instance = MockHttpClient.return_value + mock_client_instance.__enter__.return_value = mock_client_instance + mock_client_instance.get.return_value = mock_response + + api_registry = ApiRegistry( + api_registry_project_id=self.project_id, location=self.location + ) + tool_filter = ["tool1"] + tool_name_prefix = "prefix_" + toolset = api_registry.get_toolset( + "test-mcp-server-1", + tool_filter=tool_filter, + tool_name_prefix=tool_name_prefix, + ) + + MockMcpToolset.assert_called_once_with( + connection_params=StreamableHTTPConnectionParams( + url="https://mcp.server1.com", + headers={"Authorization": "Bearer mock_token"}, + ), + tool_filter=tool_filter, + tool_name_prefix=tool_name_prefix, + header_provider=None, + ) + self.assertEqual(toolset, MockMcpToolset.return_value) + + @patch("httpx.Client", autospec=True) + async def test_get_toolset_server_not_found(self, MockHttpClient): + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST) + mock_client_instance = MockHttpClient.return_value + mock_client_instance.__enter__.return_value = mock_client_instance + mock_client_instance.get.return_value = mock_response + + api_registry = ApiRegistry( + api_registry_project_id=self.project_id, location=self.location + ) + + with self.assertRaisesRegex(ValueError, "not found in API Registry"): + api_registry.get_toolset("non-existent-server") + + @patch("httpx.Client", autospec=True) + async def test_get_toolset_server_no_url(self, MockHttpClient): + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST) + mock_client_instance = MockHttpClient.return_value + mock_client_instance.__enter__.return_value = mock_client_instance + mock_client_instance.get.return_value = mock_response + + api_registry = ApiRegistry( + api_registry_project_id=self.project_id, location=self.location + ) + + with self.assertRaisesRegex(ValueError, "has no URLs"): + api_registry.get_toolset("test-mcp-server-no-url") diff --git a/tests/unittests/tools/test_build_function_declaration.py b/tests/unittests/tools/test_build_function_declaration.py index f57c3d3838..8a562c7cf4 100644 --- a/tests/unittests/tools/test_build_function_declaration.py +++ b/tests/unittests/tools/test_build_function_declaration.py @@ -411,3 +411,19 @@ def transfer_to_agent(agent_name: str, tool_context: ToolContext): # Changed: Now uses Any type instead of NULL for no return annotation assert function_decl.response is not None assert function_decl.response.type is None # Any type maps to None in schema + + +def test_transfer_to_agent_tool_with_enum_constraint(): + """Test TransferToAgentTool adds enum constraint to agent_name.""" + from google.adk.tools.transfer_to_agent_tool import TransferToAgentTool + + agent_names = ['agent_a', 'agent_b', 'agent_c'] + tool = TransferToAgentTool(agent_names=agent_names) + + function_decl = tool._get_declaration() + + assert function_decl.name == 'transfer_to_agent' + assert function_decl.parameters.type == 'OBJECT' + assert function_decl.parameters.properties['agent_name'].type == 'STRING' + assert function_decl.parameters.properties['agent_name'].enum == agent_names + assert 'tool_context' not in function_decl.parameters.properties diff --git a/tests/unittests/tools/test_mcp_toolset.py b/tests/unittests/tools/test_mcp_toolset.py index a3a6598e35..7bfd912669 100644 --- a/tests/unittests/tools/test_mcp_toolset.py +++ b/tests/unittests/tools/test_mcp_toolset.py @@ -14,31 +14,12 @@ """Unit tests for McpToolset.""" -import sys from unittest.mock import AsyncMock from unittest.mock import MagicMock +from google.adk.tools.mcp_tool.mcp_toolset import McpToolset import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="MCP tool requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from google.adk.tools.mcp_tool.mcp_toolset import McpToolset -except ImportError as e: - if sys.version_info < (3, 10): - # Create dummy classes to prevent NameError during test collection - # Tests will be skipped anyway due to pytestmark - class DummyClass: - pass - - McpToolset = DummyClass - else: - raise e - @pytest.mark.asyncio async def test_mcp_toolset_with_prefix(): diff --git a/tests/unittests/tools/test_transfer_to_agent_tool.py b/tests/unittests/tools/test_transfer_to_agent_tool.py new file mode 100644 index 0000000000..14b7b3abea --- /dev/null +++ b/tests/unittests/tools/test_transfer_to_agent_tool.py @@ -0,0 +1,164 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for TransferToAgentTool enum constraint functionality.""" + +from unittest.mock import patch + +from google.adk.tools.function_tool import FunctionTool +from google.adk.tools.transfer_to_agent_tool import TransferToAgentTool +from google.genai import types + + +def test_transfer_to_agent_tool_enum_constraint(): + """Test that TransferToAgentTool adds enum constraint to agent_name.""" + agent_names = ['agent_a', 'agent_b', 'agent_c'] + tool = TransferToAgentTool(agent_names=agent_names) + + decl = tool._get_declaration() + + assert decl is not None + assert decl.name == 'transfer_to_agent' + assert decl.parameters is not None + assert decl.parameters.type == types.Type.OBJECT + assert 'agent_name' in decl.parameters.properties + + agent_name_schema = decl.parameters.properties['agent_name'] + assert agent_name_schema.type == types.Type.STRING + assert agent_name_schema.enum == agent_names + + # Verify that agent_name is marked as required + assert decl.parameters.required == ['agent_name'] + + +def test_transfer_to_agent_tool_single_agent(): + """Test TransferToAgentTool with a single agent.""" + tool = TransferToAgentTool(agent_names=['single_agent']) + + decl = tool._get_declaration() + + assert decl is not None + agent_name_schema = decl.parameters.properties['agent_name'] + assert agent_name_schema.enum == ['single_agent'] + + +def test_transfer_to_agent_tool_multiple_agents(): + """Test TransferToAgentTool with multiple agents.""" + agent_names = ['agent_1', 'agent_2', 'agent_3', 'agent_4', 'agent_5'] + tool = TransferToAgentTool(agent_names=agent_names) + + decl = tool._get_declaration() + + assert decl is not None + agent_name_schema = decl.parameters.properties['agent_name'] + assert agent_name_schema.enum == agent_names + assert len(agent_name_schema.enum) == 5 + + +def test_transfer_to_agent_tool_empty_list(): + """Test TransferToAgentTool with an empty agent list.""" + tool = TransferToAgentTool(agent_names=[]) + + decl = tool._get_declaration() + + assert decl is not None + agent_name_schema = decl.parameters.properties['agent_name'] + assert agent_name_schema.enum == [] + + +def test_transfer_to_agent_tool_preserves_description(): + """Test that TransferToAgentTool preserves the original description.""" + tool = TransferToAgentTool(agent_names=['agent_a', 'agent_b']) + + decl = tool._get_declaration() + + assert decl is not None + assert decl.description is not None + assert 'Transfer the question to another agent' in decl.description + + +def test_transfer_to_agent_tool_preserves_parameter_type(): + """Test that TransferToAgentTool preserves the parameter type.""" + tool = TransferToAgentTool(agent_names=['agent_a']) + + decl = tool._get_declaration() + + assert decl is not None + agent_name_schema = decl.parameters.properties['agent_name'] + # Should still be a string type, just with enum constraint + assert agent_name_schema.type == types.Type.STRING + + +def test_transfer_to_agent_tool_no_extra_parameters(): + """Test that TransferToAgentTool doesn't add extra parameters.""" + tool = TransferToAgentTool(agent_names=['agent_a']) + + decl = tool._get_declaration() + + assert decl is not None + # Should only have agent_name parameter (tool_context is ignored) + assert len(decl.parameters.properties) == 1 + assert 'agent_name' in decl.parameters.properties + assert 'tool_context' not in decl.parameters.properties + + +def test_transfer_to_agent_tool_maintains_inheritance(): + """Test that TransferToAgentTool inherits from FunctionTool correctly.""" + tool = TransferToAgentTool(agent_names=['agent_a']) + + assert isinstance(tool, FunctionTool) + assert hasattr(tool, '_get_declaration') + assert hasattr(tool, 'process_llm_request') + + +def test_transfer_to_agent_tool_handles_parameters_json_schema(): + """Test that TransferToAgentTool handles parameters_json_schema format.""" + agent_names = ['agent_x', 'agent_y', 'agent_z'] + + # Create a mock FunctionDeclaration with parameters_json_schema + mock_decl = type('MockDecl', (), {})() + mock_decl.parameters = None # No Schema object + mock_decl.parameters_json_schema = { + 'type': 'object', + 'properties': { + 'agent_name': { + 'type': 'string', + 'description': 'Agent name to transfer to', + } + }, + 'required': ['agent_name'], + } + + # Temporarily patch FunctionTool._get_declaration + with patch.object( + FunctionTool, + '_get_declaration', + return_value=mock_decl, + ): + tool = TransferToAgentTool(agent_names=agent_names) + result = tool._get_declaration() + + # Verify enum was added to parameters_json_schema + assert result.parameters_json_schema is not None + assert 'agent_name' in result.parameters_json_schema['properties'] + assert ( + result.parameters_json_schema['properties']['agent_name']['enum'] + == agent_names + ) + assert ( + result.parameters_json_schema['properties']['agent_name']['type'] + == 'string' + ) + # Verify required field is preserved + assert result.parameters_json_schema['required'] == ['agent_name'] diff --git a/tests/unittests/tools/test_vertex_ai_search_tool.py b/tests/unittests/tools/test_vertex_ai_search_tool.py index 0df19288a3..1ec1572b90 100644 --- a/tests/unittests/tools/test_vertex_ai_search_tool.py +++ b/tests/unittests/tools/test_vertex_ai_search_tool.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.sequential_agent import SequentialAgent from google.adk.models.llm_request import LlmRequest @@ -24,6 +26,10 @@ from google.genai import types import pytest +VERTEX_SEARCH_TOOL_LOGGER_NAME = ( + 'google_adk.google.adk.tools.vertex_ai_search_tool' +) + async def _create_tool_context() -> ToolContext: session_service = InMemorySessionService() @@ -121,12 +127,34 @@ def test_init_with_data_store_id(self): tool = VertexAiSearchTool(data_store_id='test_data_store') assert tool.data_store_id == 'test_data_store' assert tool.search_engine_id is None + assert tool.data_store_specs is None def test_init_with_search_engine_id(self): """Test initialization with search engine ID.""" tool = VertexAiSearchTool(search_engine_id='test_search_engine') assert tool.search_engine_id == 'test_search_engine' assert tool.data_store_id is None + assert tool.data_store_specs is None + + def test_init_with_engine_and_specs(self): + """Test initialization with search engine ID and specs.""" + specs = [ + types.VertexAISearchDataStoreSpec( + dataStore=( + 'projects/p/locations/l/collections/c/dataStores/spec_store' + ) + ) + ] + engine_id = ( + 'projects/p/locations/l/collections/c/engines/test_search_engine' + ) + tool = VertexAiSearchTool( + search_engine_id=engine_id, + data_store_specs=specs, + ) + assert tool.search_engine_id == engine_id + assert tool.data_store_id is None + assert tool.data_store_specs == specs def test_init_with_neither_raises_error(self): """Test that initialization without either ID raises ValueError.""" @@ -146,10 +174,34 @@ def test_init_with_both_raises_error(self): data_store_id='test_data_store', search_engine_id='test_search_engine' ) + def test_init_with_specs_but_no_engine_raises_error(self): + """Test that specs without engine ID raises ValueError.""" + specs = [ + types.VertexAISearchDataStoreSpec( + dataStore=( + 'projects/p/locations/l/collections/c/dataStores/spec_store' + ) + ) + ] + with pytest.raises( + ValueError, + match=( + 'search_engine_id must be specified if data_store_specs is' + ' specified' + ), + ): + VertexAiSearchTool( + data_store_id='test_data_store', data_store_specs=specs + ) + @pytest.mark.asyncio - async def test_process_llm_request_with_simple_gemini_model(self): + async def test_process_llm_request_with_simple_gemini_model(self, caplog): """Test processing LLM request with simple Gemini model name.""" - tool = VertexAiSearchTool(data_store_id='test_data_store') + caplog.set_level(logging.DEBUG, logger=VERTEX_SEARCH_TOOL_LOGGER_NAME) + + tool = VertexAiSearchTool( + data_store_id='test_data_store', filter='f', max_results=5 + ) tool_context = await _create_tool_context() llm_request = LlmRequest( @@ -162,17 +214,56 @@ async def test_process_llm_request_with_simple_gemini_model(self): assert llm_request.config.tools is not None assert len(llm_request.config.tools) == 1 - assert llm_request.config.tools[0].retrieval is not None - assert llm_request.config.tools[0].retrieval.vertex_ai_search is not None + retrieval_tool = llm_request.config.tools[0] + assert retrieval_tool.retrieval is not None + assert retrieval_tool.retrieval.vertex_ai_search is not None + assert ( + retrieval_tool.retrieval.vertex_ai_search.datastore == 'test_data_store' + ) + assert retrieval_tool.retrieval.vertex_ai_search.engine is None + assert retrieval_tool.retrieval.vertex_ai_search.filter == 'f' + assert retrieval_tool.retrieval.vertex_ai_search.max_results == 5 + + # Verify debug log + debug_records = [ + r + for r in caplog.records + if 'Adding Vertex AI Search tool config' in r.message + ] + assert len(debug_records) == 1 + log_message = debug_records[0].getMessage() + assert 'datastore=test_data_store' in log_message + assert 'engine=None' in log_message + assert 'filter=f' in log_message + assert 'max_results=5' in log_message + assert 'data_store_specs=None' in log_message @pytest.mark.asyncio - async def test_process_llm_request_with_path_based_gemini_model(self): + async def test_process_llm_request_with_path_based_gemini_model(self, caplog): """Test processing LLM request with path-based Gemini model name.""" - tool = VertexAiSearchTool(data_store_id='test_data_store') + caplog.set_level(logging.DEBUG, logger=VERTEX_SEARCH_TOOL_LOGGER_NAME) + + specs = [ + types.VertexAISearchDataStoreSpec( + dataStore=( + 'projects/p/locations/l/collections/c/dataStores/spec_store' + ) + ) + ] + engine_id = 'projects/p/locations/l/collections/c/engines/test_engine' + tool = VertexAiSearchTool( + search_engine_id=engine_id, + data_store_specs=specs, + filter='f2', + max_results=10, + ) tool_context = await _create_tool_context() llm_request = LlmRequest( - model='projects/265104255505/locations/us-central1/publishers/google/models/gemini-2.0-flash-001', + model=( + 'projects/265104255505/locations/us-central1/publishers/' + 'google/models/gemini-2.0-flash-001' + ), config=types.GenerateContentConfig(), ) @@ -182,8 +273,28 @@ async def test_process_llm_request_with_path_based_gemini_model(self): assert llm_request.config.tools is not None assert len(llm_request.config.tools) == 1 - assert llm_request.config.tools[0].retrieval is not None - assert llm_request.config.tools[0].retrieval.vertex_ai_search is not None + retrieval_tool = llm_request.config.tools[0] + assert retrieval_tool.retrieval is not None + assert retrieval_tool.retrieval.vertex_ai_search is not None + assert retrieval_tool.retrieval.vertex_ai_search.datastore is None + assert retrieval_tool.retrieval.vertex_ai_search.engine == engine_id + assert retrieval_tool.retrieval.vertex_ai_search.filter == 'f2' + assert retrieval_tool.retrieval.vertex_ai_search.max_results == 10 + assert retrieval_tool.retrieval.vertex_ai_search.data_store_specs == specs + + # Verify debug log + debug_records = [ + r + for r in caplog.records + if 'Adding Vertex AI Search tool config' in r.message + ] + assert len(debug_records) == 1 + log_message = debug_records[0].getMessage() + assert 'datastore=None' in log_message + assert f'engine={engine_id}' in log_message + assert 'filter=f2' in log_message + assert 'max_results=10' in log_message + assert 'data_store_specs=1 spec(s): [spec_store]' in log_message @pytest.mark.asyncio async def test_process_llm_request_with_gemini_1_and_other_tools_raises_error( @@ -291,9 +402,11 @@ async def test_process_llm_request_with_path_based_non_gemini_model_raises_error @pytest.mark.asyncio async def test_process_llm_request_with_gemini_2_and_other_tools_succeeds( - self, + self, caplog ): """Test that Gemini 2.x with other tools succeeds.""" + caplog.set_level(logging.DEBUG, logger=VERTEX_SEARCH_TOOL_LOGGER_NAME) + tool = VertexAiSearchTool(data_store_id='test_data_store') tool_context = await _create_tool_context() @@ -316,5 +429,23 @@ async def test_process_llm_request_with_gemini_2_and_other_tools_succeeds( assert llm_request.config.tools is not None assert len(llm_request.config.tools) == 2 assert llm_request.config.tools[0] == existing_tool - assert llm_request.config.tools[1].retrieval is not None - assert llm_request.config.tools[1].retrieval.vertex_ai_search is not None + retrieval_tool = llm_request.config.tools[1] + assert retrieval_tool.retrieval is not None + assert retrieval_tool.retrieval.vertex_ai_search is not None + assert ( + retrieval_tool.retrieval.vertex_ai_search.datastore == 'test_data_store' + ) + + # Verify debug log + debug_records = [ + r + for r in caplog.records + if 'Adding Vertex AI Search tool config' in r.message + ] + assert len(debug_records) == 1 + log_message = debug_records[0].getMessage() + assert 'datastore=test_data_store' in log_message + assert 'engine=None' in log_message + assert 'filter=None' in log_message + assert 'max_results=None' in log_message + assert 'data_store_specs=None' in log_message diff --git a/tests/unittests/utils/test_client_labels_utils.py b/tests/unittests/utils/test_client_labels_utils.py new file mode 100644 index 0000000000..b1d6acb001 --- /dev/null +++ b/tests/unittests/utils/test_client_labels_utils.py @@ -0,0 +1,68 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +from google.adk import version +from google.adk.utils import _client_labels_utils +import pytest + + +def test_get_client_labels_default(): + """Test get_client_labels returns default labels.""" + labels = _client_labels_utils.get_client_labels() + assert len(labels) == 2 + assert f"google-adk/{version.__version__}" == labels[0] + assert f"gl-python/{sys.version.split()[0]}" == labels[1] + + +def test_get_client_labels_with_agent_engine_id(monkeypatch): + """Test get_client_labels returns agent engine tag when env var is set.""" + monkeypatch.setenv( + _client_labels_utils._AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME, + "test-agent-id", + ) + labels = _client_labels_utils.get_client_labels() + assert len(labels) == 2 + assert ( + f"google-adk/{version.__version__}+{_client_labels_utils._AGENT_ENGINE_TELEMETRY_TAG}" + == labels[0] + ) + assert f"gl-python/{sys.version.split()[0]}" == labels[1] + + +def test_get_client_labels_with_context(): + """Test get_client_labels includes label from context.""" + with _client_labels_utils.client_label_context("my-label/1.0"): + labels = _client_labels_utils.get_client_labels() + assert len(labels) == 3 + assert f"google-adk/{version.__version__}" == labels[0] + assert f"gl-python/{sys.version.split()[0]}" == labels[1] + assert "my-label/1.0" == labels[2] + + +def test_client_label_context_nested_error(): + """Test client_label_context raises error when nested.""" + with pytest.raises(ValueError, match="Client label already exists"): + with _client_labels_utils.client_label_context("my-label/1.0"): + with _client_labels_utils.client_label_context("another-label/1.0"): + pass + + +def test_eval_client_label(): + """Test EVAL_CLIENT_LABEL has correct format.""" + assert ( + f"google-adk-eval/{version.__version__}" + == _client_labels_utils.EVAL_CLIENT_LABEL + )